Coverage for datamodules.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import logging
5import warnings
6from collections import namedtuple
7from pathlib import PosixPath, Path
8from typing import Dict, List, Optional, Union
10import numpy as np
11import pandas as pd
12from pandas.core.frame import DataFrame
14import torch
15from torch.utils.data.dataloader import DataLoader
17import pytorch_lightning as pl
19from elfragmentador import constants, spectra
20from argparse import _ArgumentGroup
21from torch import Tensor
22from tqdm.auto import tqdm
24TrainBatch = namedtuple(
25 "TrainBatch",
26 "encoded_sequence, encoded_mods, charge, nce, encoded_spectra, norm_irt",
27)
30def match_lengths(
31 nested_list: Union[List[List[Union[int, float]]], List[List[int]]],
32 max_len: int,
33 name: str = "items",
34 verbose: bool = True,
35) -> Tensor:
36 lengths = [len(x) for x in nested_list]
37 unique_lengths = set(lengths)
38 match_max = [1 for x in lengths if x == max_len]
40 out_message = (
41 f"{len(match_max)}/{len(nested_list)} "
42 f"{name} actually match the max sequence length of"
43 f" {max_len},"
44 f" found {unique_lengths}"
45 )
47 if len(match_max) == len(nested_list):
48 logging.info(out_message)
49 else:
50 logging.warning(out_message)
52 out = [
53 x + ([0] * (max_len - len(x))) if len(x) != max_len else x for x in nested_list
54 ]
55 out = torch.stack([torch.Tensor(x).T for x in out])
56 return out
59def match_colnames(df: DataFrame) -> Dict[str, Optional[str]]:
60 def match_col(string1, string2, colnames, match_mode="in", combine_mode=None):
61 m = {
62 "in": lambda q, t: q in t,
63 "startswith": lambda q, t: q.startswith(t) or t.startswith(q),
64 "equals": lambda q, t: q == t,
65 }
66 match_fun = m[match_mode]
67 match_indices1 = [i for i, x in enumerate(colnames) if match_fun(string1, x)]
69 if string2 is None:
70 match_indices = match_indices1
71 else:
72 match_indices2 = [
73 i for i, x in enumerate(colnames) if match_fun(string2, x)
74 ]
75 if combine_mode == "union":
76 match_indices = set(match_indices1).union(set(match_indices2))
77 elif combine_mode == "intersect":
78 match_indices = set(match_indices1).intersection(set(match_indices2))
79 else:
80 raise NotImplementedError
82 try:
83 out_index = list(match_indices)[0]
84 except IndexError:
85 out_index = None
87 return out_index
89 colnames = list(df)
90 out = {
91 "SeqE": match_col("Encoding", "Seq", colnames, combine_mode="intersect"),
92 "ModE": match_col("Encoding", "Mod", colnames, combine_mode="intersect"),
93 "SpecE": match_col("Encoding", "Spec", colnames, combine_mode="intersect"),
94 "Ch": match_col("harg", None, colnames),
95 "iRT": match_col("IRT", "iRT", colnames, combine_mode="union"),
96 "NCE": match_col(
97 "nce", "NCE", colnames, combine_mode="union", match_mode="startswith"
98 ),
99 }
100 out = {k: (colnames[v] if v is not None else None) for k, v in out.items()}
101 logging.info(f">>> Mapped column names to the provided dataset {out}")
102 return out
105class PeptideDataset(torch.utils.data.Dataset):
106 def __init__(
107 self,
108 df: DataFrame,
109 max_spec: int = 1e6,
110 drop_missing_vals=False,
111 ) -> None:
112 super().__init__()
113 logging.info("\n>>> Initalizing Dataset")
114 if drop_missing_vals:
115 former_len = len(df)
116 df.dropna(inplace=True)
117 logging.warning(
118 f"\n>>> {former_len}/{len(df)} rows left after dropping missing values"
119 )
121 if max_spec < len(df):
122 logging.warning(
123 "\n>>> Filtering out to have "
124 f"{max_spec}, change the 'max_spec' argument if you don't want"
125 "this to happen"
126 )
127 df = df.sample(n=int(max_spec))
129 self.df = df # TODO remove this for memory ...
131 name_match = match_colnames(df)
133 seq_encoding_iter = tqdm(
134 self.df[name_match["SeqE"]], "Decoding sequence encodings"
135 )
136 sequence_encodings = [eval(x) for x in seq_encoding_iter]
137 sequence_encodings = match_lengths(
138 sequence_encodings, constants.MAX_TENSOR_SEQUENCE, "Sequences"
139 )
140 self.sequence_encodings = sequence_encodings.long()
142 if name_match["ModE"] is None:
143 logging.warning(
144 (
145 "Found missing Modification Encodings,"
146 " Assuming all peptides are unmodified."
147 " Please fix the data for future use,"
148 " since this imputation will be removed in the future"
149 )
150 )
151 mod_encodings = [
152 [0] * constants.MAX_TENSOR_SEQUENCE for _ in sequence_encodings
153 ]
154 else:
155 mod_encodings_iter = tqdm(
156 self.df[name_match["ModE"]], "Decoding Modification encoding"
157 )
158 mod_encodings = [eval(x) for x in mod_encodings_iter]
160 mod_encodings = match_lengths(
161 mod_encodings, constants.MAX_TENSOR_SEQUENCE, "Mods"
162 )
163 self.mod_encodings = mod_encodings.long()
165 spec_encoding_iter = tqdm(
166 self.df[name_match["SpecE"]], "Decoding Spec Encodings"
167 )
168 spectra_encodings = [eval(x) for x in spec_encoding_iter]
169 spectra_encodings = match_lengths(
170 spectra_encodings, constants.NUM_FRAG_EMBEDINGS, "Spectra"
171 )
172 self.spectra_encodings = spectra_encodings.float()
173 avg_peaks = torch.sum(spectra_encodings > 0.01, axis=1).float().mean()
175 spectra_lengths = len(self.spectra_encodings[0])
176 sequence_lengths = len(self.sequence_encodings[0])
178 try:
179 irts = np.array(self.df[name_match["iRT"]]).astype("float") / 100
180 self.norm_irts = torch.from_numpy(irts).float().unsqueeze(1)
181 del irts
182 except ValueError as e:
183 logging.error(self.df[name_match["iRT"]])
184 logging.error(e)
185 raise e
187 if name_match["NCE"] is None:
188 nces = (
189 torch.Tensor([float("nan")] * len(self.norm_irts)).float().unsqueeze(1)
190 )
191 else:
192 try:
193 nces = np.array(self.df[name_match["NCE"]]).astype("float")
194 nces = torch.from_numpy(nces).float().unsqueeze(1)
195 except ValueError as e:
196 logging.error(self.df[name_match["NCE"]])
197 logging.error(e)
198 raise e
200 self.nces = nces
202 if torch.any(self.nces.isnan()):
203 # TODO decide if here should be the place to impute NCEs ... and warn ...
204 warnings.warn(
205 (
206 "Found missing values in NCEs, assuming 30."
207 " Please fix the data for future use, "
208 "since this imputation will be removed in the future"
209 ),
210 FutureWarning,
211 )
212 self.nces = torch.where(self.nces.isnan(), torch.Tensor([30.0]), self.nces)
214 # This syntax is compatible in torch +1.8, will change when colab migrates to it
215 # self.nces = torch.nan_to_num(self.nces, nan=30.0)
217 charges = np.array(self.df[name_match["Ch"]]).astype("long")
218 self.charges = torch.Tensor(charges).long().unsqueeze(1)
220 logging.info(
221 (
222 f"Dataset Initialized with {len(df)} entries."
223 f" Sequence length: {sequence_lengths}"
224 f" Spectra length: {spectra_lengths}"
225 f"; Average Peaks/spec: {avg_peaks}"
226 )
227 )
228 logging.info(">>> Done Initializing dataset\n")
230 @staticmethod
231 def from_sptxt(
232 filepath: str,
233 max_spec: int = 1e6,
234 filter_df: bool = True,
235 *args,
236 **kwargs,
237 ) -> PeptideDataset:
238 df = spectra.encode_sptxt(str(filepath), max_spec=max_spec, *args, **kwargs)
239 if filter_df:
240 df = filter_df_on_sequences(df)
242 return PeptideDataset(df)
244 @staticmethod
245 def from_csv(filepath: Union[str, Path], max_spec: int = 1e6):
246 df = filter_df_on_sequences(pd.read_csv(str(filepath)))
247 return PeptideDataset(df, max_spec=max_spec)
249 def __len__(self) -> int:
250 return len(self.df)
252 def __getitem__(self, index: int) -> TrainBatch:
253 # encoded_pept = torch.Tensor(eval(self.df.iloc[index].Encoding)).long().T
254 # norm_irt = torch.Tensor([self.df.iloc[index].mIRT / 100]).float()
255 encoded_sequence = self.sequence_encodings[index]
256 encoded_mods = self.mod_encodings[index]
257 encoded_spectra = self.spectra_encodings[index]
258 norm_irt = self.norm_irts[index]
259 charge = self.charges[index]
260 nce = self.nces[index]
262 out = TrainBatch(
263 encoded_sequence=encoded_sequence,
264 encoded_mods=encoded_mods,
265 charge=charge,
266 nce=nce,
267 encoded_spectra=encoded_spectra,
268 norm_irt=norm_irt,
269 )
270 return out
273def filter_df_on_sequences(df: DataFrame, name: str = "") -> DataFrame:
274 name_match = match_colnames(df)
275 logging.info(list(df))
276 logging.warning(f"Removing Large sequences, currently {name}: {len(df)}")
278 seq_iterable = tqdm(df[name_match["SeqE"]], desc="Decoding tensor seqs")
280 df = (
281 df[[len(eval(x)) <= constants.MAX_TENSOR_SEQUENCE for x in seq_iterable]]
282 .copy()
283 .reset_index(drop=True)
284 )
286 logging.warning(f"Left {name}: {len(df)}")
287 return df
290class PeptideDataModule(pl.LightningDataModule):
291 def __init__(
292 self,
293 batch_size: int = 64,
294 base_dir: Union[str, PosixPath] = ".",
295 drop_missing_vals: bool = False,
296 ) -> None:
297 super().__init__()
298 self.batch_size = batch_size
299 self.drop_missing_vals = drop_missing_vals
300 base_dir = Path(base_dir)
302 train_path = list(base_dir.glob("*train*.csv"))
303 val_path = list(base_dir.glob("*val*.csv"))
305 assert (
306 len(train_path) > 0
307 ), f"Train File '{train_path}' not found in '{base_dir}'"
308 assert len(val_path) > 0, f"Val File '{val_path}' not found in '{base_dir}'"
310 train_df = pd.concat([pd.read_csv(str(x)) for x in train_path])
311 train_df = filter_df_on_sequences(train_df)
312 val_df = pd.concat([pd.read_csv(str(x)) for x in val_path])
313 val_df = filter_df_on_sequences(val_df)
315 self.train_df = train_df
316 self.val_df = val_df
318 @staticmethod
319 def add_model_specific_args(parser: _ArgumentGroup) -> _ArgumentGroup:
320 parser.add_argument("--batch_size", type=int, default=64)
321 parser.add_argument("--data_dir", type=str, default=".")
322 parser.add_argument("--drop_missing_vals", type=bool, default=False)
323 return parser
325 def setup(self) -> None:
326 self.train_dataset = PeptideDataset(
327 self.train_df, drop_missing_vals=self.drop_missing_vals
328 )
329 self.val_dataset = PeptideDataset(
330 self.val_df, drop_missing_vals=self.drop_missing_vals
331 )
333 def train_dataloader(self) -> DataLoader:
334 return DataLoader(
335 self.train_dataset, num_workers=0, batch_size=self.batch_size, shuffle=True
336 )
338 def val_dataloader(self) -> DataLoader:
339 return DataLoader(self.val_dataset, batch_size=self.batch_size)