Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import time 

2import logging 

3from typing import Dict, List, Tuple, Union 

4from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 

5 

6import torch 

7import pytorch_lightning as pl 

8import numpy as np 

9from numpy import float32, float64, ndarray 

10 

11import pandas as pd 

12from pandas.core.series import Series 

13from tqdm.auto import tqdm 

14 

15from elfragmentador import constants 

16from elfragmentador.model import PepTransformerModel 

17from elfragmentador.datamodules import PeptideDataset 

18from elfragmentador.metrics import PearsonCorrelation 

19import uniplot 

20 

21 

22def build_evaluate_parser() -> ArgumentParser: 

23 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 

24 parser.add_argument( 

25 "checkpoint_path", type=str, help="Checkpoint to use for the testing" 

26 ) 

27 input_group = parser.add_mutually_exclusive_group() 

28 input_group.add_argument("--sptxt", type=str, help="Sptxt file to use for testing") 

29 input_group.add_argument("--csv", type=str, help="Sptxt file to use for testing") 

30 parser.add_argument( 

31 "--device", 

32 default="cpu", 

33 type=str, 

34 help="Device to move the model to during the evaluation", 

35 ) 

36 parser.add_argument( 

37 "--batch_size", 

38 default=4, 

39 type=int, 

40 help="Batch size to use during the valuation", 

41 ) 

42 nce_group = parser.add_mutually_exclusive_group() 

43 nce_group.add_argument( 

44 "--overwrite_nce", 

45 type=float, 

46 help="NCE to overwrite the collision energy with", 

47 ) 

48 nce_group.add_argument( 

49 "--screen_nce", 

50 type=str, 

51 help="Comma delimited series of collision energies to use", 

52 ) 

53 parser.add_argument( 

54 "--max_spec", 

55 default=1e6, 

56 type=int, 

57 help="Maximum number of spectra to read", 

58 ) 

59 parser.add_argument( 

60 "--out_csv", type=str, help="Optional csv file to output results to" 

61 ) 

62 return parser 

63 

64 

65# Given a model checkpoint and some input data, parse the data and return metrics, also a csv with the report 

66def evaluate_checkpoint( 

67 checkpoint_path: str, 

68 sptxt_path: str, 

69 batch_size=4, 

70 device="cpu", 

71 out_csv=None, 

72 max_spec=1e6, 

73): 

74 model = PepTransformerModel.load_from_checkpoint(checkpoint_path=checkpoint_path) 

75 model.eval() 

76 

77 out, summ_out = evaluate_on_sptxt( 

78 model, 

79 filepath=sptxt_path, 

80 batch_size=batch_size, 

81 device=device, 

82 max_spec=max_spec, 

83 ) 

84 out = pd.DataFrame(out).sort_values(["Spectra_Similarity"]).reset_index() 

85 logging.info(summ_out) 

86 if out_csv is not None: 

87 logging.info(f">>> Saving results to {out_csv}") 

88 out.to_csv(out_csv, index=False) 

89 

90 

91def evaluate_on_sptxt( 

92 model, filepath, batch_size=4, device="cpu", max_spec=1e6, *args, **kwargs 

93): 

94 ds = PeptideDataset.from_sptxt( 

95 filepath=filepath, max_spec=max_spec, *args, **kwargs 

96 ) 

97 return evaluate_on_dataset( 

98 model=model, dataset=ds, batch_size=batch_size, device=device 

99 ) 

100 

101 

102def evaluate_on_csv(model, filepath, batch_size=4, device="cpu", max_spec=1e6): 

103 ds = PeptideDataset.from_csv(filepath=filepath, max_spec=max_spec) 

104 return evaluate_on_dataset( 

105 model=model, dataset=ds, batch_size=batch_size, device=device 

106 ) 

107 

108 

109def terminal_plot_similarity(similarities, name=""): 

110 if all([np.isnan(x) for x in similarities]): 

111 logging.warning("Skipping because all values are missing") 

112 return None 

113 

114 uniplot.histogram( 

115 similarities, 

116 title=f"{name} mean:{similarities.mean()}", 

117 ) 

118 

119 qs = [0, 0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 1] 

120 similarity_quantiles = np.quantile(1 - similarities, qs) 

121 p90 = similarity_quantiles[2] 

122 p10 = similarity_quantiles[-3] 

123 q1 = similarity_quantiles[5] 

124 med = similarity_quantiles[4] 

125 q3 = similarity_quantiles[3] 

126 title = f"Accumulative distribution (y) of the 1 - {name} (x)" 

127 title += f"\nP90={1-p90:.3f} Q3={1-q3:.3f}" 

128 title += f" Median={1-med:.3f} Q1={1-q1:.3f} P10={1-p10:.3f}" 

129 uniplot.plot(xs=similarity_quantiles, ys=qs, lines=True, title=title) 

130 

131 

132def evaluate_on_dataset( 

133 model: PepTransformerModel, 

134 dataset: PeptideDataset, 

135 batch_size: int = 4, 

136 device: str = "cpu", 

137 overwrite_nce: Union[float, bool] = False, 

138) -> Tuple[pd.DataFrame, Dict[str, Union[float64, float32]]]: 

139 dl = torch.utils.data.DataLoader(dataset, batch_size) 

140 cs = torch.nn.CosineSimilarity() 

141 pc = PearsonCorrelation() 

142 

143 model.eval() 

144 model.to(device) 

145 rt_results = [] 

146 mod_sequences = dataset.df["ModSequences"] 

147 rt_real = dataset.df["RTs"] 

148 irt_real = dataset.df["iRT"] 

149 

150 if sum(~np.isnan(np.array(irt_real).astype("float"))) > 1: 

151 logging.warning("Using iRT instead of RT") 

152 rt_real = irt_real 

153 

154 charges = dataset.df["Charges"] 

155 spec_results_cs = [] 

156 spec_results_pc = [] 

157 

158 logging.info(">>> Starting Evaluation of the spectra <<<") 

159 start_time = time.time() 

160 with torch.no_grad(): 

161 for b in tqdm(dl): 

162 if overwrite_nce: 

163 nce = torch.where( 

164 torch.tensor(True), torch.tensor(overwrite_nce), b.nce 

165 ) 

166 else: 

167 nce = b.nce 

168 

169 outs = model.forward( 

170 src=b.encoded_sequence.clone().to(device), 

171 charge=b.charge.clone().to(device), 

172 mods=b.encoded_mods.clone().to(device), 

173 nce=nce.clone().to(device), 

174 ) 

175 

176 out_spec = outs.spectra.cpu().clone() 

177 out_spec = out_spec / out_spec.max(axis=1).values.unsqueeze(0).T 

178 

179 spec_results_cs.append(cs(out_spec, b.encoded_spectra)) 

180 spec_results_pc.append(pc(out_spec, b.encoded_spectra)) 

181 rt_results.append(outs.irt.cpu().clone().flatten()) 

182 del b 

183 del outs 

184 

185 end_time = time.time() 

186 elapsed_time = end_time - start_time 

187 

188 rt_results = torch.cat(rt_results) * 100 

189 spec_results_pc = torch.cat(spec_results_pc) 

190 spec_results_cs = torch.cat(spec_results_cs) 

191 

192 logging.info( 

193 f">> Elapsed time for {len(spec_results_cs)} results was {elapsed_time}." 

194 ) 

195 logging.info( 

196 f">> {len(spec_results_cs) / elapsed_time} results/sec" 

197 f"; {elapsed_time / len(spec_results_cs)} sec/res" 

198 ) 

199 out = { 

200 "ModSequence": mod_sequences, 

201 "Charges": charges, 

202 "Predicted_iRT": rt_results.numpy().flatten(), 

203 "Real_RT": rt_real.to_numpy().flatten(), 

204 "Spectra_Similarity_Cosine": spec_results_cs.numpy().flatten(), 

205 "Spectra_Similarity_Pearson": spec_results_pc.numpy().flatten(), 

206 } 

207 

208 terminal_plot_similarity(out["Spectra_Similarity_Pearson"], "Pearson Similarity") 

209 terminal_plot_similarity(out["Spectra_Similarity_Cosine"], "Cosine Similarity") 

210 

211 # TODO consider the possibility of stratifying on files before normalizing 

212 missing_vals = np.isnan(np.array(rt_real).astype("float")) 

213 logging.warning( 

214 f"Will remove {sum(missing_vals)}/{len(missing_vals)} " 

215 "because they have missing iRTs" 

216 ) 

217 norm_p_irt, rev_p_irt = norm(out["Predicted_iRT"]) 

218 norm_r_irt, rev_r_irt = norm(out["Real_RT"]) 

219 

220 if sum(missing_vals) == len(norm_p_irt): 

221 rt_fit = {"determination": None} 

222 else: 

223 rt_fit = polyfit(norm_p_irt[~missing_vals], norm_r_irt[~missing_vals]) 

224 

225 uniplot.plot( 

226 ys=rev_p_irt(norm_p_irt)[~missing_vals], 

227 xs=rev_r_irt(norm_r_irt)[~missing_vals], 

228 title=( 

229 f"Predicted iRT (y) vs RT (x)" 

230 f" (normalized R2={rt_fit['determination']})" 

231 ), 

232 ) 

233 

234 rt_errors = abs(rev_r_irt(norm_r_irt) - rev_r_irt(norm_p_irt)) 

235 out.update({"RT_Error": rt_errors}) 

236 terminal_plot_similarity(rt_errors[~missing_vals], "RT prediction error") 

237 

238 summ_out = { 

239 "normRT Rsquared": rt_fit["determination"], 

240 "AverageSpectraCosineSimilarity": out["Spectra_Similarity_Cosine"].mean(), 

241 "AverageSpectraPearsonSimilarty": out["Spectra_Similarity_Pearson"].mean(), 

242 } 

243 return pd.DataFrame(out), summ_out 

244 

245 

246def norm(x: ndarray) -> ndarray: 

247 """Normalizes a numpy array by substracting mean and dividing by standard deviation""" 

248 sd = np.nanstd(x) 

249 m = np.nanmean(x) 

250 out = (x - m) / sd 

251 return out, lambda y: (y * sd) + m 

252 

253 

254# Polynomial Regression 

255# Implementation from: 

256# https://stackoverflow.com/questions/893657/ 

257def polyfit( 

258 x: ndarray, y: ndarray, degree: int = 1 

259) -> Dict[str, Union[List[float], float64]]: 

260 """Fits a polynomial fit""" 

261 results = {} 

262 

263 coeffs = np.polyfit(x, y, degree) 

264 

265 # Polynomial Coefficients 

266 results["polynomial"] = coeffs.tolist() 

267 

268 # r-squared 

269 p = np.poly1d(coeffs) 

270 

271 # fit values, and mean 

272 yhat = p(x) # or [p(z) for z in x] 

273 ybar = np.sum(y) / len(y) # or sum(y)/len(y) 

274 ssreg = np.sum((yhat - ybar) ** 2) # or sum([ (yihat - ybar)**2 for yihat in yhat]) 

275 sstot = np.sum((y - ybar) ** 2) # or sum([ (yi - ybar)**2 for yi in y]) 

276 results["determination"] = ssreg / sstot 

277 

278 return results 

279 

280 

281def apply_polyfit(x, polynomial): 

282 tmp = 0 

283 for i, term in enumerate(polynomial[:-1]): 

284 tmp = tmp + ((x ** (1 + i)) * term) 

285 tmp = tmp + polynomial[-1] 

286 return tmp 

287 

288 

289def evaluate_landmark_rt(model: PepTransformerModel): 

290 """evaluate_landmark_rt Checks the prediction of the model on the iRT peptides 

291 

292 Predicts all the procal and Biognosys iRT peptides and checks the correlation 

293 of the theoretical iRT values and the predicted ones 

294 

295 Parameters 

296 ---------- 

297 model : PepTransformerModel 

298 A model to test the predictions on 

299 

300 Returns 

301 ------- 

302 dict : A fit linear regression fit of the theoretical retention time of iRT peptides 

303 and their prediction. 

304 """ 

305 model.eval() 

306 real_rt = [] 

307 pred_rt = [] 

308 for seq, desc in constants.IRT_PEPTIDES.items(): 

309 with torch.no_grad(): 

310 out = model.predict_from_seq(seq, 2, 25) 

311 pred_rt.append(100 * out.irt.numpy()) 

312 real_rt.append(np.array(desc["irt"])) 

313 

314 # TODO make this return a correlation coefficient 

315 fit = polyfit(np.array(real_rt).flatten(), np.array(pred_rt).flatten()) 

316 logging.info(fit) 

317 uniplot.plot(xs=np.array(real_rt).flatten(), ys=np.array(pred_rt).flatten()) 

318 return fit