pytoda.datasets.distributional_dataset module¶
Summary¶
Classes:
Generates samples from a specified distribution. |
|
Sample an item from a distribution on the fly on indexing. |
Reference¶
-
class
StochasticItems
(distribution, shape, device=None)[source]¶ Bases:
object
Sample an item from a distribution on the fly on indexing.
- Parameters
distribution (torch.distributions.distribution.Distribution) – An instance of the torch distribution class to sample from. For example, for loc = torch.tensor(0.0), and scale=torch.tensor(1.0), torch.distributions.normal.Normal(loc,scale), so that calling .sample() would return an item from this distribution.
shape (torch.Size) – The desired shape of each item.
device (torch.device) – DEPRECATED
-
class
DistributionalDataset
(dataset_size, item_shape, distribution_function, seed=None, device=None)[source]¶ Bases:
Generic
[torch.utils.data.dataset.T_co
]Generates samples from a specified distribution.
-
__init__
(dataset_size, item_shape, distribution_function, seed=None, device=None)[source]¶ Dataset of synthetic samples from a specified distribution with given shape.
- Parameters
dataset_size (int) – Number of items to generate (N).
item_shape (Tuple[int]) – The shape of each item tensor returned on indexing the dataset. For example for 2D items with timeseries of 3 timesteps and 5 features: (3, 5)
distribution_function (torch.distributions.distribution.Distribution) – An instance of the distribution class from which individual data items can be sampled by calling the .sample() method. This can either be an object that is directly initialised using a method from torch.distributions, such as, torch.distributions.normal.Normal(loc=0.0,scale=1.0), or from a factory using a keyword, for example, DISTRIBUTION_FUNCTION_FACTORY[‘normal](loc=0.0, scale=1.0) is a valid argument since the factory (found in utils.factories.py) initialises the distribution class object based on a string keyword and passes the relevant arguments to that object.
seed (Optional[int]) – If passed, all items are generated once with this seed (using a local RNG only). Defaults to None, where individual items are generated when the DistributionalDataset is indexed (using the global RNG).
device (torch.device) – DEPRECATED
-