Source code for pytoda.datasets.drug_sensitivity_dataset

"""Implementation of DrugSensitivityDataset."""
import pandas as pd
import torch
from torch.utils.data import Dataset

from pytoda.warnings import device_warning

from ..smiles.smiles_language import SMILESLanguage
from ..types import DrugSensitivityData, GeneList, Iterable, Tuple
from .gene_expression_dataset import GeneExpressionDataset
from .smiles_dataset import SMILESTokenizerDataset


[docs]class DrugSensitivityDataset(Dataset): """ Drug sensitivity dataset implementation. """
[docs] def __init__( self, drug_sensitivity_filepath: str, smi_filepath: str, gene_expression_filepath: str, column_names: Tuple[str] = ['drug', 'cell_line', 'IC50'], drug_sensitivity_dtype: torch.dtype = torch.float, drug_sensitivity_min_max: bool = True, drug_sensitivity_processing_parameters: dict = {}, smiles_language: SMILESLanguage = None, padding: bool = True, padding_length: int = None, add_start_and_stop: bool = False, augment: bool = False, canonical: bool = False, kekulize: bool = False, all_bonds_explicit: bool = False, all_hs_explicit: bool = False, randomize: bool = False, remove_bonddir: bool = False, remove_chirality: bool = False, selfies: bool = False, sanitize: bool = True, vocab_file: str = None, iterate_dataset: bool = True, gene_list: GeneList = None, gene_expression_standardize: bool = True, gene_expression_min_max: bool = False, gene_expression_processing_parameters: dict = {}, gene_expression_dtype: torch.dtype = torch.float, gene_expression_kwargs: dict = {}, backend: str = 'eager', device: torch.device = None, ) -> None: """ Initialize a drug sensitivity dataset. Args: drug_sensitivity_filepath (str): path to drug sensitivity .csv file. Currently, the only supported format is .csv, with an index and three header columns named as specified in the column_names argument. smi_filepath (str): path to .smi file. gene_expression_filepath (str): path to gene expression .csv file. Currently, the only supported format is .csv, with an index and header columns containing gene names. column_names (Tuple[str]): Names of columns in data files to retrieve labels, ligands and protein name respectively. Defaults to ['drug', 'cell_line', 'IC50']. drug_sensitivity_dtype (torch.dtype): drug sensitivity data type. Defaults to torch.float. drug_sensitivity_min_max (bool): min-max scale drug sensitivity data. Defaults to True. drug_sensitivity_processing_parameters (dict): transformation parameters for drug sensitivity data, e.g. for min-max scaling. Defaults to {}. smiles_language (SMILESLanguage): a smiles language. Defaults to None. padding (bool): pad sequences to longest in the smiles language. Defaults to True. padding_length (int): manually sets number of applied paddings, applies only if padding is True. Defaults to None. add_start_and_stop (bool): add start and stop token indexes. Defaults to False. canonical (bool): performs canonicalization of SMILES (one original string for one molecule), if True, then other transformations (augment etc, see below) do not apply augment (bool): perform SMILES augmentation. Defaults to False. kekulize (bool): kekulizes SMILES (implicit aromaticity only). Defaults to False. all_bonds_explicit (bool): Makes all bonds explicit. Defaults to False, only applies if kekulize = True. all_hs_explicit (bool): Makes all hydrogens explicit. Defaults to False, only applies if kekulize = True. randomize (bool): perform a true randomization of SMILES tokens. Defaults to False. remove_bonddir (bool): Remove directional info of bonds. Defaults to False. remove_chirality (bool): Remove chirality information. Defaults to False. selfies (bool): Whether selfies is used instead of smiles, defaults to False. sanitize (bool): RDKit sanitization of the molecule. Defaults to True. vocab_file (str): Optional .json to load vocabulary. Tries to load metadata if `iterate_dataset` is False. Defaults to None. iterate_dataset (bool): whether to go through all SMILES in the dataset to extend/build vocab, find longest sequence, and checks the passed padding length if applicable. Defaults to True. gene_list (GeneList): a list of genes. gene_expression_standardize (bool): perform gene expression data standardization. Defaults to True. gene_expression_min_max (bool): perform min-max scaling on gene expression data. Defaults to False. gene_expression_processing_parameters (dict): transformation parameters for gene expression, e.g. for min-max scaling. Defaults to {}. gene_expression_dtype (torch.dtype): gene expression data type. Defaults to torch.float. gene_expression_kwargs (dict): additional parameters for GeneExpressionDataset. backend (str): memory management backend. Defaults to eager, prefer speed over memory consumption. Note that at the moment only the gene expression and the smiles datasets implement both backends. The drug sensitivity data are loaded in memory. device (torch.device): DEPRECATED """ Dataset.__init__(self) self.drug_sensitivity_filepath = drug_sensitivity_filepath self.smi_filepath = smi_filepath self.gene_expression_filepath = gene_expression_filepath # backend self.backend = backend if not isinstance(column_names, Iterable): raise TypeError(f'Column names was {type(column_names)}, not Iterable.') if not len(column_names) == 3: raise ValueError(f'Please pass 3 column names not {len(column_names)}') self.column_names = column_names self.drug_name, self.cell_name, self.label_name = self.column_names device_warning(device) # SMILES self.smiles_dataset = SMILESTokenizerDataset( self.smi_filepath, smiles_language=smiles_language, augment=augment, canonical=canonical, kekulize=kekulize, all_bonds_explicit=all_bonds_explicit, all_hs_explicit=all_hs_explicit, remove_bonddir=remove_bonddir, remove_chirality=remove_chirality, selfies=selfies, sanitize=sanitize, randomize=randomize, padding=padding, padding_length=padding_length, add_start_and_stop=add_start_and_stop, vocab_file=vocab_file, iterate_dataset=iterate_dataset, backend=self.backend, ) # gene expression self.gene_expression_dataset = GeneExpressionDataset( self.gene_expression_filepath, gene_list=gene_list, standardize=gene_expression_standardize, min_max=gene_expression_min_max, processing_parameters=gene_expression_processing_parameters, dtype=gene_expression_dtype, backend=self.backend, index_col=0, **gene_expression_kwargs, ) # drug sensitivity self.drug_sensitivity_dtype = drug_sensitivity_dtype self.drug_sensitivity_min_max = drug_sensitivity_min_max self.drug_sensitivity_processing_parameters = ( drug_sensitivity_processing_parameters ) self.drug_sensitivity_df = pd.read_csv( self.drug_sensitivity_filepath, index_col=0 ) # filter data based on the availability drug_mask = self.drug_sensitivity_df[self.drug_name].isin( set(self.smiles_dataset.keys()) ) profile_mask = self.drug_sensitivity_df[self.cell_name].isin( set(self.gene_expression_dataset.keys()) ) self.drug_sensitivity_df = self.drug_sensitivity_df.loc[ drug_mask & profile_mask ] # to investigate missing ids per entity self.masks_df = pd.concat([drug_mask, profile_mask], axis=1) self.masks_df.columns = [self.drug_name, self.cell_name] self.number_of_samples = len(self.drug_sensitivity_df) # NOTE: optional min-max scaling if self.drug_sensitivity_min_max: minimum = self.drug_sensitivity_processing_parameters.get( 'min', self.drug_sensitivity_df[self.label_name].min() ) maximum = self.drug_sensitivity_processing_parameters.get( 'max', self.drug_sensitivity_df[self.label_name].max() ) self.drug_sensitivity_df[self.label_name] = ( self.drug_sensitivity_df[self.label_name] - minimum ) / (maximum - minimum) self.drug_sensitivity_processing_parameters = { 'processing': 'min_max', 'parameters': {'min': minimum, 'max': maximum}, }
def __len__(self) -> int: "Total number of samples." return self.number_of_samples def __getitem__(self, index: int) -> DrugSensitivityData: """ Generates one sample of data. Args: index (int): index of the sample to fetch. Returns: DrugSensitivityData: a tuple containing three torch.Tensors, representing respectively: compound token indexes, gene expression values and IC50 for the current sample. """ # drug sensitivity selected_sample = self.drug_sensitivity_df.iloc[index] ic50_tensor = torch.tensor( [selected_sample[self.label_name]], dtype=self.drug_sensitivity_dtype, ) # SMILES token_indexes_tensor = self.smiles_dataset.get_item_from_key( selected_sample[self.drug_name] ) # gene_expression gene_expression_tensor = self.gene_expression_dataset.get_item_from_key( selected_sample[self.cell_name] ) return token_indexes_tensor, gene_expression_tensor, ic50_tensor