comparison keras_train_and_eval.py @ 13:ebd3bd2f2985 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author bgruening
date Mon, 02 Oct 2023 08:46:12 +0000
parents 9ac42e46dfbd
children
comparison
equal deleted inserted replaced
12:0460590afd6e 13:ebd3bd2f2985
1 import argparse 1 import argparse
2 import json 2 import json
3 import os 3 import os
4 import pickle
5 import warnings 4 import warnings
6 from itertools import chain 5 from itertools import chain
7 6
8 import joblib 7 import joblib
9 import numpy as np 8 import numpy as np
10 import pandas as pd 9 import pandas as pd
11 from galaxy_ml.externals.selene_sdk.utils import compute_score 10 from galaxy_ml.keras_galaxy_models import (
12 from galaxy_ml.keras_galaxy_models import _predict_generator 11 _predict_generator,
12 KerasGBatchClassifier,
13 )
14 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
13 from galaxy_ml.model_validations import train_test_split 15 from galaxy_ml.model_validations import train_test_split
14 from galaxy_ml.utils import (clean_params, get_main_estimator, 16 from galaxy_ml.utils import (
15 get_module, get_scoring, load_model, read_columns, 17 clean_params,
16 SafeEval, try_get_attr) 18 gen_compute_scores,
19 get_main_estimator,
20 get_module,
21 get_scoring,
22 read_columns,
23 SafeEval
24 )
17 from scipy.io import mmread 25 from scipy.io import mmread
18 from sklearn.metrics.scorer import _check_multimetric_scoring 26 from sklearn.metrics._scorer import _check_multimetric_scoring
19 from sklearn.model_selection import _search, _validation
20 from sklearn.model_selection._validation import _score 27 from sklearn.model_selection._validation import _score
21 from sklearn.pipeline import Pipeline 28 from sklearn.utils import _safe_indexing, indexable
22 from sklearn.utils import indexable, safe_indexing
23
24 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
25 setattr(_search, "_fit_and_score", _fit_and_score)
26 setattr(_validation, "_fit_and_score", _fit_and_score)
27 29
28 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) 30 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
29 CACHE_DIR = os.path.join(os.getcwd(), "cached") 31 CACHE_DIR = os.path.join(os.getcwd(), "cached")
30 del os 32 NON_SEARCHABLE = (
31 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks") 33 "n_jobs",
34 "pre_dispatch",
35 "memory",
36 "_path",
37 "_dir",
38 "nthread",
39 "callbacks",
40 )
32 ALLOWED_CALLBACKS = ( 41 ALLOWED_CALLBACKS = (
33 "EarlyStopping", 42 "EarlyStopping",
34 "TerminateOnNaN", 43 "TerminateOnNaN",
35 "ReduceLROnPlateau", 44 "ReduceLROnPlateau",
36 "CSVLogger", 45 "CSVLogger",
94 index_arr = np.arange(n_samples) 103 index_arr = np.arange(n_samples)
95 test = index_arr[np.isin(groups, group_names)] 104 test = index_arr[np.isin(groups, group_names)]
96 train = index_arr[~np.isin(groups, group_names)] 105 train = index_arr[~np.isin(groups, group_names)]
97 rval = list( 106 rval = list(
98 chain.from_iterable( 107 chain.from_iterable(
99 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays 108 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays
100 ) 109 )
101 ) 110 )
102 else: 111 else:
103 rval = train_test_split(*new_arrays, **kwargs) 112 rval = train_test_split(*new_arrays, **kwargs)
104 113
106 rval[pos * 2: 2] = [None, None] 115 rval[pos * 2: 2] = [None, None]
107 116
108 return rval 117 return rval
109 118
110 119
111 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): 120 def _evaluate_keras_and_sklearn_scores(
112 """output scores based on input scorer 121 estimator,
122 data_generator,
123 X,
124 y=None,
125 sk_scoring=None,
126 steps=None,
127 batch_size=32,
128 return_predictions=False,
129 ):
130 """output scores for bother keras and sklearn metrics
113 131
114 Parameters 132 Parameters
115 ---------- 133 -----------
116 y_true : array 134 estimator : object
117 True label or target values 135 Fitted `galaxy_ml.keras_galaxy_models.KerasGBatchClassifier`.
118 pred_probas : array 136 data_generator : object
119 Prediction values, probability for classification problem 137 From `galaxy_ml.preprocessors.ImageDataFrameBatchGenerator`.
120 scorer : dict 138 X : 2-D array
121 dict of `sklearn.metrics.scorer.SCORER` 139 Contains indecies of images that need to be evaluated.
122 is_multimetric : bool, default is True 140 y : None
141 Target value.
142 sk_scoring : dict
143 Galaxy tool input parameters.
144 steps : integer or None
145 Evaluation/prediction steps before stop.
146 batch_size : integer
147 Number of samples in a batch
148 return_predictions : bool, default is False
149 Whether to return predictions and true labels.
123 """ 150 """
124 if y_true.ndim == 1 or y_true.shape[-1] == 1: 151 scores = {}
125 pred_probas = pred_probas.ravel() 152
126 pred_labels = (pred_probas > 0.5).astype("int32") 153 generator = data_generator.flow(X, y=y, batch_size=batch_size)
127 targets = y_true.ravel().astype("int32") 154 # keras metrics evaluation
128 if not is_multimetric: 155 # handle scorer, convert to scorer dict
129 preds = ( 156 generator.reset()
130 pred_labels 157 score_results = estimator.model_.evaluate_generator(generator, steps=steps)
131 if scorer.__class__.__name__ == "_PredictScorer" 158 metrics_names = estimator.model_.metrics_names
132 else pred_probas 159 if not isinstance(metrics_names, list):
133 ) 160 scores[metrics_names] = score_results
134 score = scorer._score_func(targets, preds, **scorer._kwargs) 161 else:
135 162 scores = dict(zip(metrics_names, score_results))
136 return score 163
137 else: 164 if sk_scoring["primary_scoring"] == "default" and not return_predictions:
138 scores = {} 165 return scores
139 for name, one_scorer in scorer.items(): 166
140 preds = ( 167 generator.reset()
141 pred_labels 168 predictions, y_true = _predict_generator(estimator.model_, generator, steps=steps)
142 if one_scorer.__class__.__name__ == "_PredictScorer" 169
143 else pred_probas 170 # for sklearn metrics
144 ) 171 if sk_scoring["primary_scoring"] != "default":
145 score = one_scorer._score_func(targets, preds, **one_scorer._kwargs) 172 scorer = get_scoring(sk_scoring)
146 scores[name] = score 173 if not isinstance(scorer, (dict, list)):
147 174 scorer = [sk_scoring["primary_scoring"]]
148 # TODO: multi-class metrics 175 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
149 # multi-label 176 sk_scores = gen_compute_scores(y_true, predictions, scorer)
150 else: 177 scores.update(sk_scores)
151 pred_labels = (pred_probas > 0.5).astype("int32") 178
152 targets = y_true.astype("int32") 179 if return_predictions:
153 if not is_multimetric: 180 return scores, predictions, y_true
154 preds = ( 181 else:
155 pred_labels 182 return scores, None, None
156 if scorer.__class__.__name__ == "_PredictScorer"
157 else pred_probas
158 )
159 score, _ = compute_score(preds, targets, scorer._score_func)
160 return score
161 else:
162 scores = {}
163 for name, one_scorer in scorer.items():
164 preds = (
165 pred_labels
166 if one_scorer.__class__.__name__ == "_PredictScorer"
167 else pred_probas
168 )
169 score, _ = compute_score(preds, targets, one_scorer._score_func)
170 scores[name] = score
171
172 return scores
173 183
174 184
175 def main( 185 def main(
176 inputs, 186 inputs,
177 infile_estimator, 187 infile_estimator,
178 infile1, 188 infile1,
179 infile2, 189 infile2,
180 outfile_result, 190 outfile_result,
191 outfile_history=None,
181 outfile_object=None, 192 outfile_object=None,
182 outfile_weights=None,
183 outfile_y_true=None, 193 outfile_y_true=None,
184 outfile_y_preds=None, 194 outfile_y_preds=None,
185 groups=None, 195 groups=None,
186 ref_seq=None, 196 ref_seq=None,
187 intervals=None, 197 intervals=None,
190 ): 200 ):
191 """ 201 """
192 Parameter 202 Parameter
193 --------- 203 ---------
194 inputs : str 204 inputs : str
195 File path to galaxy tool parameter 205 File path to galaxy tool parameter.
196 206
197 infile_estimator : str 207 infile_estimator : str
198 File path to estimator 208 File path to estimator.
199 209
200 infile1 : str 210 infile1 : str
201 File path to dataset containing features 211 File path to dataset containing features.
202 212
203 infile2 : str 213 infile2 : str
204 File path to dataset containing target values 214 File path to dataset containing target values.
205 215
206 outfile_result : str 216 outfile_result : str
207 File path to save the results, either cv_results or test result 217 File path to save the results, either cv_results or test result.
218
219 outfile_history : str, optional
220 File path to save the training history.
208 221
209 outfile_object : str, optional 222 outfile_object : str, optional
210 File path to save searchCV object 223 File path to save searchCV object.
211
212 outfile_weights : str, optional
213 File path to save deep learning model weights
214 224
215 outfile_y_true : str, optional 225 outfile_y_true : str, optional
216 File path to target values for prediction 226 File path to target values for prediction.
217 227
218 outfile_y_preds : str, optional 228 outfile_y_preds : str, optional
219 File path to save deep learning model weights 229 File path to save predictions.
220 230
221 groups : str 231 groups : str
222 File path to dataset containing groups labels 232 File path to dataset containing groups labels.
223 233
224 ref_seq : str 234 ref_seq : str
225 File path to dataset containing genome sequence file 235 File path to dataset containing genome sequence file.
226 236
227 intervals : str 237 intervals : str
228 File path to dataset containing interval file 238 File path to dataset containing interval file.
229 239
230 targets : str 240 targets : str
231 File path to dataset compressed target bed file 241 File path to dataset compressed target bed file.
232 242
233 fasta_path : str 243 fasta_path : str
234 File path to dataset containing fasta file 244 File path to dataset containing fasta file.
235 """ 245 """
236 warnings.simplefilter("ignore") 246 warnings.simplefilter("ignore")
237 247
238 with open(inputs, "r") as param_handler: 248 with open(inputs, "r") as param_handler:
239 params = json.load(param_handler) 249 params = json.load(param_handler)
240 250
241 # load estimator 251 # load estimator
242 with open(infile_estimator, "rb") as estimator_handler: 252 estimator = load_model_from_h5(infile_estimator)
243 estimator = load_model(estimator_handler)
244 253
245 estimator = clean_params(estimator) 254 estimator = clean_params(estimator)
246 255
247 # swap hyperparameter 256 # swap hyperparameter
248 swapping = params["experiment_schemes"]["hyperparams_swapping"] 257 swapping = params["experiment_schemes"]["hyperparams_swapping"]
249 swap_params = _eval_swap_params(swapping) 258 swap_params = _eval_swap_params(swapping)
250 estimator.set_params(**swap_params) 259 estimator.set_params(**swap_params)
251
252 estimator_params = estimator.get_params() 260 estimator_params = estimator.get_params()
253
254 # store read dataframe object 261 # store read dataframe object
255 loaded_df = {} 262 loaded_df = {}
256 263
257 input_type = params["input_options"]["selected_input"] 264 input_type = params["input_options"]["selected_input"]
258 # tabular input 265 # tabular input
331 else: 338 else:
332 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) 339 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
333 loaded_df[df_key] = infile2 340 loaded_df[df_key] = infile2
334 341
335 y = read_columns( 342 y = read_columns(
336 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True 343 infile2,
344 c=c,
345 c_option=column_option,
346 sep="\t",
347 header=header,
348 parse_dates=True,
337 ) 349 )
338 if len(y.shape) == 2 and y.shape[1] == 1: 350 if len(y.shape) == 2 and y.shape[1] == 1:
339 y = y.ravel() 351 y = y.ravel()
340 if input_type == "refseq_and_interval": 352 if input_type == "refseq_and_interval":
341 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) 353 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
385 if main_est.__class__.__name__ == "IRAPSClassifier": 397 if main_est.__class__.__name__ == "IRAPSClassifier":
386 main_est.set_params(memory=memory) 398 main_est.set_params(memory=memory)
387 399
388 # handle scorer, convert to scorer dict 400 # handle scorer, convert to scorer dict
389 scoring = params["experiment_schemes"]["metrics"]["scoring"] 401 scoring = params["experiment_schemes"]["metrics"]["scoring"]
390 if scoring is not None:
391 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
392 # Check if secondary_scoring is specified
393 secondary_scoring = scoring.get("secondary_scoring", None)
394 if secondary_scoring is not None:
395 # If secondary_scoring is specified, convert the list into comman separated string
396 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
397
398 scorer = get_scoring(scoring) 402 scorer = get_scoring(scoring)
399 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 403 if not isinstance(scorer, (dict, list)):
404 scorer = [scoring["primary_scoring"]]
405 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
400 406
401 # handle test (first) split 407 # handle test (first) split
402 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] 408 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
403 409
404 if test_split_options["shuffle"] == "group": 410 if test_split_options["shuffle"] == "group":
409 else: 415 else:
410 raise ValueError( 416 raise ValueError(
411 "Stratified shuffle split is not " "applicable on empty target values!" 417 "Stratified shuffle split is not " "applicable on empty target values!"
412 ) 418 )
413 419
414 ( 420 X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split_none(
415 X_train, 421 X, y, groups, **test_split_options
416 X_test, 422 )
417 y_train,
418 y_test,
419 groups_train,
420 _groups_test,
421 ) = train_test_split_none(X, y, groups, **test_split_options)
422 423
423 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] 424 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
424 425
425 # handle validation (second) split 426 # handle validation (second) split
426 if exp_scheme == "train_val_test": 427 if exp_scheme == "train_val_test":
441 X_train, 442 X_train,
442 X_val, 443 X_val,
443 y_train, 444 y_train,
444 y_val, 445 y_val,
445 groups_train, 446 groups_train,
446 _groups_val, 447 groups_val,
447 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) 448 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
448 449
449 # train and eval 450 # train and eval
450 if hasattr(estimator, "validation_data"): 451 if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
451 if exp_scheme == "train_val_test": 452 if exp_scheme == "train_val_test":
452 estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) 453 history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
453 else: 454 else:
454 estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) 455 history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
455 else: 456 else:
456 estimator.fit(X_train, y_train) 457 history = estimator.fit(X_train, y_train)
457 458 if "callbacks" in estimator_params:
458 if hasattr(estimator, "evaluate"): 459 for cb in estimator_params["callbacks"]:
460 if cb["callback_selection"]["callback_type"] == "CSVLogger":
461 hist_df = pd.DataFrame(history.history)
462 hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1)
463 epo_col = hist_df.pop('epoch')
464 hist_df.insert(0, 'epoch', epo_col)
465 hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False)
466 break
467 if isinstance(estimator, KerasGBatchClassifier):
468 scores = {}
459 steps = estimator.prediction_steps 469 steps = estimator.prediction_steps
460 batch_size = estimator.batch_size 470 batch_size = estimator.batch_size
461 generator = estimator.data_generator_.flow( 471 data_generator = estimator.data_generator_
462 X_test, y=y_test, batch_size=batch_size 472
473 scores, predictions, y_true = _evaluate_keras_and_sklearn_scores(
474 estimator,
475 data_generator,
476 X_test,
477 y=y_test,
478 sk_scoring=scoring,
479 steps=steps,
480 batch_size=batch_size,
481 return_predictions=bool(outfile_y_true),
463 ) 482 )
464 predictions, y_true = _predict_generator( 483
465 estimator.model_, generator, steps=steps 484 else:
466 ) 485 scores = {}
467 scores = _evaluate(y_true, predictions, scorer, is_multimetric=True) 486 if hasattr(estimator, "model_") and hasattr(estimator.model_, "metrics_names"):
468 487 batch_size = estimator.batch_size
469 else: 488 score_results = estimator.model_.evaluate(
489 X_test, y=y_test, batch_size=batch_size, verbose=0
490 )
491 metrics_names = estimator.model_.metrics_names
492 if not isinstance(metrics_names, list):
493 scores[metrics_names] = score_results
494 else:
495 scores = dict(zip(metrics_names, score_results))
496
470 if hasattr(estimator, "predict_proba"): 497 if hasattr(estimator, "predict_proba"):
471 predictions = estimator.predict_proba(X_test) 498 predictions = estimator.predict_proba(X_test)
472 else: 499 else:
473 predictions = estimator.predict(X_test) 500 predictions = estimator.predict(X_test)
474 501
475 y_true = y_test 502 y_true = y_test
476 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) 503 sk_scores = _score(estimator, X_test, y_test, scorer)
504 scores.update(sk_scores)
505
506 # handle output
477 if outfile_y_true: 507 if outfile_y_true:
478 try: 508 try:
479 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False) 509 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False)
480 pd.DataFrame(predictions).astype(np.float32).to_csv( 510 pd.DataFrame(predictions).astype(np.float32).to_csv(
481 outfile_y_preds, 511 outfile_y_preds,
484 float_format="%g", 514 float_format="%g",
485 chunksize=10000, 515 chunksize=10000,
486 ) 516 )
487 except Exception as e: 517 except Exception as e:
488 print("Error in saving predictions: %s" % e) 518 print("Error in saving predictions: %s" % e)
489
490 # handle output 519 # handle output
491 for name, score in scores.items(): 520 for name, score in scores.items():
492 scores[name] = [score] 521 scores[name] = [score]
493 df = pd.DataFrame(scores) 522 df = pd.DataFrame(scores)
494 df = df[sorted(df.columns)] 523 df = df[sorted(df.columns)]
495 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) 524 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
496 525
497 memory.clear(warn=False) 526 memory.clear(warn=False)
498 527
499 if outfile_object: 528 if outfile_object:
500 main_est = estimator 529 dump_model_to_h5(estimator, outfile_object)
501 if isinstance(estimator, Pipeline):
502 main_est = estimator.steps[-1][-1]
503
504 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
505 if outfile_weights:
506 main_est.save_weights(outfile_weights)
507 del main_est.model_
508 del main_est.fit_params
509 del main_est.model_class_
510 if getattr(main_est, "validation_data", None):
511 del main_est.validation_data
512 if getattr(main_est, "data_generator_", None):
513 del main_est.data_generator_
514
515 with open(outfile_object, "wb") as output_handler:
516 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
517 530
518 531
519 if __name__ == "__main__": 532 if __name__ == "__main__":
520 aparser = argparse.ArgumentParser() 533 aparser = argparse.ArgumentParser()
521 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 534 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
522 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 535 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
523 aparser.add_argument("-X", "--infile1", dest="infile1") 536 aparser.add_argument("-X", "--infile1", dest="infile1")
524 aparser.add_argument("-y", "--infile2", dest="infile2") 537 aparser.add_argument("-y", "--infile2", dest="infile2")
525 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") 538 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
539 aparser.add_argument("-hi", "--outfile_history", dest="outfile_history")
526 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") 540 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
527 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights")
528 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") 541 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
529 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") 542 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
530 aparser.add_argument("-g", "--groups", dest="groups") 543 aparser.add_argument("-g", "--groups", dest="groups")
531 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 544 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
532 aparser.add_argument("-b", "--intervals", dest="intervals") 545 aparser.add_argument("-b", "--intervals", dest="intervals")
538 args.inputs, 551 args.inputs,
539 args.infile_estimator, 552 args.infile_estimator,
540 args.infile1, 553 args.infile1,
541 args.infile2, 554 args.infile2,
542 args.outfile_result, 555 args.outfile_result,
556 outfile_history=args.outfile_history,
543 outfile_object=args.outfile_object, 557 outfile_object=args.outfile_object,
544 outfile_weights=args.outfile_weights,
545 outfile_y_true=args.outfile_y_true, 558 outfile_y_true=args.outfile_y_true,
546 outfile_y_preds=args.outfile_y_preds, 559 outfile_y_preds=args.outfile_y_preds,
547 groups=args.groups, 560 groups=args.groups,
548 ref_seq=args.ref_seq, 561 ref_seq=args.ref_seq,
549 intervals=args.intervals, 562 intervals=args.intervals,