Source code for pytoda.datasets.drug_sensitivity_dose_dataset

"""Implementation of DrugSensitivityDoseDataset."""
from typing import Callable

import numpy as np
import torch

from pytoda.warnings import device_warning

from ..smiles.smiles_language import SMILESTokenizer
from ..types import DrugSensitivityDoseData, GeneList, Tuple
from .drug_sensitivity_dataset import DrugSensitivityDataset


[docs]class DrugSensitivityDoseDataset(DrugSensitivityDataset): """ Drug sensitivity dose dataset implementation. """
[docs] def __init__( self, drug_sensitivity_filepath: str, smi_filepath: str, gene_expression_filepath: str, smiles_language: SMILESTokenizer, column_names: Tuple[str] = ['drug', 'cell_line', 'dose', 'viability'], dose_transform: Callable[[float], float] = np.log10, iterate_dataset: bool = False, gene_list: GeneList = None, gene_expression_standardize: bool = True, gene_expression_min_max: bool = False, gene_expression_processing_parameters: dict = {}, gene_expression_dtype: torch.dtype = torch.float, gene_expression_kwargs: dict = {}, backend: str = 'eager', device: torch.device = None, **kwargs, ) -> None: """ Initialize a drug sensitivity dose dataset. Args: drug_sensitivity_filepath (str): path to drug sensitivity .csv file. Currently, the only supported format is .csv, with an index and three header columns named as specified in the column_names argument. smi_filepath (str): path to .smi file. gene_expression_filepath (str): path to gene expression .csv file. Currently, the only supported format is .csv, with an index and header columns containing gene names. smiles_language (SMILESTokenizer): a smiles language/tokenizer must be passed. Specifies tokens and all transforms for SMILES conversion. column_names (Tuple[str]): Names of columns in data files to retrieve molecules, cell-line-data, drug dose and viability (label). Defaults to ['drug', 'cell_line', 'dose', 'viability']. All but the 2nd last (dosedose) are passed to drug_sensitivity_dataset. dose_transform (Callable[[float], float]): A callable to convert the raw concentration into an input for the model. E.g. if raw concentration is uMol, torch.log10 could make sense. Defaults to torch.log10. NOTE: To switch it off, pass `lambda x:x`. iterate_dataset (bool): whether to go through all SMILES in the dataset to extend/build vocab, find longest sequence, and checks the passed padding length if applicable. Defaults to False. gene_list (GeneList): a list of genes. gene_expression_standardize (bool): perform gene expression data standardization. Defaults to True. gene_expression_min_max (bool): perform min-max scaling on gene expression data. Defaults to False. gene_expression_processing_parameters (dict): transformation parameters for gene expression, e.g. for min-max scaling. Defaults to {}. backend (str): memory management backend. Defaults to eager, prefer speed over memory consumption. Note that at the moment only the gene expression and the smiles datasets implement both backends. The drug sensitivity data are loaded in memory. device (torch.device): DEPRECATED **kwargs: Additional keyword arguments for parent class (DrugSensitivityDataset). """ super().__init__( drug_sensitivity_filepath=drug_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gene_expression_filepath, column_names=column_names[:2] + [column_names[-1]], drug_sensitivity_min_max=False, drug_sensitivity_processing_parameters={}, smiles_language=smiles_language, iterate_dataset=iterate_dataset, gene_list=gene_list, gene_expression_standardize=gene_expression_standardize, gene_expression_min_max=gene_expression_min_max, gene_expression_processing_parameters=gene_expression_processing_parameters, backend=backend, **kwargs, ) self.dose_name = column_names[2] self.dose_transform = dose_transform device_warning(device)
def __getitem__(self, index: int) -> DrugSensitivityDoseData: """ Generates one sample of data. Args: index (int): index of the sample to fetch. Returns: DrugSensitivityDoseDataset: a tuple containing four torch.Tensors, representing respectively: - compound token indexes, - gene expression values, - drug concentration, - cell viability. """ token_indexes, gene_expression, viability = super().__getitem__(index) dose = torch.tensor( [self.dose_transform(self.drug_sensitivity_df.iloc[index][self.dose_name])], dtype=self.drug_sensitivity_dtype, ) return token_indexes, gene_expression, dose, viability