Source code for pytoda.datasets.tests.test_protein_sequence_dataset

"""Testing Protein Sequence dataset."""
import os
import random
import time
import unittest

import numpy as np
from importlib_resources import files
from torch.utils.data import DataLoader

from pytoda.datasets import ProteinSequenceDataset
from pytoda.tests.utils import TestFileContent

SMI_CONTENT = os.linesep.join(['EGK	ID3', 'S	ID1', 'FGAAV	ID2', 'NCCS	ID4'])
MORE_SMI_CONTENT = os.linesep.join(['KGE	ID5', 'K	ID6', 'SCCN	ID7', 'K	ID8'])
BROKEN_SMI_CONTENT = os.linesep.join(['KGE	ID5', 'K	ID6', 'SCfCN	ID7'])
AS_SMI_CONTENT = os.linesep.join(
    [
        'LGQGTRTNVVKTMLAVMVTEYVEHGPVLVRNLSDV	P29597',
        'LGKGTFGKVAKELLTLFVMEYANGGEFVVENMTDL	P31751',
    ]
)
AS_MORE_SMI_CONTENT = os.linesep.join(
    [
        'IGEGSTGIVAKEMVVMVVMEFLEGGADVIDSLSDF	Q9P286',
        'LGKGTFGKVAKELLTLFVMEYANGGEFVVENMTDL	P31749',  # same AS different protein
    ]
)


FASTA_CONTENT_UNIPROT = r""">sp|Q6GZX0|005R_FRG3G Uncharacterized protein 005R OS=Frog virus 3 (isolate Goorha) OX=654924 GN=FV3-005R PE=4 SV=1
MQNPLPEVMSPEHDKRTTTPMSKEANKFIRELDKKPGDLAVVSDFVKRNTGKRLPIGKRS
NLYVRICDLSGTIYMGETFILESWEELYLPEPTKMEVLGTLESCCGIPPFPEWIVMVGED
QCVYAYGDEEILLFAYSVKQLVEEGIQETGISYKYPDDISDVDEEVLQQDEEIQKIRKKT
REFVDKDAQEFQDFLNSLDASLLS
>sp|Q91G88|006L_IIV6 Putative KilA-N domain-containing protein 006L OS=Invertebrate iridescent virus 6 OX=176652 GN=IIV6-006L PE=3 SV=1
MDSLNEVCYEQIKGTFYKGLFGDFPLIVDKKTGCFNATKLCVLGGKRFVDWNKTLRSKKL
IQYYETRCDIKTESLLYEIKGDNNDEITKQITGTYLPKEFILDIASWISVEFYDKCNNII
"""  # length 204, 120

FASTA_CONTENT_GENERIC = (
    FASTA_CONTENT_UNIPROT
    + r""">generic_header eager upfp would concat to sequence above.
LLLLLLLLLLLLLLLL
"""
)  # length 16

all_keys = ['ID3', 'ID1', 'ID2', 'ID4', 'Q6GZX0', 'Q91G88']


[docs]class TestProteinSequenceDatasetEagerBackend(unittest.TestCase): """Testing ProteinSequence dataset with eager backend."""
[docs] def setUp(self): self.backend = 'eager' print(f'backend is {self.backend}') self.smi_content = SMI_CONTENT self.smi_other_content = MORE_SMI_CONTENT self.smi_broken_content = BROKEN_SMI_CONTENT self.as_smi_content = AS_SMI_CONTENT self.as_more_smi_content = AS_MORE_SMI_CONTENT # would fail with FASTA_CONTENT_GENERIC self.fasta_content = FASTA_CONTENT_UNIPROT
[docs] def test___len__smi(self) -> None: """Test __len__.""" with TestFileContent(self.smi_content) as a_test_file: with TestFileContent(self.smi_other_content) as another_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, backend=self.backend, ) self.assertEqual(len(protein_sequence_dataset), 8)
[docs] def test___len__fasta(self) -> None: """Test __len__.""" with TestFileContent(self.fasta_content) as a_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, filetype='.fasta', backend=self.backend ) # eager only uniprot headers self.assertEqual(len(protein_sequence_dataset), 2) time.sleep(1)
[docs] def test___getitem__(self) -> None: """Test __getitem__.""" with TestFileContent(self.smi_content) as a_test_file: with TestFileContent(self.smi_other_content) as another_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=True, add_start_and_stop=True, backend=self.backend, ) pad_index = protein_sequence_dataset.protein_language.token_to_index[ '<PAD>' ] start_index = protein_sequence_dataset.protein_language.token_to_index[ '<START>' ] stop_index = protein_sequence_dataset.protein_language.token_to_index[ '<STOP>' ] e_index = protein_sequence_dataset.protein_language.token_to_index['E'] g_index = protein_sequence_dataset.protein_language.token_to_index['G'] k_index = protein_sequence_dataset.protein_language.token_to_index['K'] n_index = protein_sequence_dataset.protein_language.token_to_index['N'] c_index = protein_sequence_dataset.protein_language.token_to_index['C'] s_index = protein_sequence_dataset.protein_language.token_to_index['S'] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), [ pad_index, pad_index, start_index, e_index, g_index, k_index, stop_index, ], ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), [ pad_index, start_index, n_index, c_index, c_index, s_index, stop_index, ], ) self.assertListEqual( protein_sequence_dataset[7].numpy().flatten().tolist(), [ pad_index, pad_index, pad_index, pad_index, start_index, k_index, stop_index, ], ) protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, ) self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), [e_index, g_index, k_index], ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), [n_index, c_index, c_index, s_index], ) self.assertListEqual( protein_sequence_dataset[7].numpy().flatten().tolist(), [k_index] ) # Test padding but no start and stop token protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=True, add_start_and_stop=False, backend=self.backend, ) self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), [pad_index, pad_index, e_index, g_index, k_index], ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), [pad_index, n_index, c_index, c_index, s_index], ) self.assertListEqual( protein_sequence_dataset[7].numpy().flatten().tolist(), [pad_index, pad_index, pad_index, pad_index, k_index], ) # Test augmentation / order reversion protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, augment_by_revert=True, ) random.seed(42) for reverted_sequence in ['EGK', 'KGE', 'KGE', 'KGE']: token_indexes = ( protein_sequence_dataset[0].numpy().flatten().tolist() ) sequence = protein_sequence_dataset.protein_language.token_indexes_to_sequence( token_indexes ) self.assertEqual(sequence, reverted_sequence) for reverted_sequence in ['S', 'S', 'S', 'S']: token_indexes = ( protein_sequence_dataset[1].numpy().flatten().tolist() ) sequence = protein_sequence_dataset.protein_language.token_indexes_to_sequence( token_indexes ) self.assertEqual(sequence, reverted_sequence) # Test UNIREP vocab protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, amino_acid_dict='unirep', padding=True, add_start_and_stop=False, backend=self.backend, ) pad_index = protein_sequence_dataset.protein_language.token_to_index[ '<PAD>' ] e_index = protein_sequence_dataset.protein_language.token_to_index['E'] g_index = protein_sequence_dataset.protein_language.token_to_index['G'] k_index = protein_sequence_dataset.protein_language.token_to_index['K'] n_index = protein_sequence_dataset.protein_language.token_to_index['N'] c_index = protein_sequence_dataset.protein_language.token_to_index['C'] s_index = protein_sequence_dataset.protein_language.token_to_index['S'] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), [pad_index, pad_index, e_index, g_index, k_index], ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), [pad_index, n_index, c_index, c_index, s_index], ) self.assertListEqual( protein_sequence_dataset[7].numpy().flatten().tolist(), [pad_index, pad_index, pad_index, pad_index, k_index], ) # Test parsing of .fasta file with TestFileContent(self.fasta_content) as a_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, filetype='.fasta', add_start_and_stop=True, backend=self.backend, ) a_tokenized_sequence = protein_sequence_dataset[1].tolist() self.assertEqual(len(a_tokenized_sequence), 206) # padded to length + start + stop self.assertEqual(sum(a_tokenized_sequence[:-123]), 0) time.sleep(1) # Test case with unknown token in dataset for iterate in [False, True]: with TestFileContent(self.smi_broken_content) as a_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, add_start_and_stop=False, padding=False, backend=self.backend, iterate_dataset=iterate, ) self.assertListEqual( protein_sequence_dataset[2].tolist(), [ protein_sequence_dataset.protein_language.token_to_index['S'], protein_sequence_dataset.protein_language.token_to_index['C'], protein_sequence_dataset.protein_language.token_to_index[ '<UNK>' ], protein_sequence_dataset.protein_language.token_to_index['C'], protein_sequence_dataset.protein_language.token_to_index['N'], ], ) """ With sequence augmentation strategies """ # TODO: Test swapping, test noise, test custom file with TestFileContent(self.as_smi_content) as a_test_file: with TestFileContent(self.as_more_smi_content) as another_test_file: # No *actual* changes protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={'discard_lowercase': True}, ) p_to_t = ( protein_sequence_dataset.protein_language.sequence_to_token_indexes ) gts = [ p_to_t('LGQGTRTNVVKTMLAVMVTEYVEHGPVLVRNLSDV'), p_to_t('LGKGTFGKVAKELLTLFVMEYANGGEFVVENMTDL'), ] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] ) # Converted to full sequence protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={'discard_lowercase': False}, ) gts = [ p_to_t( 'ITQLSHLGQGTRTNVYEGRLRVEGSGDPEEGKMDDEDPLVPGRDRGQELRVVLKVLDPSHHDIALAFYETASLMSQVSHTHLAFVHGVCVRGPENIMVTEYVEHGPLDVWLRRERGHVPMAWKMVVAQQLASALSYLENKNLVHGNVCGRNILLARLGLAEGTSPFIKLSDPGVGLGALSREERVERIPWLAPECLPGGANSLSTAMDKWGFGATLLEICFDGEAPLQSRSPSEKEHFYQRQHRLPEPSCPQLATLTSQCLTYEPTQRPSFRTILRDLTR' ), p_to_t( 'FEYLKLLGKGTFGKVILVKEKATGRYYAMKILKKEVIVAKDEVAHTLTENRVLQNSRHPFLTALKYSFQTHDRLCFVMEYANGGELFFHLSRERVFSEDRARFYGAEIVSALDYLHSEKNVVYRDLKLENLMLDKDGHIKITDFGLCKEGIKDGATMKTFCGTPEYLAPEVLEDNDYGRAVDWWGLGVVMYEMMCGRLPFYNQDHEKLFELILMEEIRFPRTLGPEAKSLLSGLLKKDPKQRLGGGSEDAKEIMQHRFF' ), ] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] ) # Flip substrings protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={ 'discard_lowercase': True, 'flip_substrings': 1.0, }, ) gts = [ p_to_t('VNTRTGQGLVKTMALVPGHEVYETVMVVLNRLDSV'), p_to_t('VKGFTGKGLAKELTLLEGGNAYEMVFFVVNEMDTL'), ] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] ) # Swap substrings protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={ 'discard_lowercase': True, 'swap_substrings': 1.0, }, ) gts = [ p_to_t('VKTMLAVMVTEYVEHGPVLVRNLSDVLGQGTRTNV'), p_to_t('AKELLTLFVMEYANGGEFVVENMTDLLGKGTFGKV'), ] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] ) # Add noise (on active site) protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={ 'discard_lowercase': True, 'swap_substrings': 1.0, 'noise': (0.5, 0.0), }, ) gts = [ p_to_t('VKMMLQEPVPWCVGNMHKNCKNLSDVDGQGTRAQY'), p_to_t('AGEHWTSFVMDFANGQKPCVCNMTDDDGHAHFVKV'), ] np.random.seed(42) self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) np.random.seed(42) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] ) # Add noise (outside active site) protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={ 'discard_lowercase': False, 'noise': (0.0, 0.1), }, ) gts = [ p_to_t( 'ITWLSQLGQGTRTNVYEGRLRVEGSGDIEEGKMDDEDKLVMGRDRGQELRFVLKVLDPSHHDIPLAFYETASLMSQVSHTHLAFVHGVCVPGIENIMVTEYVEHGPLDVWLRAERGHVPTAWKMVAAQQLASALSKLENKNLVHGNVCGRNILLARLGLAEITSPFIKLSDPGVQLGALSRWERVERIPWLAPECLPNGANSTSTAADKWGFGRTLLEICFDGEAPLQSRSPSEKEWFYQRYRRLPEPSPPQLATLTWQFLKYAPTIRPSFRTSLRDRQR' ), p_to_t( 'FEWLKQLGKGTFGKVILVKEKATGRYYAMKILKKEVIVAHDEVWHTLHENRVLQNSRHPFLTALKRSPQTHDRLCFVMEYANGGELFFHLSRERPFIEDRARFYGAYIVSALDYLASEKNVVTRDLKLENLMYDKDGHKKITDFGLCKEGIKDNATMKTFCGTPEYLAPEVLEDNDYQRAVDWWWLGVVMYEMMCGRLPFNNQDHTKLFALILMEERRFPRTLGPEAKSLLSGLLKKDPWQRLGYRSEDAKEPMQHRFF' ), ] np.random.seed(42) self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) np.random.seed(42) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] ) # Test passing a file path = files('pytoda.proteins.metadata').joinpath( 'kinase_activesite_alignment.smi' ) protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, padding=False, add_start_and_stop=False, backend=self.backend, sequence_augment={ 'alignment_path': path, 'discard_lowercase': True, 'flip_substrings': 1.0, }, ) gts = [ p_to_t('VNTRTGQGLVKTMALVPGHEVYETVMVVLNRLDSV'), p_to_t('VKGFTGKGLAKELTLLEGGNAYEMVFFVVNEMDTL'), ] self.assertListEqual( protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0] ) self.assertListEqual( protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1] )
[docs] def test_data_loader(self) -> None: """Test data_loader.""" with TestFileContent(self.smi_content) as a_test_file: with TestFileContent(self.smi_other_content) as another_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, add_start_and_stop=False, backend=self.backend, ) data_loader = DataLoader( protein_sequence_dataset, batch_size=4, shuffle=True ) for batch_index, batch in enumerate(data_loader): self.assertEqual(batch.shape, (4, 5)) if batch_index > 10: break protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, add_start_and_stop=True, backend=self.backend, ) data_loader = DataLoader( protein_sequence_dataset, batch_size=4, shuffle=True ) for batch_index, batch in enumerate(data_loader): self.assertEqual(batch.shape, (4, 7)) if batch_index > 10: break
def _test_indexed(self, ds, keys, index): key = keys[index] positive_index = index % len(ds) # get_key (support for negative index?) self.assertEqual(key, ds.get_key(positive_index)) self.assertEqual(key, ds.get_key(index)) # get_index self.assertEqual(positive_index, ds.get_index(key)) # get_item_from_key self.assertTrue(all(ds[index] == ds.get_item_from_key(key))) # keys self.assertSequenceEqual(keys, list(ds.keys())) # duplicate keys self.assertFalse(ds.has_duplicate_keys)
[docs] def test_all_base_for_indexed_methods(self): with TestFileContent(self.smi_content) as a_test_file: with TestFileContent(self.smi_other_content) as another_test_file: protein_sequence_ds = ProteinSequenceDataset( a_test_file.filename, another_test_file.filename, backend=self.backend, ) protein_sequence_ds_0 = ProteinSequenceDataset( a_test_file.filename, backend=self.backend ) protein_sequence_ds_1 = ProteinSequenceDataset( another_test_file.filename, backend=self.backend ) all_smiles, all_keys = zip( *( pair.split('\t') for pair in ( self.smi_content.split(os.linesep) + self.smi_other_content.split(os.linesep) ) ) ) for ds, keys in [ (protein_sequence_ds, all_keys), (protein_sequence_ds_0, all_keys[:4]), (protein_sequence_ds_1, all_keys[4:]), # no transformation on # concat delegation to _SmiLazyDataset/_SmiEagerDataset (protein_sequence_ds_0 + protein_sequence_ds_1, all_keys), ]: index = -1 self._test_indexed(ds, keys, index) # duplicate duplicate_ds = protein_sequence_ds_0 + protein_sequence_ds_0 self.assertTrue(duplicate_ds.has_duplicate_keys) # ProteinSequenceDataset tests and raises with TestFileContent(self.smi_content) as a_test_file: with self.assertRaises(KeyError): protein_sequence_ds = ProteinSequenceDataset( a_test_file.filename, a_test_file.filename, backend=self.backend )
[docs]class TestProteinSequenceDatasetLazyBackend( TestProteinSequenceDatasetEagerBackend ): # noqa """Testing ProteinSequence dataset with lazy backend."""
[docs] def setUp(self): self.backend = 'lazy' print(f'backend is {self.backend}') self.smi_content = SMI_CONTENT self.smi_other_content = MORE_SMI_CONTENT self.smi_broken_content = BROKEN_SMI_CONTENT self.fasta_content = FASTA_CONTENT_GENERIC self.as_smi_content = AS_SMI_CONTENT self.as_more_smi_content = AS_MORE_SMI_CONTENT
[docs] def test___len__fasta(self) -> None: """Test __len__.""" with TestFileContent(self.fasta_content) as a_test_file: protein_sequence_dataset = ProteinSequenceDataset( a_test_file.filename, filetype='.fasta', backend=self.backend ) # generic sequences self.assertEqual(len(protein_sequence_dataset), 3) time.sleep(1)
if __name__ == '__main__': unittest.main()