Source code for pytoda.data_splitter

"""Data splitting utilties."""
import hashlib
import os
import random
from math import ceil
from typing import Tuple

import numpy as np
import pandas as pd

from .types import Files


[docs]def csv_data_splitter( data_filepaths: Files, save_path: str, data_type: str, mode: str, seed: int = 42, test_fraction: float = 0.1, number_of_columns: int = 12, **kwargs, ) -> Tuple[str, str]: """ Function for generic splitting into train and test data in csv format. This is an eager splitter trying to fit the entire dataset into memory. Args: data_filepaths (Files): a list of .csv files that contain the data. save_path (str): folder to store the training/testing dataset. data_type (str): data type (only used as prefix for the saved files). mode (str): mode to split data from: "random" and "file". - random: does a random split across all samples in all files. - file: randomly splits the files into training and testing. seed (int): random seed used for the split. Defaults to 42. test_fraction (float): portion of samples in testing data. Defaults to 0.1. number_of_columns (int): number of columns used to create the hash. Defaults to 12. kwargs (dict): additional parameters for pd.read_csv. Returns: Tuple[str, str]: a tuple pointing to the train and test files. """ # preparation random.seed(seed) data_filepaths = ( [data_filepaths] if isinstance(data_filepaths, str) else data_filepaths ) file_suffix = '{}_{}_fraction_{}_id_{}_seed_{}.csv' hash_fn = hashlib.md5() if not ('index_col' in kwargs): kwargs['index_col'] = 0 # helper function to produce unique hash. Base hash on df.to_string() # but only consider first number_of_columns columns (slow otherwise) def _hash_from_df_columns(df, number_of_columns): return df.reindex(sorted(df.columns), axis=1).to_string( columns=sorted(df.columns)[:number_of_columns] ) # NOTE: if all *.csv files contain a single sample only # the splitting modes collapse # file splitting mode if mode == 'file': if len(data_filepaths) < 2: raise RuntimeError( 'mode={} requires at least two input files.'.format(mode) ) file_indexes = np.arange(len(data_filepaths)) random.shuffle(file_indexes) # compile splits (ceil ensures test_df is not empty) test_df = pd.concat( [ pd.read_csv(data_filepaths[index], **kwargs) for index in file_indexes[: ceil(test_fraction * len(data_filepaths))] ] ) train_df = pd.concat( [ pd.read_csv(data_filepaths[index], **kwargs) for index in file_indexes[ceil(test_fraction * len(data_filepaths)) :] ] ) # random splitting mode: # build a joint df from all samples, then split elif mode == 'random': df = pd.concat( [pd.read_csv(data_path, **kwargs) for data_path in data_filepaths], sort=False, ) sample_indexes = np.arange(df.shape[0]) random.shuffle(sample_indexes) splitting_index = ceil(test_fraction * df.shape[0]) test_df = df.iloc[sample_indexes[:splitting_index]] train_df = df.iloc[sample_indexes[splitting_index:]] else: raise ValueError('Choose mode to be from the set {"random", "file"}.') # generate hash (per default uses first number_of_columns # columns to define hash) number_of_columns = ( number_of_columns if test_df.shape[1] > number_of_columns else test_df.shape[1] ) hash_str = _hash_from_df_columns( train_df, number_of_columns ) + _hash_from_df_columns(test_df, number_of_columns) hash_fn.update(hash_str.encode('utf-8')) hash_id = str(int(hash_fn.hexdigest(), 16))[:6] # generate file path train_filepath = os.path.join( save_path, file_suffix.format(data_type, 'train', 1 - test_fraction, hash_id, seed), ) test_filepath = os.path.join( save_path, file_suffix.format(data_type, 'test', test_fraction, hash_id, seed) ) # write the dataset for path, df in zip([train_filepath, test_filepath], [train_df, test_df]): df.to_csv(path) return train_filepath, test_filepath