Source code for pytoda.datasets.utils.wrappers

import torch
import torch.nn as nn

from ...types import Tensor


[docs]class WrapperCDist(nn.Module): """Wrapper for torch.cdist module for easy argument passing."""
[docs] def __init__(self, p: int = 2) -> None: """Constructor. Args: p (int, optional): p value for the p-norm distance to calculate between each vector pair. Defaults to 2. """ super(WrapperCDist, self).__init__() self.p = p
[docs] def forward(self, set1: Tensor, set2: Tensor) -> Tensor: """Computes the pairwise p-norms. Args: set1 (Tensor): Input tensor of shape [batch_size, length1, dim] set2 (Tensor): Input tensor of shape [batch_size, length2, dim] Returns: Tensor: Tensor of shape [batch_size, length1, length2] representing the pairwise distances. """ return torch.cdist(set1, set2, self.p)
[docs]class WrapperKLDiv(nn.Module): """Wrapper for KL-Divergence for easy argument passing."""
[docs] def __init__(self, reduction: str = 'mean') -> None: """Constructor. Args: reduction (str, optional): One of 'none','batchmean','sum', 'mean'. Defaults to 'mean'. """ super(WrapperKLDiv, self).__init__() self.reduction = reduction
[docs] def forward(self, set1: Tensor, set2: Tensor) -> Tensor: """Computes the KL-Divergence. Args: set1 (Tensor): Input tensor of arbitrary shape. set2 (Tensor): Tensor of the same shape as input. Returns: Tensor: Scalar by default. if reduction = 'none', then same shape as input. """ return nn.functional.kl_div(set1, set2, reduction=self.reduction)