pytoda.datasets.annotated_dataset module

Implementation of AnnotatedDataset class.

Summary

Classes:

AnnotatedDataset

Annotated samples in order of annotations csv, fetching data from passed dataset.

Reference

class AnnotatedDataset(annotations_filepath, dataset, annotation_index=- 1, label_columns=None, dtype=torch.float32, device=None, **kwargs)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Annotated samples in order of annotations csv, fetching data from passed dataset.

__init__(annotations_filepath, dataset, annotation_index=- 1, label_columns=None, dtype=torch.float32, device=None, **kwargs)[source]

Initialize an annotated dataset via additional annotations dataframe. E.g. the dataset could be SMILES and the annotations could be single or multi task labels.

Parameters
  • annotations_filepath (str) – path to the annotations of a dataset. Currently, the supported formats are column separated files. The default structure assumes that the last column contains an id that is also used in the dataset provided.

  • dataset (AnyBaseDataset) – instance of a AnyBaseDataset (supporting key lookup API of KeyDataset), e.g. a SMILESDataset.

  • annotation_index (Union[int, str]) – positional or string for the column containing the annotation index of keys to get items in the passed dataset. Defaults to -1, i.e. the last column.

  • label_columns (Union[List[int], List[str]]) – indexes (positional or strings) for the annotations. Defaults to None, a.k.a. all the columns, except the annotation index, are considered annotation labels.

  • dtype (torch.dtype) – torch data type for labels. Defaults to torch.float.

  • device (torch.device) – DEPRECATED

  • kwargs (dict) – additional parameter for pd.read_csv.

get_item_from_key(key)[source]

Get item via key.

Parameters

key (Hashable) – key of the item and annotations to fetch.

Returns

a tuple containing the item itself (with type

depending on passed dataset) and a torch.Tensor of labels for the current item.

Return type

AnnotatedData