"""GeneExpressionDataset module."""
import torch
from pytoda.warnings import device_warning
from ..types import GeneList, Optional
from ._table_dataset import _TableEagerDataset, _TableLazyDataset
from .base_dataset import DatasetDelegator
TABLE_DATASET_IMPLEMENTATIONS = {'eager': _TableEagerDataset, 'lazy': _TableLazyDataset}
[docs]class GeneExpressionDataset(DatasetDelegator):
    """
    Gene expression dataset implementation.
    """
[docs]    def __init__(
        self,
        *gene_expression_filepaths: str,
        gene_list: GeneList = None,
        standardize: bool = True,
        min_max: bool = False,
        processing_parameters: dict = {},
        impute: Optional[float] = 0.0,
        dtype: torch.dtype = torch.float,
        backend: str = 'eager',
        chunk_size: int = 10000,
        device: torch.device = None,
        **kwargs,
    ) -> None:
        """
        Initialize a gene expression dataset.
        Args:
            gene_expression_filepaths (Files): paths to .csv files.
                Currently, the only supported format is .csv, with gene
                profiles on rows and gene names as columns.
            gene_list (GeneList): a list of genes. Defaults to None.
            standardize (bool): perform data standardization. Defaults to True.
            min_max (bool): perform min-max scaling. Defaults to False.
            processing_parameters (dict): processing parameters.
                Keys can be 'min', 'max' or 'mean', 'std'
                respectively. Values must be readable by `np.array`, and the
                required order and subset of features has to match that
                determined by the dataset setup (see `self.gene_list` after
                initialization). Defaults to {}.
            impute (Optional[float]): NaN imputation with value if
                given. Defaults to 0.0.
            dtype (torch.dtype): data type. Defaults to torch.float.
            backend (str): memory management backend.
                Defaults to eager, prefer speed over memory consumption.
            chunk_size (int): size of the chunks in case of lazy reading, is
                ignored with 'eager' backend. Defaults to 10000.
            device (torch.device): DEPRECATED
            kwargs (dict): additional parameters for pd.read_csv.
        """
        device_warning(device)
        if not (backend in TABLE_DATASET_IMPLEMENTATIONS):
            raise RuntimeError(
                'backend={} not supported! '.format(backend)
                + 'Select one in [{}]'.format(
                    ','.join(TABLE_DATASET_IMPLEMENTATIONS.keys())
                )
            )
        self.dataset = TABLE_DATASET_IMPLEMENTATIONS[backend](
            filepaths=gene_expression_filepaths,
            feature_list=gene_list,
            standardize=standardize,
            min_max=min_max,
            processing_parameters=processing_parameters,
            impute=impute,
            dtype=dtype,
            chunk_size=chunk_size,
            **kwargs,
        )
        # if it was not passed, gene_list is common subset in files.
        self.gene_list = self.dataset.feature_list
        self.number_of_features = len(self.gene_list)
        DatasetDelegator.__init__(self)  # delegate to self.dataset