"""Datasets for smiles and transformations of smiles."""
import logging
import torch
from pytoda.warnings import device_warning
from ..smiles.processing import split_selfies
from ..smiles.smiles_language import SMILESLanguage, SMILESTokenizer
from ._smi_eager_dataset import _SmiEagerDataset
from ._smi_lazy_dataset import _SmiLazyDataset
from .base_dataset import DatasetDelegator
from .utils import concatenate_file_based_datasets
logger = logging.getLogger(__name__)
SMILES_DATASET_IMPLEMENTATIONS = { # get class and acceptable keywords
'eager': (_SmiEagerDataset, {'index_col', 'names'}),
'lazy': (_SmiLazyDataset, {'chunk_size', 'index_col', 'names'}),
} # name cannot be passed
[docs]class SMILESDataset(DatasetDelegator):
"""Dataset of SMILES."""
[docs] def __init__(
self,
*smi_filepaths: str,
backend: str = 'eager',
name: str = 'smiles-dataset',
device: torch.device = None,
**kwargs,
) -> None:
"""
Initialize a SMILES dataset.
Args:
smi_filepaths (Files): paths to .smi files.
name (str): name of the SMILESDataset.
backend (str): memory management backend.
Defaults to eager, prefer speed over memory consumption.
device (torch.device): DEPRECATED
kwargs (dict): additional arguments for dataset constructor.
"""
device_warning(device)
# Parse language object and data paths
self.smi_filepaths = smi_filepaths
self.backend = backend
self.name = name
dataset_class, valid_keys = SMILES_DATASET_IMPLEMENTATIONS[self.backend]
self.kwargs = dict((k, v) for k, v in kwargs.items() if k in valid_keys)
self.kwargs['name'] = 'SMILES'
self.dataset = concatenate_file_based_datasets(
filepaths=self.smi_filepaths, dataset_class=dataset_class, **self.kwargs
)
DatasetDelegator.__init__(self) # delegate to self.dataset
if self.has_duplicate_keys:
raise KeyError('Please remove duplicates from your .smi file.')
[docs]class SMILESTokenizerDataset(DatasetDelegator):
"""Dataset of token indexes from SMILES."""
[docs] def __init__(
self,
*smi_filepaths: str,
smiles_language: SMILESLanguage = None,
canonical: bool = False,
augment: bool = False,
kekulize: bool = False,
all_bonds_explicit: bool = False,
all_hs_explicit: bool = False,
remove_bonddir: bool = False,
remove_chirality: bool = False,
selfies: bool = False,
sanitize: bool = True,
randomize: bool = False,
add_start_and_stop: bool = False,
padding: bool = True,
padding_length: int = None,
vocab_file: str = None,
iterate_dataset: bool = True,
backend: str = 'eager',
device: torch.device = None,
name: str = 'smiles-encoder-dataset',
**kwargs,
) -> None:
"""
Initialize a dataset providing token indexes from source SMILES.
The datasets transformations on smiles and encodings can be adapted,
depending on the smiles_language used (see SMILESTokenizer).
Args:
smi_filepaths (Files): paths to .smi files.
smiles_language (SMILESLanguage): a smiles language that transforms
and encodes SMILES to token indexes. Defaults to None, where
a SMILESTokenizer is instantited with the following arguments.
canonical (bool): performs canonicalization of SMILES (one
original string for one molecule), if True, then other
transformations (augment etc, see below) do not apply
augment (bool): perform SMILES augmentation. Defaults to False.
kekulize (bool): kekulizes SMILES (implicit aromaticity only).
Defaults to False.
all_bonds_explicit (bool): Makes all bonds explicit. Defaults to
False, only applies if kekulize = True.
all_hs_explicit (bool): Makes all hydrogens explicit. Defaults to
False, only applies if kekulize = True.
randomize (bool): perform a true randomization of SMILES tokens.
Defaults to False.
remove_bonddir (bool): Remove directional info of bonds.
Defaults to False.
remove_chirality (bool): Remove chirality information.
Defaults to False.
selfies (bool): Whether selfies is used instead of smiles, defaults
to False.
sanitize (bool): RDKit sanitization of the molecule.
Defaults to True.
add_start_and_stop (bool): add start and stop token indexes.
Defaults to False.
padding (bool): pad sequences to longest in the smiles language.
Defaults to True.
padding_length (int): padding to match manually set length,
applies only if padding is True. Defaults to None.
vocab_file (str): Optional .json to load vocabulary. Defaults to
None.
iterate_dataset (bool): whether to go through all SMILES in the
dataset to extend/build vocab, find longest sequence, and
checks the passed padding length if applicable. Defaults to
True.
backend (str): memory management backend.
Defaults to eager, prefer speed over memory consumption.
name (str): name of the SMILESTokenizerDataset.
device (torch.device): DEPRECATED
kwargs (dict): additional arguments for dataset constructor.
"""
device_warning(device)
self.name = name
self.dataset = SMILESDataset(*smi_filepaths, backend=backend, **kwargs)
DatasetDelegator.__init__(self) # delegate to self.dataset
if smiles_language is not None:
self.smiles_language = smiles_language
params = (
"canonical, augment, kekulize, all_bonds_explicit, selfies, sanitize, "
"all_hs_explicit, remove_bonddir, remove_chirality, randomize, "
"add_start_and_stop, padding, padding_length"
)
logger.error(
'Since you provided a smiles_language, the following parameters to this'
f' class will be ignored: {params}.\nHere are the problems:'
)
mismatch = False
for p in params.split(','):
if eval(p.strip()) != eval(f'smiles_language.{p.strip()}'):
logger.error(
f'Provided arg {p.strip()}:{eval(p.strip())} does not match the '
f'smiles_language value: {eval(f"smiles_language.{p.strip()}")}'
' NOTE: smiles_language value takes preference!!'
)
mismatch = True
if not mismatch:
logger.error('Looking great, no problems found!')
else:
logger.error(
'To get rid of this, adapt the smiles_language *offline*, feed it '
'ready for intended usage, and adapt the constructor args to be '
'identical with their equivalents in the language object'
)
else:
language_kwargs = {} # SMILES default
if selfies:
language_kwargs = dict(
name='selfies-language', smiles_tokenizer=split_selfies
)
self.smiles_language = SMILESTokenizer(
**language_kwargs,
canonical=canonical,
augment=augment,
kekulize=kekulize,
all_bonds_explicit=all_bonds_explicit,
all_hs_explicit=all_hs_explicit,
randomize=randomize,
remove_bonddir=remove_bonddir,
remove_chirality=remove_chirality,
selfies=selfies,
sanitize=sanitize,
add_start_and_stop=add_start_and_stop,
padding=padding,
padding_length=padding_length,
)
if vocab_file:
self.smiles_language.load_vocabulary(vocab_file)
if iterate_dataset:
# uses the smiles transforms
self.smiles_language.add_dataset(self.dataset)
try:
if (
self.smiles_language.padding
and self.smiles_language.padding_length is None
):
try:
# max_sequence_token_length has to be set somehow
if smiles_language is not None or iterate_dataset:
self.smiles_language.set_max_padding()
except AttributeError:
raise TypeError(
'Setting a maximum padding length requires a '
'smiles_language with `set_max_padding` method. See '
'`SMILESTokenizer`.'
)
except AttributeError:
# SmilesLanguage w/o padding support passed.
pass
def __getitem__(self, index: int) -> torch.Tensor:
"""
Generates one sample of data.
Args:
index (int): index of the sample to fetch.
Returns:
torch.Tensor: a torch tensor of token indexes,
for the current sample.
"""
return self.smiles_language.smiles_to_token_indexes(self.dataset[index])