pytoda.datasets.distributional_dataset module

Summary

Classes:

DistributionalDataset

Generates samples from a specified distribution.

StochasticItems

Sample an item from a distribution on the fly on indexing.

Reference

class StochasticItems(distribution, shape, device=None)[source]

Bases: object

Sample an item from a distribution on the fly on indexing.

Parameters
  • distribution (torch.distributions.distribution.Distribution) – An instance of the torch distribution class to sample from. For example, for loc = torch.tensor(0.0), and scale=torch.tensor(1.0), torch.distributions.normal.Normal(loc,scale), so that calling .sample() would return an item from this distribution.

  • shape (torch.Size) – The desired shape of each item.

  • device (torch.device) – DEPRECATED

class DistributionalDataset(dataset_size, item_shape, distribution_function, seed=None, device=None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Generates samples from a specified distribution.

__init__(dataset_size, item_shape, distribution_function, seed=None, device=None)[source]

Dataset of synthetic samples from a specified distribution with given shape.

Parameters
  • dataset_size (int) – Number of items to generate (N).

  • item_shape (Tuple[int]) – The shape of each item tensor returned on indexing the dataset. For example for 2D items with timeseries of 3 timesteps and 5 features: (3, 5)

  • distribution_function (torch.distributions.distribution.Distribution) – An instance of the distribution class from which individual data items can be sampled by calling the .sample() method. This can either be an object that is directly initialised using a method from torch.distributions, such as, torch.distributions.normal.Normal(loc=0.0,scale=1.0), or from a factory using a keyword, for example, DISTRIBUTION_FUNCTION_FACTORY[‘normal](loc=0.0, scale=1.0) is a valid argument since the factory (found in utils.factories.py) initialises the distribution class object based on a string keyword and passes the relevant arguments to that object.

  • seed (Optional[int]) – If passed, all items are generated once with this seed (using a local RNG only). Defaults to None, where individual items are generated when the DistributionalDataset is indexed (using the global RNG).

  • device (torch.device) – DEPRECATED