Source code for pytoda.tests.test_transforms
"""Testing transforms."""
import random
import unittest
import torch
from pytoda.transforms import (
AugmentByReversing,
Compose,
LeftPadding,
ListToTensor,
ToTensor,
)
[docs]class TestTransforms(unittest.TestCase):
"""Testing transforms."""
[docs] def test_left_padding(self) -> None:
"""Test LeftPadding."""
padding_index = 0
padding_lengths = [8, 4]
# Molecules that are too long will be cut and a warning will be raised.
for padding_length in padding_lengths:
transform = LeftPadding(
padding_index=padding_index, padding_length=padding_length
)
for mol in ['C(N)CS', 'CCO']:
self.assertEqual(len(transform(list(mol))), padding_length)
[docs] def test_augment_by_reversing(self) -> None:
"""Test AugmentByReversing."""
sequence = 'ABC'
ground_truths = ['ABC', 'ABC', 'CBA']
for k in range(15):
for p, ground_truth in zip([0.0, 0.5, 1.0], ground_truths):
random.seed(42)
transform = AugmentByReversing(p=p)
self.assertEqual(transform(sequence), ground_truth)
[docs] def test_to_tensor(self) -> None:
"""Test ToTensor."""
tokens = [2, 3, 4]
transform = ToTensor()
tensor = transform(tokens)
self.assertListEqual(
[tokens[0], tokens[1], tokens[2]], [tensor[0], tensor[1], tensor[2]]
)
self.assertTrue(torch.is_tensor(tensor))
self.assertEqual(len(tensor), 3)
self.assertEqual(len(tensor.shape), 1)
self.assertRaises(TypeError, ToTensor, dtype=42)
[docs] def test_list_to_tensor(self) -> None:
"""Test ListToTensor."""
tokens = [(2, 3, 4), (2, 3, 4)]
transform = ListToTensor()
tensor = transform(tokens)
self.assertEqual(len(tensor), 2)
self.assertEqual(len(tensor.shape), 2)
self.assertEqual(tensor.shape[-1], 3)
self.assertListEqual(
[tokens[0][0], tokens[0][1], tokens[0][2]],
[tensor[0][0], tensor[0][1], tensor[0][2]],
)
self.assertTrue(torch.is_tensor(tensor))
self.assertRaises(TypeError, ListToTensor, dtype=42)
[docs] def test_compose(self) -> None:
"""Test Compose."""
# Test equality
c1 = Compose([ToTensor()])
c2 = Compose([ToTensor()])
c3 = Compose([LeftPadding(padding_length=2, padding_index=0), ToTensor()])
c4 = Compose([ToTensor(dtype=torch.long)])
self.assertTrue(c1 == c2)
self.assertFalse(c1 == c3)
self.assertFalse(c1 == c4)
# Test repr
self.assertIsNotNone(repr(c1))
if __name__ == '__main__':
unittest.main()