Source code for pytoda.datasets.tests.test_annotated_dataset

"""Testing AnnotatedDataset dataset with eager backend."""
import os
import unittest

import numpy as np

from pytoda.datasets import AnnotatedDataset, SMILESTokenizerDataset, indexed, keyed
from pytoda.tests.utils import TestFileContent

# must contain all keys in annotated
SMILES_CONTENT = os.linesep.join(
    ['CCO	CHEMBL545', 'C	CHEMBL17564', 'CO	CHEMBL14688', 'NCCS	CHEMBL602']
)
ANNOTATED_CONTENT = os.linesep.join(
    [
        'label_0,label_1,annotation_index',
        '2.3,3.4,CHEMBL545',
        '4.5,5.6,CHEMBL17564',
        '6.7,7.8,CHEMBL602',
    ]
)


[docs]class TestAnnotatedDataset(unittest.TestCase): """Testing annotated dataset."""
[docs] def setUp(self): self.smiles_content = SMILES_CONTENT self.annotated_content = ANNOTATED_CONTENT
[docs] def test___getitem__(self) -> None: """Test __getitem__.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset( smiles_file.filename, add_start_and_stop=True, backend='eager' ) annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset ) pad_index = smiles_dataset.smiles_language.padding_index start_index = smiles_dataset.smiles_language.start_index stop_index = smiles_dataset.smiles_language.stop_index c_index = smiles_dataset.smiles_language.token_to_index['C'] o_index = smiles_dataset.smiles_language.token_to_index['O'] n_index = smiles_dataset.smiles_language.token_to_index['N'] s_index = smiles_dataset.smiles_language.token_to_index['S'] # test first sample smiles_tokens, labels = annotated_dataset[0] self.assertEqual( smiles_tokens.numpy().flatten().tolist(), [pad_index, start_index, c_index, c_index, o_index, stop_index], ) self.assertTrue(np.allclose(labels.numpy().flatten().tolist(), [2.3, 3.4])) # test last sample smiles_tokens, labels = annotated_dataset[2] self.assertEqual( smiles_tokens.numpy().flatten().tolist(), [start_index, n_index, c_index, c_index, s_index, stop_index], ) self.assertTrue(np.allclose(labels.numpy().flatten().tolist(), [6.7, 7.8]))
[docs] def test___getitem___from_indexed_annotation(self) -> None: """Test __getitem__ with index in the annotation file.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset( smiles_file.filename, add_start_and_stop=True, backend='eager' ) annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset ) pad_index = smiles_dataset.smiles_language.padding_index start_index = smiles_dataset.smiles_language.start_index stop_index = smiles_dataset.smiles_language.stop_index c_index = smiles_dataset.smiles_language.token_to_index['C'] o_index = smiles_dataset.smiles_language.token_to_index['O'] n_index = smiles_dataset.smiles_language.token_to_index['N'] s_index = smiles_dataset.smiles_language.token_to_index['S'] # test first sample smiles_tokens, labels = annotated_dataset[0] self.assertEqual( smiles_tokens.numpy().flatten().tolist(), [pad_index, start_index, c_index, c_index, o_index, stop_index], ) self.assertTrue(np.allclose(labels.numpy().flatten().tolist(), [2.3, 3.4])) # test last sample smiles_tokens, labels = annotated_dataset[2] self.assertEqual( smiles_tokens.numpy().flatten().tolist(), [start_index, n_index, c_index, c_index, s_index, stop_index], ) self.assertTrue(np.allclose(labels.numpy().flatten().tolist(), [6.7, 7.8]))
def _test_indexed(self, ds, keys, index): key = keys[index] positive_index = index % len(ds) # get_key (support for negative index?) self.assertEqual(key, ds.get_key(positive_index)) self.assertEqual(key, ds.get_key(index)) # get_index self.assertEqual(positive_index, ds.get_index(key)) # get_item_from_key returning labels as well for from_index, from_key in zip(ds[index], ds.get_item_from_key(key)): self.assertTrue(all(from_index == from_key)) # keys self.assertSequenceEqual(keys, list(ds.keys())) # duplicate keys self.assertFalse(ds.has_duplicate_keys)
[docs] def test_all_base_for_indexed_methods(self): with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset( smiles_file.filename, add_start_and_stop=True, backend='eager' ) annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset, index_col=0, label_columns=['label_1'], ) duplicate_ds = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset + smiles_dataset, ) all_keys = [ row.split(',')[-1] for row in self.annotated_content.split(os.linesep)[1:] ] for ds, keys in [ (annotated_dataset, all_keys), ]: index = -1 self._test_indexed(ds, keys, index) # duplicates in datasource can be checked directly self.assertTrue(duplicate_ds.datasource.has_duplicate_keys) # DataFrame is the dataset self.assertFalse(duplicate_ds.has_duplicate_keys)
[docs]class TestChangeIndexingReturn(unittest.TestCase): """Testing annotated dataset."""
[docs] def setUp(self): self.smiles_content = SMILES_CONTENT self.annotated_content = ANNOTATED_CONTENT
[docs] def test_return_integer_index(self) -> None: """Test __getitem__ with index in dataset.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset(smiles_file.filename) # default default_annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset ) # outer indexed_annotated_dataset = indexed(default_annotated_dataset) # inner annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=indexed(smiles_dataset) ) # default # relevant to check that `smiles_dataset` was not mutated from # the `indexed(smiles_dataset)` call smiles_tokens, labels = default_annotated_dataset[2] self.check_CHEMBL602(smiles_tokens, labels) # outer __getitem__ (smiles_tokens, labels), sample_index = indexed_annotated_dataset[2] self.assertEqual(sample_index, 2) self.check_CHEMBL602(smiles_tokens, labels) # outer get_item_from_key ( (smiles_tokens, labels), sample_index, ) = indexed_annotated_dataset.get_item_from_key('CHEMBL602') self.assertEqual(sample_index, 2) self.check_CHEMBL602(smiles_tokens, labels) # inner __getitem__ (smiles_tokens, sample_index), labels = annotated_dataset[0] self.assertEqual(sample_index, 0) # inner __getitem__ with different index in smiles_dataset (smiles_tokens, sample_index), labels = annotated_dataset[2] self.assertEqual(sample_index, 3) self.check_CHEMBL602(smiles_tokens, labels) # inner get_item_from_key (smiles_tokens, sample_index), labels = annotated_dataset.get_item_from_key( 'CHEMBL602' ) self.assertEqual(sample_index, 3) self.check_CHEMBL602(smiles_tokens, labels)
[docs] def test_return_key(self) -> None: """Test __getitem__ with key in dataset.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset(smiles_file.filename) # default default_annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset ) # outer keyed_annotated_dataset = keyed(default_annotated_dataset) # inner annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=keyed(smiles_dataset) ) # default # relevant to check that `smiles_dataset` was not mutated from # the `keyed(smiles_dataset)` call smiles_tokens, labels = default_annotated_dataset[2] self.check_CHEMBL602(smiles_tokens, labels) # outer __getitem__ (smiles_tokens, labels), sample_key = keyed_annotated_dataset[2] self.assertEqual(sample_key, 'CHEMBL602') self.check_CHEMBL602(smiles_tokens, labels) # outer get_item_from_key (smiles_tokens, labels), sample_key = keyed_annotated_dataset.get_item_from_key( 'CHEMBL602' ) self.assertEqual(sample_key, 'CHEMBL602') self.check_CHEMBL602(smiles_tokens, labels) # inner __getitem__ (smiles_tokens, sample_key), labels = annotated_dataset[0] self.assertEqual(sample_key, 'CHEMBL545') # inner __getitem__ with different index in smiles_dataset (smiles_tokens, sample_key), labels = annotated_dataset[2] self.assertEqual(sample_key, 'CHEMBL602') self.check_CHEMBL602(smiles_tokens, labels) # inner get_item_from_key (smiles_tokens, sample_key), labels = annotated_dataset.get_item_from_key( 'CHEMBL602' ) self.assertEqual(sample_key, 'CHEMBL602') self.check_CHEMBL602(smiles_tokens, labels)
[docs] def test_return_key_index_stacked(self) -> None: """Test __getitem__ with key in dataset.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = keyed( indexed( SMILESTokenizerDataset( smiles_file.filename, ) ) ) annotated_dataset = indexed( keyed( AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset, ) ) ) (smiles_tokens, smiles_index), smiles_key = smiles_dataset[3] self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.check_CHEMBL602(smiles_tokens) (smiles_tokens, smiles_index), smiles_key = smiles_dataset.get_item_from_key( 'CHEMBL602' ) self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.check_CHEMBL602(smiles_tokens) ( ( (((smiles_tokens, smiles_index), smiles_key), labels), # inner annotation_key, ), annotation_index, ) = annotated_dataset[2] self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.assertEqual(annotation_index, 2) self.assertEqual(annotation_key, 'CHEMBL602') ( ( (((smiles_tokens, smiles_index), smiles_key), labels), # inner annotation_key, ), annotation_index, ) = annotated_dataset.get_item_from_key('CHEMBL602') self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.assertEqual(annotation_index, 2) self.assertEqual(annotation_key, 'CHEMBL602')
[docs] def check_CHEMBL602(self, smiles_tokens, labels=None): """Check indexing results lack unwanted index/key from hidden calls""" self.assertEqual(len(smiles_tokens.numpy().flatten().tolist()), 4) if labels is not None: self.assertEqual(len(labels.numpy().flatten().tolist()), 2)
if __name__ == '__main__': unittest.main()