Source code for pytoda.datasets.dataframe_dataset

"""KeyDataset for pandas DataFrames ."""
import numpy as np
import pandas as pd

from ..types import Hashable, Iterator
from .base_dataset import KeyDataset


[docs]class DataFrameDataset(KeyDataset): """ Dataset of rows from pandas.DataFrame """ def __init__(self, df: pd.DataFrame): super(DataFrameDataset).__init__() self.df = df def __len__(self) -> int: """Total number of samples.""" return len(self.df) def __getitem__(self, index: int) -> np.array: """ Generates one sample of data. Args: index (int): index of the sample to fetch. Returns: np.array: a selected row of the CSV. """ return self.df.iloc[index].values
[docs] def get_key(self, index: int) -> Hashable: """Get key from integer index.""" return self.df.index[index]
[docs] def get_index(self, key: Hashable) -> int: """Get index for first datum mapping to the given key.""" # item will raise if not single value (deprecated in pandas) try: indexes = np.nonzero(self.df.index == key)[0] return indexes.item() except ValueError: if len(indexes) == 0: raise KeyError else: # key not unique, return first as ConcatKeyDataset return indexes[0]
[docs] def get_item_from_key(self, key: Hashable) -> np.array: """Get item via key""" return self.df.loc[key, :].values
[docs] def keys(self) -> Iterator: """Iterator over index of dataframe.""" return iter(self.df.index)