import numpy as np
from nabu.cuda.processing import CudaProcessing
from ..preproc.flatfield import FlatFieldArrays
from ..utils import get_cuda_srcfile
from ..io.reader import load_images_from_dataurl_dict
from ..cuda.utils import __has_pycuda__
[docs]
class CudaFlatFieldArrays(FlatFieldArrays):
    def __init__(
        self,
        radios_shape,
        flats,
        darks,
        radios_indices=None,
        interpolation="linear",
        distortion_correction=None,
        nan_value=1.0,
        radios_srcurrent=None,
        flats_srcurrent=None,
        cuda_options=None,
    ):
        """
        Initialize a flat-field normalization CUDA process.
        Please read the documentation of nabu.preproc.flatfield.FlatField for help
        on the parameters.
        """
        #
        if distortion_correction is not None:
            raise NotImplementedError("Flats distortion correction is not implemented with the Cuda backend")
        #
        super().__init__(
            radios_shape,
            flats,
            darks,
            radios_indices=radios_indices,
            interpolation=interpolation,
            distortion_correction=distortion_correction,
            radios_srcurrent=radios_srcurrent,
            flats_srcurrent=flats_srcurrent,
            nan_value=nan_value,
        )
        self.cuda_processing = CudaProcessing(**(cuda_options or {}))
        self._init_cuda_kernels()
        self._load_flats_and_darks_on_gpu()
    def _init_cuda_kernels(self):
        # TODO
        if self.interpolation != "linear":
            raise ValueError("Interpolation other than linar is not yet implemented in the cuda back-end")
        #
        self._cuda_fname = get_cuda_srcfile("flatfield.cu")
        options = [
            "-DN_FLATS=%d" % self.n_flats,
            "-DN_DARKS=%d" % self.n_darks,
        ]
        if self.nan_value is not None:
            options.append("-DNAN_VALUE=%f" % self.nan_value)
        self.cuda_kernel = self.cuda_processing.kernel(
            "flatfield_normalization", self._cuda_fname, signature="PPPiiiPP", options=options
        )
        self._nx = np.int32(self.shape[1])
        self._ny = np.int32(self.shape[0])
    def _load_flats_and_darks_on_gpu(self):
        # Flats
        self.d_flats = self.cuda_processing.allocate_array("d_flats", (self.n_flats,) + self.shape, np.float32)
        for i, flat_idx in enumerate(self._sorted_flat_indices):
            self.d_flats[i].set(np.ascontiguousarray(self.flats[flat_idx], dtype=np.float32))
        # Darks
        self.d_darks = self.cuda_processing.allocate_array("d_darks", (self.n_darks,) + self.shape, np.float32)
        for i, dark_idx in enumerate(self._sorted_dark_indices):
            self.d_darks[i].set(np.ascontiguousarray(self.darks[dark_idx], dtype=np.float32))
        self.d_darks_indices = self.cuda_processing.to_device(
            "d_darks_indices", np.array(self._sorted_dark_indices, dtype=np.int32)
        )
        # Indices
        self.d_flats_indices = self.cuda_processing.to_device("d_flats_indices", self.flats_idx)
        self.d_flats_weights = self.cuda_processing.to_device("d_flats_weights", self.flats_weights)
[docs]
    def normalize_radios(self, radios):
        """
        Apply a flat-field correction, with the current parameters, to a stack
        of radios.
        Parameters
        -----------
        radios_shape: `pycuda.gpuarray.GPUArray`
            Radios chunk.
        """
        if not (isinstance(radios, self.cuda_processing.array_class)):
            raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios)))
        if radios.dtype != np.float32:
            raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype))
        if radios.shape != self.radios_shape:
            raise ValueError("Expected radios shape = %s but got %s" % (str(self.radios_shape), str(radios.shape)))
        self.cuda_kernel(
            radios,
            self.d_flats,
            self.d_darks,
            self._nx,
            self._ny,
            np.int32(self.n_radios),
            self.d_flats_indices,
            self.d_flats_weights,
        )
        if self.normalize_srcurrent:
            for i in range(self.n_radios):
                radios[i] *= self.srcurrent_ratios[i]
        return radios 
 
CudaFlatField = CudaFlatFieldArrays
[docs]
class CudaFlatFieldDataUrls(CudaFlatField):
    def __init__(
        self,
        radios_shape,
        flats,
        darks,
        radios_indices=None,
        interpolation="linear",
        distortion_correction=None,
        nan_value=1.0,
        radios_srcurrent=None,
        flats_srcurrent=None,
        cuda_options=None,
        **chunk_reader_kwargs,
    ):
        flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs)
        darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs)
        super().__init__(
            radios_shape,
            flats_arrays_dict,
            darks_arrays_dict,
            radios_indices=radios_indices,
            interpolation=interpolation,
            distortion_correction=distortion_correction,
            radios_srcurrent=radios_srcurrent,
            flats_srcurrent=flats_srcurrent,
            cuda_options=cuda_options,
        )