Source code for pytoda.datasets.tests.test_distributional_dataset

"""Testing DistributionalDataset."""
import unittest

import torch

from pytoda.datasets import DistributionalDataset
from pytoda.datasets.utils.factories import DISTRIBUTION_FUNCTION_FACTORY

distribution_types = ['normal', 'uniform']
distribution_args = [{'loc': 0.0, 'scale': 1.0}, {'low': 0.0, 'high': 1.0}]
dataset_sizes = [100, 10000]
item_shapes = [(1, 16), (10, 16)]
seeds = [None, 1, 42]


[docs]class TestDistributionalDataset(unittest.TestCase): """Test DistributionalDataset class."""
[docs] def test__len__(self) -> None: """Test __len__.""" for dist_type, dist_args in zip(distribution_types, distribution_args): distribution_function = DISTRIBUTION_FUNCTION_FACTORY[dist_type]( **dist_args ) for size in dataset_sizes: for shape in item_shapes: for seed in seeds: dataset = DistributionalDataset( size, shape, distribution_function, seed=seed, ) self.assertEqual(len(dataset), size)
[docs] def test__getitem__(self) -> None: """Test __getitem__.""" for dist_type, dist_args in zip(distribution_types, distribution_args): distribution_function = DISTRIBUTION_FUNCTION_FACTORY[dist_type]( **dist_args ) for size in dataset_sizes: for shape in item_shapes: for seed in seeds: dataset = DistributionalDataset( size, shape, distribution_function, seed=seed, ) sample1_1 = dataset[42] sample1_2 = dataset[42] # Test shapes self.assertEqual(sample1_1.shape, sample1_2.shape) self.assertEqual(sample1_1.shape, shape) # Test content if seed is None: self.assertFalse(torch.equal(sample1_1, sample1_2)) else: self.assertTrue(torch.equal(sample1_1, sample1_2))
if __name__ == '__main__': unittest.main()