pytoda.datasets.utils.wrappers module¶
Summary¶
Classes:
Wrapper for torch.cdist module for easy argument passing. |
|
Wrapper for KL-Divergence for easy argument passing. |
Reference¶
-
class
WrapperCDist
(p=2)[source]¶ Bases:
torch.nn.modules.module.Module
Wrapper for torch.cdist module for easy argument passing.
-
__init__
(p=2)[source]¶ Constructor.
- Parameters
p (int, optional) – p value for the p-norm distance to calculate between each vector pair. Defaults to 2.
-
forward
(set1, set2)[source]¶ Computes the pairwise p-norms.
- Parameters
set1 (Tensor) – Input tensor of shape [batch_size, length1, dim]
set2 (Tensor) – Input tensor of shape [batch_size, length2, dim]
- Returns
- Tensor of shape [batch_size, length1, length2]
representing the pairwise distances.
- Return type
Tensor
-
training
: bool¶
-
-
class
WrapperKLDiv
(reduction='mean')[source]¶ Bases:
torch.nn.modules.module.Module
Wrapper for KL-Divergence for easy argument passing.
-
__init__
(reduction='mean')[source]¶ Constructor.
- Parameters
reduction (str, optional) – One of ‘none’,’batchmean’,’sum’, ‘mean’. Defaults to ‘mean’.
-
forward
(set1, set2)[source]¶ Computes the KL-Divergence.
- Parameters
set1 (Tensor) – Input tensor of arbitrary shape.
set2 (Tensor) – Tensor of the same shape as input.
- Returns
- Scalar by default. if reduction = ‘none’, then same
shape as input.
- Return type
Tensor
-
training
: bool¶
-