Source code for pytoda.smiles.polymer_language

"""Polymer language handling."""
from typing import Sequence

from ..types import Indexes, Tensor, Union  # , delegate_kwargs
from .smiles_language import SMILESTokenizer
from .transforms import compose_encoding_transforms, compose_smiles_transforms


# @delegate_kwargs
[docs]class PolymerTokenizer(SMILESTokenizer): """ PolymerTokenizer class. PolymerTokenizer is an extension of SMILESTokenizer adding special start and stop tokens per entity. A polymer language is usually shared across several SMILES datasets (e.g. different entity sources). """
[docs] def __init__( self, entity_names: Sequence[str], name: str = 'polymer-language', add_start_and_stop: bool = True, **kwargs, ) -> None: """ Initialize Polymer language able to encode different entities. Args: entity_names (Sequence[str]): A list of entity names that the polymer language can distinguish. name (str): name of the PolymerTokenizer. add_start_and_stop (bool): add start and stop token indexes. Defaults to True. kwargs (dict): additional parameters passed to SMILESTokenizer. NOTE: See `set_smiles_transforms` and `set_encoding_transforms` to change the transforms temporarily and reset with `reset_initial_transforms`. Assignment of class attributes in the parameter list will trigger such a reset. """ super().__init__(name=name, add_start_and_stop=add_start_and_stop, **kwargs) self.entities = list(map(lambda x: x.capitalize(), entity_names)) self.init_kwargs['entity_names'] = self.entities self.current_entity = None # rebuild basic vocab to group special tokens self.start_entity_tokens, self.stop_entity_tokens = ( list(map(lambda x: '<' + x.upper() + '_' + s + '>', entity_names)) for s in ['START', 'STOP'] ) # required for `token_indexes_to_smiles` self.special_indexes.update( enumerate( self.start_entity_tokens + self.stop_entity_tokens, start=len(self.special_indexes), ) ) self.setup_vocab() if kwargs.get('vocab_file', None): self.load_vocabulary(kwargs['vocab_file']) self.reset_initial_transforms()
def _check_entity(self, entity: str) -> str: entity_ = entity.capitalize() if entity_ not in self.entities: raise ValueError(f'Unknown entity was given ({entity_})') return entity_
[docs] def update_entity(self, entity: str) -> None: """ Update the current entity and the default transforms (used e.g. in `add_dataset`) of the Polymer language object. Args: entity (str): a chemical entity (e.g. 'Monomer'). """ self.current_entity = self._check_entity(entity) self.transform_smiles = self.all_smiles_transforms[self.current_entity] self.transform_encoding = self.all_encoding_transforms[self.current_entity]
[docs] def smiles_to_token_indexes( self, smiles: str, entity: str = None ) -> Union[Indexes, Tensor]: """ Transform character-level SMILES into a sequence of token indexes. In case of add_start_stop, inserts entity specific tokens. Args: smiles (str): a SMILES (or SELFIES) representation. entity (str): a chemical entity (e.g. 'Monomer'). Defaults to None, where the current entity is used (initially the SMILESTokenizer default). Returns: Union[Indexes, Tensor]: indexes representation for the SMILES/SELFIES provided. """ if entity is None: # default behavior given by call to update_entity() entity = self.current_entity else: entity = self._check_entity(entity) return self.all_encoding_transforms[entity]( [ self.token_to_index.get(token, self.unknown_token) for token in self.smiles_tokenizer( self.all_smiles_transforms[entity](smiles) ) ] )
[docs] def reset_initial_transforms(self): """ Reset smiles and token indexes transforms as on initialization, including entity specific transforms. """ super().reset_initial_transforms() if not hasattr(self, 'entities'): # call from base return self.current_entity = None self.all_smiles_transforms = { None: self.transform_smiles, } self.all_encoding_transforms = { None: self.transform_encoding, } for entity in self.entities: self.set_smiles_transforms(entity) self.set_encoding_transforms(entity)
[docs] def set_smiles_transforms( self, entity, canonical=None, augment=None, kekulize=None, all_bonds_explicit=None, all_hs_explicit=None, remove_bonddir=None, remove_chirality=None, selfies=None, sanitize=None, ): """ Helper function to reversibly change the transforms per entity. """ entity = self._check_entity(entity) self.all_smiles_transforms[entity] = compose_smiles_transforms( canonical=canonical if canonical is not None else self.canonical, augment=augment if augment is not None else self.augment, kekulize=kekulize if kekulize is not None else self.kekulize, all_bonds_explicit=all_bonds_explicit if all_bonds_explicit is not None else self.all_bonds_explicit, all_hs_explicit=all_hs_explicit if all_hs_explicit is not None else self.all_hs_explicit, remove_bonddir=remove_bonddir if remove_bonddir is not None else self.remove_bonddir, remove_chirality=remove_chirality if remove_chirality is not None else self.remove_chirality, selfies=selfies if selfies is not None else self.selfies, sanitize=sanitize if sanitize is not None else self.sanitize, )
[docs] def set_encoding_transforms( self, entity, randomize=None, add_start_and_stop=None, padding=None, padding_length=None, ): """ Helper function to reversibly change the transforms per entity. Addresses entity specific start and stop tokens. """ entity = self._check_entity(entity) start_index = self.token_to_index['<' + entity.upper() + '_START>'] stop_index = self.token_to_index['<' + entity.upper() + '_STOP>'] self.all_encoding_transforms[entity] = compose_encoding_transforms( randomize=randomize if randomize is not None else self.randomize, add_start_and_stop=add_start_and_stop if add_start_and_stop is not None else self.add_start_and_stop, start_index=start_index, stop_index=stop_index, padding=padding if padding is not None else self.padding, padding_length=padding_length if padding_length is not None else self.padding_length, padding_index=self.padding_index, ) if add_start_and_stop is not None: self._set_token_len_fn(add_start_and_stop)