pytoda.datasets.utils.wrappers module

Summary

Classes:

WrapperCDist

Wrapper for torch.cdist module for easy argument passing.

WrapperKLDiv

Wrapper for KL-Divergence for easy argument passing.

Reference

class WrapperCDist(p=2)[source]

Bases: torch.nn.modules.module.Module

Wrapper for torch.cdist module for easy argument passing.

__init__(p=2)[source]

Constructor.

Parameters

p (int, optional) – p value for the p-norm distance to calculate between each vector pair. Defaults to 2.

forward(set1, set2)[source]

Computes the pairwise p-norms.

Parameters
  • set1 (Tensor) – Input tensor of shape [batch_size, length1, dim]

  • set2 (Tensor) – Input tensor of shape [batch_size, length2, dim]

Returns

Tensor of shape [batch_size, length1, length2]

representing the pairwise distances.

Return type

Tensor

training: bool
class WrapperKLDiv(reduction='mean')[source]

Bases: torch.nn.modules.module.Module

Wrapper for KL-Divergence for easy argument passing.

__init__(reduction='mean')[source]

Constructor.

Parameters

reduction (str, optional) – One of ‘none’,’batchmean’,’sum’, ‘mean’. Defaults to ‘mean’.

forward(set1, set2)[source]

Computes the KL-Divergence.

Parameters
  • set1 (Tensor) – Input tensor of arbitrary shape.

  • set2 (Tensor) – Tensor of the same shape as input.

Returns

Scalar by default. if reduction = ‘none’, then same

shape as input.

Return type

Tensor

training: bool