Source code for pytoda.datasets.polymer_dataset

"""PolymerTokenizerDataset module."""
import pandas as pd
import torch
from numpy import iterable
from torch.utils.data import Dataset

from pytoda.warnings import device_warning

from ..smiles.polymer_language import PolymerTokenizer
from ..types import List, Sequence, Tensor, Tuple, Union
from .smiles_dataset import SMILESDataset


[docs]class PolymerTokenizerDataset(Dataset): """ Dataset of SMILES from multiple entities encoded as token indexes. Creates a tuple of SMILES datasets, one per given entity (i.e. molecule class, e.g monomer and catalysts). Rows in the annotation df needs to have column names identical to entities, mapping to SMILES in the datasets. Uses a PolymerTokenizer """
[docs] def __init__( self, *smi_filepaths: str, entity_names: Sequence[str], annotations_filepath: str, annotations_column_names: Union[List[int], List[str]] = None, smiles_language: PolymerTokenizer = None, canonical: Union[Sequence[bool], bool] = False, augment: Union[Sequence[bool], bool] = False, kekulize: Union[Sequence[bool], bool] = False, all_bonds_explicit: Union[Sequence[bool], bool] = False, all_hs_explicit: Union[Sequence[bool], bool] = False, randomize: Union[Sequence[bool], bool] = False, remove_bonddir: Union[Sequence[bool], bool] = False, remove_chirality: Union[Sequence[bool], bool] = False, selfies: Union[Sequence[bool], bool] = False, sanitize: Union[Sequence[bool], bool] = True, padding: Union[Sequence[bool], bool] = True, padding_length: Union[Sequence[int], int] = None, iterate_dataset: bool = True, backend: str = 'eager', device: torch.device = None, **kwargs, ) -> None: """ Initialize a Polymer dataset. All SMILES dataset parameter can be controlled either separately for each dataset (by iterable of correct length) or globally (bool/int). Args: smi_filepaths (Files): paths to .smi files, one per entity entity_names (Sequence[str]): List of chemical entities. annotations_filepath (str): Path to .csv with the IDs of the chemical entities and their properties. Needs to have one column per entity name. annotations_column_names (Union[List[int], List[str]]): indexes (positional or strings) for the annotations. Defaults to None, a.k.a. all the columns, except the entity_names are annotation labels. smiles_language (PolymerTokenizer): a polymer language. Defaults to None, in which case a new object is created. padding (Union[Sequence[bool], bool]): pad sequences to longest in the smiles language. Defaults to True. Controlled either for each dataset separately (by iterable) or globally (bool). padding_length (Union[Sequence[int], int]): manually sets number of applied paddings, applies only if padding is True. Defaults to None. Controlled either for each dataset separately (by iterable) or globally (int). canonical (Union[Sequence[bool], bool]): performs canonicalization of SMILES (one original string for one molecule), if True, then other transformations (augment etc, see below) do not apply. augment (Union[Sequence[bool], bool]): perform SMILES augmentation. Defaults to False. kekulize (Union[Sequence[bool], bool]): kekulizes SMILES (implicit aromaticity only). Defaults to False. all_bonds_explicit (Union[Sequence[bool], bool]): Makes all bonds explicit. Defaults to False, only applies if kekulize = True. all_hs_explicit (Union[Sequence[bool], bool]): Makes all hydrogens explicit. Defaults to False, only applies if kekulize = True. randomize (Union[Sequence[bool], bool]): perform a true randomization of SMILES tokens. Defaults to False. remove_bonddir (Union[Sequence[bool], bool]): Remove directional info of bonds. Defaults to False. remove_chirality (Union[Sequence[bool], bool]): Remove chirality information. Defaults to False. selfies (Union[Sequence[bool], bool]): Whether selfies is used instead of smiles, defaults to False. sanitize (Union[Sequence[bool], bool]): sanitize (bool): RDKit sanitization of the molecule. Defaults to True. iterate_dataset (bool): whether to go through all SMILES in the dataset to build/extend vocab, find longest sequence, and checks the passed padding length if applicable. Defaults to True. backend (str): memory management backend. Defaults to eager, prefer speed over memory consumption. device (torch.device): DEPRECATED kwargs (dict): additional arguments for dataset constructor. NOTE: If a parameter that can be given as Union[Sequence[bool], bool] is given as Sequence[bool] of wrong length (!= len(entity_names)), the first list item is used for all datasets. """ device_warning(device) self.backend = backend if len(entity_names) != len(smi_filepaths): raise ValueError('Give 1 .smi file per entity') # Setup parameter ( self.paddings, self.padding_lengths, self.canonicals, self.augments, self.kekulizes, self.all_bonds_explicits, self.all_hs_explicits, self.randomizes, self.remove_bonddirs, self.remove_chiralitys, self.selfies, self.sanitize, ) = map( ( lambda x: x if iterable(x) and len(x) == len(entity_names) else [x] * len(entity_names) ), ( padding, padding_length, canonical, augment, kekulize, all_bonds_explicit, all_hs_explicit, randomize, remove_bonddir, remove_chirality, selfies, sanitize, ), ) if smiles_language is None: self.smiles_language = PolymerTokenizer( # defaults to add smiles entity_names=entity_names, padding=self.paddings[0], padding_length=self.padding_lengths[0], canonical=self.canonicals[0], augment=self.augments[0], kekulize=self.kekulizes[0], all_bonds_explicit=self.all_bonds_explicits[0], all_hs_explicit=self.all_hs_explicits[0], randomize=self.randomizes[0], remove_bonddir=self.remove_bonddirs[0], remove_chirality=self.remove_chiralitys[0], selfies=self.selfies[0], sanitize=self.sanitize[0], add_start_and_stop=True, ) for index, entity in enumerate(entity_names): self.smiles_language.set_smiles_transforms( entity, canonical=self.canonicals[index], augment=self.augments[index], kekulize=self.kekulizes[index], all_bonds_explicit=self.all_bonds_explicits[index], all_hs_explicit=self.all_hs_explicits[index], remove_bonddir=self.remove_bonddirs[index], remove_chirality=self.remove_chiralitys[index], selfies=self.selfies[index], sanitize=self.sanitize[index], ) # set_encoding_transforms only after adding smiles # while transforms are needed to add_dataset else: self.smiles_language = smiles_language self.entities = self.smiles_language.entities self.datasets = [ SMILESDataset(smi_filepath, name=self.entities[index], **kwargs) for index, smi_filepath in enumerate(smi_filepaths) ] if iterate_dataset: for dataset in self.datasets: self.smiles_language.update_entity(dataset.name) self.smiles_language.add_dataset(dataset) if padding and None in self.padding_lengths: # take care, this will call a transform reset self.smiles_language.set_max_padding() self.smiles_language.current_entity = None if smiles_language is None: for index, entity in enumerate(entity_names): # smiles_transforms might have been reset self.smiles_language.set_smiles_transforms( entity, canonical=self.canonicals[index], augment=self.augments[index], kekulize=self.kekulizes[index], all_bonds_explicit=self.all_bonds_explicits[index], all_hs_explicit=self.all_hs_explicits[index], remove_bonddir=self.remove_bonddirs[index], remove_chirality=self.remove_chiralitys[index], selfies=self.selfies[index], sanitize=self.sanitize[index], ) self.smiles_language.set_encoding_transforms( entity, randomize=self.randomizes[index], add_start_and_stop=True, padding=self.paddings[index], padding_length=self.padding_lengths[index], ) # Read and post-process the annotations dataframe self.annotations_filepath = annotations_filepath self.annotated_data_df = pd.read_csv(self.annotations_filepath) # Cast the column names to uppercase self.annotated_data_df.columns = map( lambda x: str(x).capitalize(), self.annotated_data_df.columns ) columns = self.annotated_data_df.columns # handle annotation index assert all( [entity in columns for entity in self.entities] ), 'Some of the chemical entities were not found in the label csv.' # handle labels if annotations_column_names is None: self.labels = [column for column in columns if column not in self.entities] elif all([isinstance(column, int) for column in annotations_column_names]): self.labels = columns[annotations_column_names] elif all([isinstance(column, str) for column in annotations_column_names]): self.labels = list(map(lambda x: x.capitalize(), annotations_column_names)) else: raise RuntimeError( 'label_columns should be an iterable containing int or str' ) # get the number of labels self.number_of_tasks = len(self.labels) # filter data based on the availability masks = [] mask = pd.Series( [True] * len(self.annotated_data_df), index=self.annotated_data_df.index ) for entity, dataset in zip(self.entities, self.datasets): # prune rows (in mask) with ids unavailable in respective dataset local_mask = self.annotated_data_df[entity].isin(set(dataset.keys())) mask = mask & local_mask masks.append(local_mask) self.annotated_data_df = self.annotated_data_df.loc[mask] # to investigate missing ids per entity self.masks_df = pd.concat(masks, axis=1) self.masks_df.columns = self.entities
def __len__(self) -> int: """Total number of samples.""" return len(self.annotated_data_df) def __getitem__(self, index: int) -> Tuple[Tensor, ...]: """ Generates one sample of data. Args: index (int): index of the sample to fetch. Returns: Tuple: a tuple containing self.entities+1 torch.Tensors representing respectively: compound token indexes for each chemical entity and the property labels (annotations) """ # sample selection for all entities/datasets selected_sample = self.annotated_data_df.iloc[index] # labels (annotations) labels_tensor = torch.tensor( list(selected_sample[self.labels].values), dtype=torch.float, ) # samples (SMILES token indexes) smiles_tensors = tuple( self.smiles_language.smiles_to_token_indexes( dataset.get_item_from_key(selected_sample[dataset.name]), dataset.name ) for dataset in self.datasets ) return tuple([*smiles_tensors, labels_tensor])