"""Testing Protein Sequence dataset."""
import os
import random
import time
import unittest
import numpy as np
from importlib_resources import files
from torch.utils.data import DataLoader
from pytoda.datasets import ProteinSequenceDataset
from pytoda.tests.utils import TestFileContent
SMI_CONTENT = os.linesep.join(['EGK ID3', 'S ID1', 'FGAAV ID2', 'NCCS ID4'])
MORE_SMI_CONTENT = os.linesep.join(['KGE ID5', 'K ID6', 'SCCN ID7', 'K ID8'])
BROKEN_SMI_CONTENT = os.linesep.join(['KGE ID5', 'K ID6', 'SCfCN ID7'])
AS_SMI_CONTENT = os.linesep.join(
[
'LGQGTRTNVVKTMLAVMVTEYVEHGPVLVRNLSDV P29597',
'LGKGTFGKVAKELLTLFVMEYANGGEFVVENMTDL P31751',
]
)
AS_MORE_SMI_CONTENT = os.linesep.join(
[
'IGEGSTGIVAKEMVVMVVMEFLEGGADVIDSLSDF Q9P286',
'LGKGTFGKVAKELLTLFVMEYANGGEFVVENMTDL P31749', # same AS different protein
]
)
FASTA_CONTENT_UNIPROT = r""">sp|Q6GZX0|005R_FRG3G Uncharacterized protein 005R OS=Frog virus 3 (isolate Goorha) OX=654924 GN=FV3-005R PE=4 SV=1
MQNPLPEVMSPEHDKRTTTPMSKEANKFIRELDKKPGDLAVVSDFVKRNTGKRLPIGKRS
NLYVRICDLSGTIYMGETFILESWEELYLPEPTKMEVLGTLESCCGIPPFPEWIVMVGED
QCVYAYGDEEILLFAYSVKQLVEEGIQETGISYKYPDDISDVDEEVLQQDEEIQKIRKKT
REFVDKDAQEFQDFLNSLDASLLS
>sp|Q91G88|006L_IIV6 Putative KilA-N domain-containing protein 006L OS=Invertebrate iridescent virus 6 OX=176652 GN=IIV6-006L PE=3 SV=1
MDSLNEVCYEQIKGTFYKGLFGDFPLIVDKKTGCFNATKLCVLGGKRFVDWNKTLRSKKL
IQYYETRCDIKTESLLYEIKGDNNDEITKQITGTYLPKEFILDIASWISVEFYDKCNNII
""" # length 204, 120
FASTA_CONTENT_GENERIC = (
FASTA_CONTENT_UNIPROT
+ r""">generic_header eager upfp would concat to sequence above.
LLLLLLLLLLLLLLLL
"""
) # length 16
all_keys = ['ID3', 'ID1', 'ID2', 'ID4', 'Q6GZX0', 'Q91G88']
[docs]class TestProteinSequenceDatasetEagerBackend(unittest.TestCase):
"""Testing ProteinSequence dataset with eager backend."""
[docs] def setUp(self):
self.backend = 'eager'
print(f'backend is {self.backend}')
self.smi_content = SMI_CONTENT
self.smi_other_content = MORE_SMI_CONTENT
self.smi_broken_content = BROKEN_SMI_CONTENT
self.as_smi_content = AS_SMI_CONTENT
self.as_more_smi_content = AS_MORE_SMI_CONTENT
# would fail with FASTA_CONTENT_GENERIC
self.fasta_content = FASTA_CONTENT_UNIPROT
[docs] def test___len__smi(self) -> None:
"""Test __len__."""
with TestFileContent(self.smi_content) as a_test_file:
with TestFileContent(self.smi_other_content) as another_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
backend=self.backend,
)
self.assertEqual(len(protein_sequence_dataset), 8)
[docs] def test___len__fasta(self) -> None:
"""Test __len__."""
with TestFileContent(self.fasta_content) as a_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename, filetype='.fasta', backend=self.backend
)
# eager only uniprot headers
self.assertEqual(len(protein_sequence_dataset), 2)
time.sleep(1)
[docs] def test___getitem__(self) -> None:
"""Test __getitem__."""
with TestFileContent(self.smi_content) as a_test_file:
with TestFileContent(self.smi_other_content) as another_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=True,
add_start_and_stop=True,
backend=self.backend,
)
pad_index = protein_sequence_dataset.protein_language.token_to_index[
'<PAD>'
]
start_index = protein_sequence_dataset.protein_language.token_to_index[
'<START>'
]
stop_index = protein_sequence_dataset.protein_language.token_to_index[
'<STOP>'
]
e_index = protein_sequence_dataset.protein_language.token_to_index['E']
g_index = protein_sequence_dataset.protein_language.token_to_index['G']
k_index = protein_sequence_dataset.protein_language.token_to_index['K']
n_index = protein_sequence_dataset.protein_language.token_to_index['N']
c_index = protein_sequence_dataset.protein_language.token_to_index['C']
s_index = protein_sequence_dataset.protein_language.token_to_index['S']
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(),
[
pad_index,
pad_index,
start_index,
e_index,
g_index,
k_index,
stop_index,
],
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(),
[
pad_index,
start_index,
n_index,
c_index,
c_index,
s_index,
stop_index,
],
)
self.assertListEqual(
protein_sequence_dataset[7].numpy().flatten().tolist(),
[
pad_index,
pad_index,
pad_index,
pad_index,
start_index,
k_index,
stop_index,
],
)
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
)
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(),
[e_index, g_index, k_index],
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(),
[n_index, c_index, c_index, s_index],
)
self.assertListEqual(
protein_sequence_dataset[7].numpy().flatten().tolist(), [k_index]
)
# Test padding but no start and stop token
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=True,
add_start_and_stop=False,
backend=self.backend,
)
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(),
[pad_index, pad_index, e_index, g_index, k_index],
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(),
[pad_index, n_index, c_index, c_index, s_index],
)
self.assertListEqual(
protein_sequence_dataset[7].numpy().flatten().tolist(),
[pad_index, pad_index, pad_index, pad_index, k_index],
)
# Test augmentation / order reversion
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
augment_by_revert=True,
)
random.seed(42)
for reverted_sequence in ['EGK', 'KGE', 'KGE', 'KGE']:
token_indexes = (
protein_sequence_dataset[0].numpy().flatten().tolist()
)
sequence = protein_sequence_dataset.protein_language.token_indexes_to_sequence(
token_indexes
)
self.assertEqual(sequence, reverted_sequence)
for reverted_sequence in ['S', 'S', 'S', 'S']:
token_indexes = (
protein_sequence_dataset[1].numpy().flatten().tolist()
)
sequence = protein_sequence_dataset.protein_language.token_indexes_to_sequence(
token_indexes
)
self.assertEqual(sequence, reverted_sequence)
# Test UNIREP vocab
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
amino_acid_dict='unirep',
padding=True,
add_start_and_stop=False,
backend=self.backend,
)
pad_index = protein_sequence_dataset.protein_language.token_to_index[
'<PAD>'
]
e_index = protein_sequence_dataset.protein_language.token_to_index['E']
g_index = protein_sequence_dataset.protein_language.token_to_index['G']
k_index = protein_sequence_dataset.protein_language.token_to_index['K']
n_index = protein_sequence_dataset.protein_language.token_to_index['N']
c_index = protein_sequence_dataset.protein_language.token_to_index['C']
s_index = protein_sequence_dataset.protein_language.token_to_index['S']
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(),
[pad_index, pad_index, e_index, g_index, k_index],
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(),
[pad_index, n_index, c_index, c_index, s_index],
)
self.assertListEqual(
protein_sequence_dataset[7].numpy().flatten().tolist(),
[pad_index, pad_index, pad_index, pad_index, k_index],
)
# Test parsing of .fasta file
with TestFileContent(self.fasta_content) as a_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
filetype='.fasta',
add_start_and_stop=True,
backend=self.backend,
)
a_tokenized_sequence = protein_sequence_dataset[1].tolist()
self.assertEqual(len(a_tokenized_sequence), 206)
# padded to length + start + stop
self.assertEqual(sum(a_tokenized_sequence[:-123]), 0)
time.sleep(1)
# Test case with unknown token in dataset
for iterate in [False, True]:
with TestFileContent(self.smi_broken_content) as a_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
add_start_and_stop=False,
padding=False,
backend=self.backend,
iterate_dataset=iterate,
)
self.assertListEqual(
protein_sequence_dataset[2].tolist(),
[
protein_sequence_dataset.protein_language.token_to_index['S'],
protein_sequence_dataset.protein_language.token_to_index['C'],
protein_sequence_dataset.protein_language.token_to_index[
'<UNK>'
],
protein_sequence_dataset.protein_language.token_to_index['C'],
protein_sequence_dataset.protein_language.token_to_index['N'],
],
)
"""
With sequence augmentation strategies
"""
# TODO: Test swapping, test noise, test custom file
with TestFileContent(self.as_smi_content) as a_test_file:
with TestFileContent(self.as_more_smi_content) as another_test_file:
# No *actual* changes
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={'discard_lowercase': True},
)
p_to_t = (
protein_sequence_dataset.protein_language.sequence_to_token_indexes
)
gts = [
p_to_t('LGQGTRTNVVKTMLAVMVTEYVEHGPVLVRNLSDV'),
p_to_t('LGKGTFGKVAKELLTLFVMEYANGGEFVVENMTDL'),
]
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
# Converted to full sequence
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={'discard_lowercase': False},
)
gts = [
p_to_t(
'ITQLSHLGQGTRTNVYEGRLRVEGSGDPEEGKMDDEDPLVPGRDRGQELRVVLKVLDPSHHDIALAFYETASLMSQVSHTHLAFVHGVCVRGPENIMVTEYVEHGPLDVWLRRERGHVPMAWKMVVAQQLASALSYLENKNLVHGNVCGRNILLARLGLAEGTSPFIKLSDPGVGLGALSREERVERIPWLAPECLPGGANSLSTAMDKWGFGATLLEICFDGEAPLQSRSPSEKEHFYQRQHRLPEPSCPQLATLTSQCLTYEPTQRPSFRTILRDLTR'
),
p_to_t(
'FEYLKLLGKGTFGKVILVKEKATGRYYAMKILKKEVIVAKDEVAHTLTENRVLQNSRHPFLTALKYSFQTHDRLCFVMEYANGGELFFHLSRERVFSEDRARFYGAEIVSALDYLHSEKNVVYRDLKLENLMLDKDGHIKITDFGLCKEGIKDGATMKTFCGTPEYLAPEVLEDNDYGRAVDWWGLGVVMYEMMCGRLPFYNQDHEKLFELILMEEIRFPRTLGPEAKSLLSGLLKKDPKQRLGGGSEDAKEIMQHRFF'
),
]
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
# Flip substrings
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={
'discard_lowercase': True,
'flip_substrings': 1.0,
},
)
gts = [
p_to_t('VNTRTGQGLVKTMALVPGHEVYETVMVVLNRLDSV'),
p_to_t('VKGFTGKGLAKELTLLEGGNAYEMVFFVVNEMDTL'),
]
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
# Swap substrings
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={
'discard_lowercase': True,
'swap_substrings': 1.0,
},
)
gts = [
p_to_t('VKTMLAVMVTEYVEHGPVLVRNLSDVLGQGTRTNV'),
p_to_t('AKELLTLFVMEYANGGEFVVENMTDLLGKGTFGKV'),
]
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
# Add noise (on active site)
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={
'discard_lowercase': True,
'swap_substrings': 1.0,
'noise': (0.5, 0.0),
},
)
gts = [
p_to_t('VKMMLQEPVPWCVGNMHKNCKNLSDVDGQGTRAQY'),
p_to_t('AGEHWTSFVMDFANGQKPCVCNMTDDDGHAHFVKV'),
]
np.random.seed(42)
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
np.random.seed(42)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
# Add noise (outside active site)
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={
'discard_lowercase': False,
'noise': (0.0, 0.1),
},
)
gts = [
p_to_t(
'ITWLSQLGQGTRTNVYEGRLRVEGSGDIEEGKMDDEDKLVMGRDRGQELRFVLKVLDPSHHDIPLAFYETASLMSQVSHTHLAFVHGVCVPGIENIMVTEYVEHGPLDVWLRAERGHVPTAWKMVAAQQLASALSKLENKNLVHGNVCGRNILLARLGLAEITSPFIKLSDPGVQLGALSRWERVERIPWLAPECLPNGANSTSTAADKWGFGRTLLEICFDGEAPLQSRSPSEKEWFYQRYRRLPEPSPPQLATLTWQFLKYAPTIRPSFRTSLRDRQR'
),
p_to_t(
'FEWLKQLGKGTFGKVILVKEKATGRYYAMKILKKEVIVAHDEVWHTLHENRVLQNSRHPFLTALKRSPQTHDRLCFVMEYANGGELFFHLSRERPFIEDRARFYGAYIVSALDYLASEKNVVTRDLKLENLMYDKDGHKKITDFGLCKEGIKDNATMKTFCGTPEYLAPEVLEDNDYQRAVDWWWLGVVMYEMMCGRLPFNNQDHTKLFALILMEERRFPRTLGPEAKSLLSGLLKKDPWQRLGYRSEDAKEPMQHRFF'
),
]
np.random.seed(42)
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
np.random.seed(42)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
# Test passing a file
path = files('pytoda.proteins.metadata').joinpath(
'kinase_activesite_alignment.smi'
)
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
padding=False,
add_start_and_stop=False,
backend=self.backend,
sequence_augment={
'alignment_path': path,
'discard_lowercase': True,
'flip_substrings': 1.0,
},
)
gts = [
p_to_t('VNTRTGQGLVKTMALVPGHEVYETVMVVLNRLDSV'),
p_to_t('VKGFTGKGLAKELTLLEGGNAYEMVFFVVNEMDTL'),
]
self.assertListEqual(
protein_sequence_dataset[0].numpy().flatten().tolist(), gts[0]
)
self.assertListEqual(
protein_sequence_dataset[3].numpy().flatten().tolist(), gts[1]
)
[docs] def test_data_loader(self) -> None:
"""Test data_loader."""
with TestFileContent(self.smi_content) as a_test_file:
with TestFileContent(self.smi_other_content) as another_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
add_start_and_stop=False,
backend=self.backend,
)
data_loader = DataLoader(
protein_sequence_dataset, batch_size=4, shuffle=True
)
for batch_index, batch in enumerate(data_loader):
self.assertEqual(batch.shape, (4, 5))
if batch_index > 10:
break
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
add_start_and_stop=True,
backend=self.backend,
)
data_loader = DataLoader(
protein_sequence_dataset, batch_size=4, shuffle=True
)
for batch_index, batch in enumerate(data_loader):
self.assertEqual(batch.shape, (4, 7))
if batch_index > 10:
break
def _test_indexed(self, ds, keys, index):
key = keys[index]
positive_index = index % len(ds)
# get_key (support for negative index?)
self.assertEqual(key, ds.get_key(positive_index))
self.assertEqual(key, ds.get_key(index))
# get_index
self.assertEqual(positive_index, ds.get_index(key))
# get_item_from_key
self.assertTrue(all(ds[index] == ds.get_item_from_key(key)))
# keys
self.assertSequenceEqual(keys, list(ds.keys()))
# duplicate keys
self.assertFalse(ds.has_duplicate_keys)
[docs] def test_all_base_for_indexed_methods(self):
with TestFileContent(self.smi_content) as a_test_file:
with TestFileContent(self.smi_other_content) as another_test_file:
protein_sequence_ds = ProteinSequenceDataset(
a_test_file.filename,
another_test_file.filename,
backend=self.backend,
)
protein_sequence_ds_0 = ProteinSequenceDataset(
a_test_file.filename, backend=self.backend
)
protein_sequence_ds_1 = ProteinSequenceDataset(
another_test_file.filename, backend=self.backend
)
all_smiles, all_keys = zip(
*(
pair.split('\t')
for pair in (
self.smi_content.split(os.linesep)
+ self.smi_other_content.split(os.linesep)
)
)
)
for ds, keys in [
(protein_sequence_ds, all_keys),
(protein_sequence_ds_0, all_keys[:4]),
(protein_sequence_ds_1, all_keys[4:]),
# no transformation on
# concat delegation to _SmiLazyDataset/_SmiEagerDataset
(protein_sequence_ds_0 + protein_sequence_ds_1, all_keys),
]:
index = -1
self._test_indexed(ds, keys, index)
# duplicate
duplicate_ds = protein_sequence_ds_0 + protein_sequence_ds_0
self.assertTrue(duplicate_ds.has_duplicate_keys)
# ProteinSequenceDataset tests and raises
with TestFileContent(self.smi_content) as a_test_file:
with self.assertRaises(KeyError):
protein_sequence_ds = ProteinSequenceDataset(
a_test_file.filename, a_test_file.filename, backend=self.backend
)
[docs]class TestProteinSequenceDatasetLazyBackend(
TestProteinSequenceDatasetEagerBackend
): # noqa
"""Testing ProteinSequence dataset with lazy backend."""
[docs] def setUp(self):
self.backend = 'lazy'
print(f'backend is {self.backend}')
self.smi_content = SMI_CONTENT
self.smi_other_content = MORE_SMI_CONTENT
self.smi_broken_content = BROKEN_SMI_CONTENT
self.fasta_content = FASTA_CONTENT_GENERIC
self.as_smi_content = AS_SMI_CONTENT
self.as_more_smi_content = AS_MORE_SMI_CONTENT
[docs] def test___len__fasta(self) -> None:
"""Test __len__."""
with TestFileContent(self.fasta_content) as a_test_file:
protein_sequence_dataset = ProteinSequenceDataset(
a_test_file.filename, filetype='.fasta', backend=self.backend
)
# generic sequences
self.assertEqual(len(protein_sequence_dataset), 3)
time.sleep(1)
if __name__ == '__main__':
unittest.main()