pytoda.datasets.set_matching_dataset module

Summary

Classes:

PairedSetMatchingDataset

Dataset class for the case where set to match is another random set.

PermutedSetMatchingDataset

Dataset class for the case where set to match is permuted.

SetMatchingDataset

Base class for set matching datasets.

Functions:

get_subsampling_indexes

Return indexers to remove random elements of an item and it’s permutation.

hungarian_assignment

Compute targets for one training sample.

Reference

hungarian_assignment(set_reference, set_matching, cost_metric_function)[source]

Compute targets for one training sample.

Parameters
  • set_reference (torch.Tensor) – Tensor with elements of set_reference.

  • set_matching (torch.Tensor) – Tensor with elements of set_matching.

  • cost_metric_function (nn.Module) – Function wrapped as an nn.Module that computes the metric used in constructing the cost matrix.

Returns

Tuple containing hungarian matching indices of set1 vs set2 and

set2 vs set1.

Return type

Tuple

get_subsampling_indexes(min_set_length, max_set_length, permutation, shuffle=True)[source]

Return indexers to remove random elements of an item and it’s permutation.

Parameters
  • min_set_length (int) – Minimum number of elements in the set.

  • max_set_length (int) – Maximum number of elements in the set.

  • permutation (Tensor) – Tensor of integers defining a permutation, that are indices of a range in arbitrary order.

  • shuffle (bool) – The first returned indexer also shuffles the elements.

Returns

A Tensor of integers for indexing a subset of elements (shuffled or not). A Tensor of integers for indexing the same elements in a permuted item. Number of elements in the item (length).

Return type

Tuple[Tensor, Tensor, int]

class SetMatchingDataset(min_set_length, cost_metric_function, set_padding_value=0.0, seed=None, shuffle=True, device=None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Base class for set matching datasets.

dataset: torch.utils.data.dataset.Dataset
__init__(min_set_length, cost_metric_function, set_padding_value=0.0, seed=None, shuffle=True, device=None)[source]
Base Class for set matching datasets, that allows returning subsets of

varying length controlled by the passed min set length.

Parameters
  • min_set_length (int) – Lower bound on number of elements in the returned sets. This should be equal to the maximum item length of the passed dataset if varying set lengths are not desired.

  • cost_metric_function (nn.Module) – Function wrapped as an nn.Module that computes the metric used in constructing the cost matrix.

  • set_padding_value (float) – The constant value with which to pad each set. Defaults to 0.0. NOTE: No padding is done if min_set_length = max_set_length.

  • seed (Optional[int]) – If passed, all samples are generated once with this seed (using a local RNG only). Defaults to None, where individual samples are generated when the DistributionalDataset is indexed (using the global RNG).

  • shuffle (Optional[bool]) – Whether the sets should be shuffled again before subsampling. Adds another layer of randomness. Defaults to True.

  • device (torch.device) – DEPRECATED

Note

  1. Requires child classes to set the dataset attribute.

  2. Batches generated by DataLoader will have batch size in the first dim.

property permutation

Class attribute that defines the permutation to use in creating set_matching.

Raises

NotImplementedError – Not implemented by the base class. Attribute overwritten by the child class.

Return type

Tensor

get_matching_set(index, reference_set)[source]

Gets the corresponding set to match to the reference set.

Parameters
  • index (int) – The index to be sampled.

  • reference_set (Tensor) – Tensor that represents samples of the reference set.

Raises

NotImplementedError – Not implemented by the base class. Function overwritten by the child class.

Return type

Tensor

class PairedSetMatchingDataset(dataset, dataset_to_match, min_set_length, cost_metric_function, set_padding_value, seed=None, shuffle=True, noise_std=0.0)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Dataset class for the case where set to match is another random set.

__init__(dataset, dataset_to_match, min_set_length, cost_metric_function, set_padding_value, seed=None, shuffle=True, noise_std=0.0)[source]

Pairs of sets and their hungarian assignments, where the set to match is provided by a second dataset.

Parameters
  • dataset (Dataset) – Object containing torch.utils.data.Dataset or child classes that represents samples of set_reference.

  • dataset_to_match (Dataset) – Object containing torch.utils.data.Dataset or child classes that represents samples of set_matching.

  • min_set_length (int) – Lower bound on number of elements in the returned sets. This should be equal to the maximum item length of the passed dataset if varying set lengths are not desired.

  • set_padding_value (float) – The constant value with which to pad each set. Defaults to 0.0. NOTE: No padding if min_set_length = max_set_length.

  • cost_metric_function (nn.Module) – Function wrapped as an nn.Module that computes the metric used in constructing the cost matrix.

  • seed (Optional[int]) – If passed, all samples are generated once with this seed (using a local RNG only). Defaults to None, where individual samples are generated when the DistributionalDataset is indexed (using the global RNG).

  • shuffle (Optional[bool]) – Whether the sets should be shuffled again before subsampling. Adds another layer of randomness. Defaults to True.

  • noise_std (float, optional) – Standard deviation to use in generating noise from a normal distribution with mean 0. Deafults to 0.0. Dummy variable for this class for consistency purposes.

dataset: torch.utils.data.dataset.Dataset
property permutation

Class attribute that defines the permutation to use in creating set_matching.

Returns

A fixed tensor containing the range of max_set_length.

Return type

Tensor

get_matching_set(index, set_reference)[source]

Gets the corresponding set to match to the reference set.

Parameters
  • index (int) – The index to be sampled.

  • reference_set (Tensor) – Tensor that represents samples of the reference set.

Returns

Tensor of the corresponding matching set.

Return type

Tensor

class PermutedSetMatchingDataset(dataset, min_set_length, cost_metric_function, set_padding_value, seed=None, noise_std=0.0, shuffle=True)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Dataset class for the case where set to match is permuted.

__init__(dataset, min_set_length, cost_metric_function, set_padding_value, seed=None, noise_std=0.0, shuffle=True)[source]

Pairs of sets and their hungarian assignments, where the set to match is a permutation of the given sets.

Parameters
  • dataset (Dataset) – Object containing torch.utils.data.Dataset or child classes that represents samples of set_reference.

  • min_set_length (int) – Lower bound on number of elements in the returned sets. This should be equal to the maximum item length of the passed dataset if varying set lengths are not desired.

  • set_padding_value (float) – The constant value with which to pad each set. Defaults to 0.0. NOTE: No padding is done if min_set_length = max_set_length.

  • cost_metric_function (nn.Module) – Function wrapped as an nn.Module that computes the metric used in constructing the cost matrix.

  • seed (Optional[int]) – If passed, all samples are generated once with this seed (using a local RNG only). Defaults to None, where individual samples are generated when the DistributionalDataset is indexed (using the global RNG).

  • noise_std (float, optional) – Standard deviation to use in generating noise from a normal distribution with mean 0. Deafults to 0.0.

  • shuffle (Optional[bool]) – Whether the sets should be shuffled again before subsampling. Adds another layer of randomness. Defaults to True.

dataset: torch.utils.data.dataset.Dataset
property permutation

Class attribute that defines the permutation to use in creating set_matching.

Returns

Tensor of a randomly generated permutation of length max_set_length.

Return type

Tensor

get_matching_set(index, set_reference)[source]

Gets the corresponding set to match to the reference set.

Parameters
  • index (int) – The index to be sampled.

  • reference_set (Tensor) – Tensor that represents samples of the reference set.

Returns

Tensor of the permuted reference set with additive noise.

Return type

Tensor