Source code for pytoda.datasets.tests.test_set_matching_dataset
"""Testing SetMatchingDataset."""
import unittest
from typing import List
import torch
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader
from pytoda.datasets import (
DistributionalDataset,
PairedSetMatchingDataset,
PermutedSetMatchingDataset,
)
from pytoda.datasets.utils.factories import (
DISTRIBUTION_FUNCTION_FACTORY,
METRIC_FUNCTION_FACTORY,
)
seeds = [None, 42]
distribution_seeds = [None, 42]
min_set_length = [5, 2]
set_padding_value = 6.0
dataset_size = 250
max_set_length = 10
item_shape = (max_set_length, 4)
cost_metric = 'p-norm'
cost_metric_args = {'p': 2}
distribution_type = ['normal', 'uniform']
distribution_args = [{'loc': 0.0, 'scale': 1.0}, {'low': 0, 'high': 1}]
noise_std = [0.001, 0.1] # Pytorch >1.7 errors with a noise of 0.0.
cost_metric_function = METRIC_FUNCTION_FACTORY[cost_metric](**cost_metric_args)
permute = [True, False]
DATASET_FACTORY = {True: PermutedSetMatchingDataset, False: PairedSetMatchingDataset}
[docs]class TestSetMatchingDataset(unittest.TestCase):
"""Test SetMatchingDataset class."""
[docs] def test_permuted_set_matching_dataset(self) -> None:
"""Test PermutedSetMatchingDataset class."""
def tolist(x: torch.Tensor) -> List:
return x.flatten().tolist()
for dist_type, dist_args in zip(distribution_type, distribution_args):
distribution_function = DISTRIBUTION_FUNCTION_FACTORY[dist_type](
**dist_args
)
for dist_seed in distribution_seeds:
for seed in seeds:
for noise in noise_std:
for min_len in min_set_length:
s1 = DistributionalDataset(
dataset_size,
item_shape,
distribution_function,
seed=dist_seed,
)
datasets = [s1]
permuted_dataset = PermutedSetMatchingDataset(
*datasets,
min_len,
cost_metric_function,
set_padding_value=set_padding_value,
noise_std=noise,
seed=seed,
)
# Test length
self.assertEqual(len(permuted_dataset), dataset_size)
# Test __getitem__
sample1 = permuted_dataset[0]
sample2 = permuted_dataset[0]
self.assertIsInstance(sample1, tuple)
self.assertEqual(len(sample1), 5)
if dist_seed is not None and seed is not None:
# since both distribution seed and permutation seed
# are fixed, sampling twice with same index should return
# identical samples.
for item1, item2 in zip(sample1, sample2):
self.assertTrue(torch.equal(item1, item2))
self.assertTrue(
torch.equal(
sample1[1][sample1[2].long(), :],
sample2[1][sample2[2].long(), :],
)
)
# When noise =0, permutation of set2 should return set1
# and vice versa
if noise < 0.01:
for a, b in zip(
tolist(sample1[1][sample1[2].long(), :]),
tolist(sample1[0]),
):
self.assertAlmostEqual(a, b, places=2)
for a, b in zip(
tolist(sample1[0][sample1[3].long(), :]),
tolist(sample1[1]),
):
self.assertAlmostEqual(a, b, places=2)
else:
self.assertFalse(
torch.equal(
sample1[1][sample1[2].long(), :],
sample1[0],
)
)
self.assertFalse(
torch.equal(
sample1[0][sample1[3].long(), :],
sample1[1],
)
)
elif dist_seed is not None:
# Since only distribution seed is fixed, for fixed length
# settings the reference set returned at index 0 must be identical.
# NOTE: since items are padded when lengths vary,
# a lower limit on max set length is required to test
# that permutations are not equal when permutation seed is None.
if min_len == max_set_length:
self.assertTrue(torch.equal(sample1[0], sample2[0]))
if max_set_length > 3:
self.assertFalse(
torch.equal(sample1[2], sample2[2]),
msg=f'{sample1},{sample2}',
)
if noise < 0.01:
for a, b in zip(
tolist(sample1[1][sample1[2].long(), :]),
tolist(sample2[1][sample2[2].long(), :]),
):
self.assertAlmostEqual(a, b, places=2)
elif sample1[-1] != sample2[-1]:
# reason for asserting false is that length
# cropping is a random event dependent on
# permutation seed which is None in this setting
self.assertFalse(
torch.equal(sample1[0], sample2[0]),
msg=f'{sample1},{sample2}',
)
self.assertFalse(
torch.equal(sample1[1], sample2[1]),
)
if noise < 0.01:
for a, b in zip(
tolist(sample1[1][sample1[2].long(), :]),
tolist(sample1[0]),
):
self.assertAlmostEqual(a, b, places=2)
for a, b in zip(
tolist(sample1[0][sample1[3].long(), :]),
tolist(sample1[1]),
):
self.assertAlmostEqual(a, b, places=2)
else:
self.assertFalse(
torch.equal(
sample1[1][sample1[2].long(), :], sample1[0]
)
)
self.assertFalse(
torch.equal(
sample1[0][sample1[3].long(), :], sample1[1]
)
)
elif seed is not None:
# since distribution seed is None, the sampled sets
# must be different but the lengths and permutations
# should be the same since the permutation seed is set.
self.assertFalse(torch.equal(sample1[0], sample2[0]))
self.assertFalse(torch.equal(sample1[1], sample2[1]))
self.assertTrue(torch.equal(sample1[-1], sample2[-1]))
if noise < 0.01:
for a, b in zip(
tolist(sample1[2]),
tolist(sample2[2]),
):
self.assertAlmostEqual(a, b, places=2)
for a, b in zip(
tolist(sample1[1][sample1[2].long(), :]),
tolist(sample1[0]),
):
self.assertAlmostEqual(a, b, places=2)
for a, b in zip(
tolist(sample1[0][sample1[3].long(), :]),
tolist(sample1[1]),
):
self.assertAlmostEqual(a, b, places=2)
else:
self.assertFalse(
torch.equal(
sample1[1][sample1[2].long(), :],
sample1[0],
)
)
self.assertFalse(
torch.equal(
sample1[0][sample1[3].long(), :],
sample1[1],
)
)
else:
# since both seeds are none, the sets and cropped lengths
# must be different. Difference in permutations are only
# checked if length>3 due to item-wise padding.
self.assertFalse(torch.equal(sample1[0], sample2[0]))
self.assertFalse(torch.equal(sample1[1], sample2[1]))
[docs] def test_paired_set_matching_dataset(self) -> None:
"""Test PairedSetMatchingDataset class."""
# Similar reasoning with respect to seeds follows from above.
# Main difference is that the hungarian assignments, i.e, the targets
# returned are not tested since these assignments are not dependent on
# the permutation seed but only on the pair of sets generated.
noise = 0.0
for dist_type, dist_args in zip(distribution_type, distribution_args):
distribution_function = DISTRIBUTION_FUNCTION_FACTORY[dist_type](
**dist_args
)
for dist_seed in distribution_seeds:
for seed in seeds:
for min_len in min_set_length:
if dist_seed is not None:
seed_s1 = dist_seed
seed_s2 = dist_seed + 1
else:
seed_s1 = seed_s2 = dist_seed
s1 = DistributionalDataset(
dataset_size,
item_shape,
distribution_function,
seed=seed_s1,
)
datasets = [s1]
s2 = DistributionalDataset(
dataset_size,
item_shape,
distribution_function,
seed=seed_s2,
)
datasets.append(s2)
paired_dataset = PairedSetMatchingDataset(
*datasets,
min_len,
cost_metric_function,
set_padding_value=set_padding_value,
noise_std=noise,
seed=seed,
)
# Test length
self.assertEqual(len(paired_dataset), dataset_size)
# Test __getitem__
sample1 = paired_dataset[0]
sample2 = paired_dataset[0]
sample1_hungarian12 = linear_sum_assignment(
torch.cdist(sample1[0], sample1[1]).numpy()
)[1]
sample1_hungarian21 = linear_sum_assignment(
torch.cdist(sample1[1], sample1[0]).numpy()
)[1]
self.assertTrue(
torch.equal(
sample1[2].int(),
torch.from_numpy(sample1_hungarian12).int(),
)
)
self.assertTrue(
torch.equal(
sample1[3].int(),
torch.from_numpy(sample1_hungarian21).int(),
)
)
self.assertIsInstance(sample1, tuple)
self.assertEqual(len(sample1), 5)
self.assertFalse(torch.equal(sample1[0], sample1[1]))
if dist_seed is not None and seed is not None:
for item1, item2 in zip(sample1, sample2):
self.assertTrue(torch.equal(item1, item2))
elif dist_seed is not None:
if min_len == max_set_length:
for item1, item2 in zip(sample1, sample2):
self.assertTrue(torch.equal(item1, item2))
elif sample1[-1] != sample2[-1]:
for item1, item2 in zip(sample1[:2], sample2[:2]):
self.assertFalse(torch.equal(item1, item2))
elif seed is not None:
self.assertFalse(torch.equal(sample1[0], sample2[0]))
self.assertFalse(torch.equal(sample1[1], sample2[1]))
self.assertTrue(torch.equal(sample1[-1], sample2[-1]))
else:
for item1, item2 in zip(sample1[:2], sample2[:2]):
self.assertFalse(torch.equal(item1, item2))
[docs] def test_data_loader(self) -> None:
"""Test data_loader or SetMatchingDataset."""
for dist_seed in distribution_seeds:
for dist_type, dist_args in zip(distribution_type, distribution_args):
distribution_function = DISTRIBUTION_FUNCTION_FACTORY[dist_type](
**dist_args
)
for noise in noise_std:
for seed in seeds:
for permute_ in permute:
for min_len in min_set_length:
if dist_seed is None:
seed_s1 = seed_s2 = dist_seed
else:
seed_s1 = dist_seed
seed_s2 = dist_seed + 1
s1 = DistributionalDataset(
dataset_size,
item_shape,
distribution_function,
seed=seed_s1,
)
datasets = [s1]
if not permute_:
s2 = DistributionalDataset(
dataset_size,
item_shape,
distribution_function,
seed=seed_s2,
)
datasets.append(s2)
setmatch_dataset = DATASET_FACTORY[permute_](
*datasets,
min_len,
cost_metric_function,
set_padding_value=set_padding_value,
noise_std=noise,
seed=seed,
)
data_loader = DataLoader(
setmatch_dataset,
batch_size=25,
)
for batch_index, batch in enumerate(data_loader):
(
set1_batch,
set2_batch,
idx12_batch,
idx21_batch,
len_batch,
) = batch
self.assertEqual(
set1_batch.shape,
(
25,
max_set_length,
4,
),
)
self.assertTrue(
torch.unique(set1_batch, dim=0).size(0) == 25
)
self.assertTrue(
torch.unique(set2_batch, dim=0).size(0) == 25
)
self.assertEqual(set1_batch.shape, set2_batch.shape)
self.assertEqual(
idx12_batch.shape, idx21_batch.shape
)
self.assertEqual(len(len_batch), 25)
self.assertFalse(
torch.equal(set1_batch, set2_batch)
)
if permute_ and noise_std == 0.0:
ordered_set1 = set1_batch[
torch.arange(0, idx21_batch.size(0))
.unsqueeze(1)
.repeat((1, idx21_batch.size(1))),
idx21_batch,
]
ordered_set2 = set2_batch[
torch.arange(0, idx12_batch.size(0))
.unsqueeze(1)
.repeat((1, idx12_batch.size(1))),
idx12_batch,
]
self.assertTrue(
torch.equal(set1_batch, ordered_set2)
)
self.assertTrue(
torch.equal(ordered_set1, set2_batch)
)
if __name__ == '__main__':
unittest.main()