"""Implementation of ProteinProteinInteractionDataset."""
import pandas as pd
import torch
from numpy import iterable
from torch.utils.data import Dataset
from pytoda.warnings import device_warning
from ..proteins.protein_language import ProteinLanguage
from ..types import Files, List, Sequence, Tensor, Tuple, Union
from .protein_sequence_dataset import ProteinSequenceDataset
[docs]class ProteinProteinInteractionDataset(Dataset):
"""
PPI Dataset implementation. Designed for two sources of protein sequences
and on source of discrete labels.
NOTE: Only supports classification (possibly multitask) but no regression
tasks.
"""
[docs] def __init__(
self,
sequence_filepaths: Union[Files, Sequence[Files]],
entity_names: Sequence[str],
labels_filepath: str,
sequence_filetypes: Union[str, List[str]] = 'infer',
annotations_column_names: Union[List[int], List[str]] = None,
protein_languages: Union[ProteinLanguage, List[ProteinLanguage]] = None,
paddings: Union[bool, Sequence[bool]] = True,
padding_lengths: Union[int, Sequence[int]] = None,
add_start_and_stops: Union[bool, Sequence[bool]] = False,
augment_by_reverts: Union[bool, Sequence[bool]] = False,
randomizes: Union[bool, Sequence[bool]] = False,
iterate_datasets: Union[bool, Sequence[bool]] = False,
device: torch.device = None,
) -> None:
"""
Initialize a protein protein interactiondataset.
Args:
sequence_filepaths (Union[Files, Sequence[Files]]):
paths to .smi (also as .csv) or .fasta (.gz) file for protein
sequences. For each item in the iterable, one protein sequence
dataset is created. Sequences can be nested, i.e. each protein
sequence dataset can be created from an iterable of filepaths
of same type, see sequence_filetypes.
entity_names (Sequence[str]): List of protein sequence entities,
e.g. ['Peptides', 'T-Cell-Receptors']. These names should be
column names of the labels_filepaths in order respective to
sequence_filepaths.
labels_filepath (str): path to .csv file with classification
labels.
sequence_filetypes (Union[str, List[str]]): the filetypes of the
sequence files. Can either be a str if all files have identical
types or an Sequence if different entities have different
types. Different types across the same entity are not
supported. Supported formats are {.smi, .csv, .fasta,
.fasta.gz}. Default is `infer`, i.e. filetypes are inferred
automatically.
annotations_column_names (Union[List[int], List[str]]): indexes
(positional or strings) for the annotations. Defaults to None,
a.k.a. all the columns, except the entity_names are annotation
labels.
protein_languages (Union[ProteinLanguage, List[ProteinLanguage]):
one or multiple ProteinLanguages. If multiple are provided, exactly one
should be given for each entity. If only one is provided, the same
language will be used for all entities. You can also pass child instances
like ProteinFeatureLanguage. Defaults to None, i.e., creating a single
protein language with iupac dictionary for all entities.
paddings (Union[bool, Sequence[bool]]): pad sequences to longest in
the protein language. Defaults to True.
padding_lengths (Union[int, Sequence[int]]): manually sets number
of applied paddings (only if padding = True). Defaults to None.
add_start_and_stops (Union[bool, Sequence[bool]]): add start and
stop token indexes. Defaults to False.
augment_by_reverts (Union[bool, Sequence[bool]]): perform a
stochastic reversion of the amino acid sequence.
randomizes (Union[bool, Sequence[bool]]): perform a true
randomization of the amino acid sequences. Defaults to False.
iterate_datasets (Union[bool, Sequence[bool]]): whether to go through all
items in the datasets to detect unknown characters, find longest
sequence and checks passed padding length if applicable.
Defaults to False.
device (torch.device): DEPRECATED
"""
Dataset.__init__(self)
device_warning(device)
assert len(entity_names) == len(
sequence_filepaths
), 'sequence_filepaths should be an iterable of length in entity names'
self.labels_filepath = labels_filepath
self.entities = list(map(lambda x: x.capitalize(), entity_names))
# wrap single filepath per entity to treat equally as iterable (*args)
self.sequence_filepaths = [
[filepath] if isinstance(filepath, str) else filepath
for filepath in sequence_filepaths
]
# Data type of first sequence files per entity
if sequence_filetypes == 'infer':
self.filetypes = list(
map(lambda x: '.' + x[0].split('.')[-1], self.sequence_filepaths)
)
elif sequence_filetypes in ['.smi', '.csv', '.fasta', '.fasta.gz']:
self.filetypes = [sequence_filetypes] * len(self.entities)
elif len(sequence_filetypes) == len(self.entities) and all(
x in ['.smi', '.csv', '.fasta', '.fasta.gz'] for x in sequence_filetypes
):
self.filetypes = sequence_filetypes
else:
raise ValueError(f'Unsupported filetype: {sequence_filetypes}')
(
self.paddings,
self.padding_lengths,
self.add_start_and_stops,
self.augment_by_reverts,
self.randomizes,
self.iterate_datasets,
) = map(
(
lambda x: x
if iterable(x) and len(x) == len(self.entities)
else [x] * len(self.entities)
),
(
paddings,
padding_lengths,
add_start_and_stops,
augment_by_reverts,
randomizes,
iterate_datasets,
),
)
if protein_languages is None:
# Objects will be constructed in `ProteinSequenceDataset`
self.protein_languages = [None for _ in self.entities]
self.pl_methods = ['iupac' for _ in self.entities]
else:
if isinstance(protein_languages, Sequence):
# Multiple languages were passed
self.protein_languages = protein_languages
self.pl_methods = [p.method for p in protein_languages]
elif isinstance(protein_languages, ProteinLanguage):
# Single language was passed
self.protein_languages = [protein_languages for _ in self.entities]
self.pl_methods = [protein_languages.method for _ in self.entities]
else:
raise TypeError(
f'Received unknown type for protein_language: {type(protein_languages)}'
)
# Check whether input arguments are consistent
for i, (lang, start) in enumerate(
zip(self.protein_languages, self.add_start_and_stops)
):
assert (
lang.add_start_and_stop == start
), f'add_start_and_stop differs for language {i}: {start} vs. {lang.add_start_and_stop}'
# Create protein sequence datasets.
self.datasets = [
ProteinSequenceDataset(
*filepaths,
filetype=self.filetypes[index],
protein_language=self.protein_languages[index],
padding=self.paddings[index],
padding_length=self.padding_lengths[index],
add_start_and_stop=self.add_start_and_stops[index],
augment_by_revert=self.augment_by_reverts[index],
randomize=self.randomizes[index],
name=self.entities[index],
iterate_dataset=self.iterate_datasets[index],
)
for index, filepaths in enumerate(self.sequence_filepaths)
]
# Retrieve the possibly updated protein languages
self.protein_languages = [data.protein_language for data in self.datasets]
# Labels
self.labels_df = pd.read_csv(self.labels_filepath)
# Cast the column names to uppercase
self.labels_df.columns = map(
lambda x: str(x).capitalize(), self.labels_df.columns
)
columns = self.labels_df.columns
# handle labels
if annotations_column_names is None:
self.labels = [column for column in columns if column not in self.entities]
elif all([isinstance(column, int) for column in annotations_column_names]):
self.labels = columns[annotations_column_names]
elif all([isinstance(column, str) for column in annotations_column_names]):
self.labels = list(map(lambda x: x.capitalize(), annotations_column_names))
else:
raise RuntimeError(
'label_columns should be an iterable containing int or str'
)
# get the number of labels
self.number_of_tasks = len(self.labels)
assert all(
list(map(lambda x: x in self.labels_df.columns, self.entities))
), 'At least one given entity name was not found in labels_filepath.'
# filter data based on the availability
masks = []
mask = pd.Series([True] * len(self.labels_df), index=self.labels_df.index)
for entity, dataset in zip(self.entities, self.datasets):
# prune rows (in mask) with ids unavailable in respective dataset
local_mask = self.labels_df[entity].isin(set(dataset.keys()))
mask = mask & local_mask
masks.append(local_mask)
self.labels_df = self.labels_df.loc[mask]
# to investigate missing ids per entity
self.masks_df = pd.concat(masks, axis=1)
self.masks_df.columns = self.entities
self.number_of_samples = len(self.labels_df)
def __len__(self) -> int:
"Total number of samples."
return self.number_of_samples
def __getitem__(self, index: int) -> Tuple[Tensor, ...]:
"""
Generates one sample of data.
Args:
index (int): index of the sample to fetch.
Returns:
Tuple: a tuple containing self.entities+1 torch.Tensors
representing respectively: compound token indexes for each
protein entity and the property labels (annotations)
"""
# sample selection
selected_sample = self.labels_df.iloc[index]
# labels (annotations)
labels_tensor = torch.tensor(
list(selected_sample[self.labels].values),
dtype=torch.float,
)
# samples (Protein sequences)
proteins_tensors = [
ds.get_item_from_key(selected_sample[ds.name]) for ds in self.datasets
]
return tuple([*proteins_tensors, labels_tensor])