Source code for pytoda.datasets.gene_expression_dataset

"""GeneExpressionDataset module."""
import torch

from pytoda.warnings import device_warning

from ..types import GeneList, Optional
from ._table_dataset import _TableEagerDataset, _TableLazyDataset
from .base_dataset import DatasetDelegator

TABLE_DATASET_IMPLEMENTATIONS = {'eager': _TableEagerDataset, 'lazy': _TableLazyDataset}


[docs]class GeneExpressionDataset(DatasetDelegator): """ Gene expression dataset implementation. """
[docs] def __init__( self, *gene_expression_filepaths: str, gene_list: GeneList = None, standardize: bool = True, min_max: bool = False, processing_parameters: dict = {}, impute: Optional[float] = 0.0, dtype: torch.dtype = torch.float, backend: str = 'eager', chunk_size: int = 10000, device: torch.device = None, **kwargs, ) -> None: """ Initialize a gene expression dataset. Args: gene_expression_filepaths (Files): paths to .csv files. Currently, the only supported format is .csv, with gene profiles on rows and gene names as columns. gene_list (GeneList): a list of genes. Defaults to None. standardize (bool): perform data standardization. Defaults to True. min_max (bool): perform min-max scaling. Defaults to False. processing_parameters (dict): processing parameters. Keys can be 'min', 'max' or 'mean', 'std' respectively. Values must be readable by `np.array`, and the required order and subset of features has to match that determined by the dataset setup (see `self.gene_list` after initialization). Defaults to {}. impute (Optional[float]): NaN imputation with value if given. Defaults to 0.0. dtype (torch.dtype): data type. Defaults to torch.float. backend (str): memory management backend. Defaults to eager, prefer speed over memory consumption. chunk_size (int): size of the chunks in case of lazy reading, is ignored with 'eager' backend. Defaults to 10000. device (torch.device): DEPRECATED kwargs (dict): additional parameters for pd.read_csv. """ device_warning(device) if not (backend in TABLE_DATASET_IMPLEMENTATIONS): raise RuntimeError( 'backend={} not supported! '.format(backend) + 'Select one in [{}]'.format( ','.join(TABLE_DATASET_IMPLEMENTATIONS.keys()) ) ) self.dataset = TABLE_DATASET_IMPLEMENTATIONS[backend]( filepaths=gene_expression_filepaths, feature_list=gene_list, standardize=standardize, min_max=min_max, processing_parameters=processing_parameters, impute=impute, dtype=dtype, chunk_size=chunk_size, **kwargs, ) # if it was not passed, gene_list is common subset in files. self.gene_list = self.dataset.feature_list self.number_of_features = len(self.gene_list) DatasetDelegator.__init__(self) # delegate to self.dataset