comparison keras_train_and_eval.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children ba7fb6b33cd0
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import joblib
3 import json 2 import json
4 import numpy as np
5 import os 3 import os
6 import pandas as pd
7 import pickle
8 import warnings 4 import warnings
9 from itertools import chain 5 from itertools import chain
6
7 import joblib
8 import numpy as np
9 import pandas as pd
10 from galaxy_ml.keras_galaxy_models import (
11 _predict_generator,
12 KerasGBatchClassifier,
13 )
14 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
15 from galaxy_ml.model_validations import train_test_split
16 from galaxy_ml.utils import (
17 clean_params,
18 gen_compute_scores,
19 get_main_estimator,
20 get_module,
21 get_scoring,
22 read_columns,
23 SafeEval
24 )
10 from scipy.io import mmread 25 from scipy.io import mmread
11 from sklearn.pipeline import Pipeline 26 from sklearn.metrics._scorer import _check_multimetric_scoring
12 from sklearn.metrics.scorer import _check_multimetric_scoring
13 from sklearn import model_selection
14 from sklearn.model_selection._validation import _score 27 from sklearn.model_selection._validation import _score
15 from sklearn.model_selection import _search, _validation 28 from sklearn.utils import _safe_indexing, indexable
16 from sklearn.utils import indexable, safe_indexing 29
17 30 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
18 from galaxy_ml.externals.selene_sdk.utils import compute_score 31 CACHE_DIR = os.path.join(os.getcwd(), "cached")
19 from galaxy_ml.model_validations import train_test_split 32 NON_SEARCHABLE = (
20 from galaxy_ml.keras_galaxy_models import _predict_generator 33 "n_jobs",
21 from galaxy_ml.utils import (SafeEval, get_scoring, load_model, 34 "pre_dispatch",
22 read_columns, try_get_attr, get_module, 35 "memory",
23 clean_params, get_main_estimator) 36 "_path",
24 37 "_dir",
25 38 "nthread",
26 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 39 "callbacks",
27 setattr(_search, '_fit_and_score', _fit_and_score) 40 )
28 setattr(_validation, '_fit_and_score', _fit_and_score) 41 ALLOWED_CALLBACKS = (
29 42 "EarlyStopping",
30 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) 43 "TerminateOnNaN",
31 CACHE_DIR = os.path.join(os.getcwd(), 'cached') 44 "ReduceLROnPlateau",
32 del os 45 "CSVLogger",
33 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 46 "None",
34 'nthread', 'callbacks') 47 )
35 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau',
36 'CSVLogger', 'None')
37 48
38 49
39 def _eval_swap_params(params_builder): 50 def _eval_swap_params(params_builder):
40 swap_params = {} 51 swap_params = {}
41 52
42 for p in params_builder['param_set']: 53 for p in params_builder["param_set"]:
43 swap_value = p['sp_value'].strip() 54 swap_value = p["sp_value"].strip()
44 if swap_value == '': 55 if swap_value == "":
45 continue 56 continue
46 57
47 param_name = p['sp_name'] 58 param_name = p["sp_name"]
48 if param_name.lower().endswith(NON_SEARCHABLE): 59 if param_name.lower().endswith(NON_SEARCHABLE):
49 warnings.warn("Warning: `%s` is not eligible for search and was " 60 warnings.warn(
50 "omitted!" % param_name) 61 "Warning: `%s` is not eligible for search and was "
62 "omitted!" % param_name
63 )
51 continue 64 continue
52 65
53 if not swap_value.startswith(':'): 66 if not swap_value.startswith(":"):
54 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 67 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
55 ev = safe_eval(swap_value) 68 ev = safe_eval(swap_value)
56 else: 69 else:
57 # Have `:` before search list, asks for estimator evaluatio 70 # Have `:` before search list, asks for estimator evaluatio
58 safe_eval_es = SafeEval(load_estimators=True) 71 safe_eval_es = SafeEval(load_estimators=True)
75 if arr is None: 88 if arr is None:
76 nones.append(idx) 89 nones.append(idx)
77 else: 90 else:
78 new_arrays.append(arr) 91 new_arrays.append(arr)
79 92
80 if kwargs['shuffle'] == 'None': 93 if kwargs["shuffle"] == "None":
81 kwargs['shuffle'] = None 94 kwargs["shuffle"] = None
82 95
83 group_names = kwargs.pop('group_names', None) 96 group_names = kwargs.pop("group_names", None)
84 97
85 if group_names is not None and group_names.strip(): 98 if group_names is not None and group_names.strip():
86 group_names = [name.strip() for name in 99 group_names = [name.strip() for name in group_names.split(",")]
87 group_names.split(',')]
88 new_arrays = indexable(*new_arrays) 100 new_arrays = indexable(*new_arrays)
89 groups = kwargs['labels'] 101 groups = kwargs["labels"]
90 n_samples = new_arrays[0].shape[0] 102 n_samples = new_arrays[0].shape[0]
91 index_arr = np.arange(n_samples) 103 index_arr = np.arange(n_samples)
92 test = index_arr[np.isin(groups, group_names)] 104 test = index_arr[np.isin(groups, group_names)]
93 train = index_arr[~np.isin(groups, group_names)] 105 train = index_arr[~np.isin(groups, group_names)]
94 rval = list(chain.from_iterable( 106 rval = list(
95 (safe_indexing(a, train), 107 chain.from_iterable(
96 safe_indexing(a, test)) for a in new_arrays)) 108 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays
109 )
110 )
97 else: 111 else:
98 rval = train_test_split(*new_arrays, **kwargs) 112 rval = train_test_split(*new_arrays, **kwargs)
99 113
100 for pos in nones: 114 for pos in nones:
101 rval[pos * 2: 2] = [None, None] 115 rval[pos * 2: 2] = [None, None]
102 116
103 return rval 117 return rval
104 118
105 119
106 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True): 120 def _evaluate_keras_and_sklearn_scores(
107 """ output scores based on input scorer 121 estimator,
122 data_generator,
123 X,
124 y=None,
125 sk_scoring=None,
126 steps=None,
127 batch_size=32,
128 return_predictions=False,
129 ):
130 """output scores for bother keras and sklearn metrics
108 131
109 Parameters 132 Parameters
110 ---------- 133 -----------
111 y_true : array 134 estimator : object
112 True label or target values 135 Fitted `galaxy_ml.keras_galaxy_models.KerasGBatchClassifier`.
113 pred_probas : array 136 data_generator : object
114 Prediction values, probability for classification problem 137 From `galaxy_ml.preprocessors.ImageDataFrameBatchGenerator`.
115 scorer : dict 138 X : 2-D array
116 dict of `sklearn.metrics.scorer.SCORER` 139 Contains indecies of images that need to be evaluated.
117 is_multimetric : bool, default is True 140 y : None
141 Target value.
142 sk_scoring : dict
143 Galaxy tool input parameters.
144 steps : integer or None
145 Evaluation/prediction steps before stop.
146 batch_size : integer
147 Number of samples in a batch
148 return_predictions : bool, default is False
149 Whether to return predictions and true labels.
118 """ 150 """
119 if y_true.ndim == 1 or y_true.shape[-1] == 1: 151 scores = {}
120 pred_probas = pred_probas.ravel() 152
121 pred_labels = (pred_probas > 0.5).astype('int32') 153 generator = data_generator.flow(X, y=y, batch_size=batch_size)
122 targets = y_true.ravel().astype('int32') 154 # keras metrics evaluation
123 if not is_multimetric: 155 # handle scorer, convert to scorer dict
124 preds = pred_labels if scorer.__class__.__name__ == \ 156 generator.reset()
125 '_PredictScorer' else pred_probas 157 score_results = estimator.model_.evaluate_generator(generator, steps=steps)
126 score = scorer._score_func(targets, preds, **scorer._kwargs) 158 metrics_names = estimator.model_.metrics_names
127 159 if not isinstance(metrics_names, list):
128 return score 160 scores[metrics_names] = score_results
129 else: 161 else:
130 scores = {} 162 scores = dict(zip(metrics_names, score_results))
131 for name, one_scorer in scorer.items(): 163
132 preds = pred_labels if one_scorer.__class__.__name__\ 164 if sk_scoring["primary_scoring"] == "default" and not return_predictions:
133 == '_PredictScorer' else pred_probas 165 return scores
134 score = one_scorer._score_func(targets, preds, 166
135 **one_scorer._kwargs) 167 generator.reset()
136 scores[name] = score 168 predictions, y_true = _predict_generator(estimator.model_, generator, steps=steps)
137 169
138 # TODO: multi-class metrics 170 # for sklearn metrics
139 # multi-label 171 if sk_scoring["primary_scoring"] != "default":
140 else: 172 scorer = get_scoring(sk_scoring)
141 pred_labels = (pred_probas > 0.5).astype('int32') 173 if not isinstance(scorer, (dict, list)):
142 targets = y_true.astype('int32') 174 scorer = [sk_scoring["primary_scoring"]]
143 if not is_multimetric: 175 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
144 preds = pred_labels if scorer.__class__.__name__ == \ 176 sk_scores = gen_compute_scores(y_true, predictions, scorer)
145 '_PredictScorer' else pred_probas 177 scores.update(sk_scores)
146 score, _ = compute_score(preds, targets, 178
147 scorer._score_func) 179 if return_predictions:
148 return score 180 return scores, predictions, y_true
149 else: 181 else:
150 scores = {} 182 return scores, None, None
151 for name, one_scorer in scorer.items(): 183
152 preds = pred_labels if one_scorer.__class__.__name__\ 184
153 == '_PredictScorer' else pred_probas 185 def main(
154 score, _ = compute_score(preds, targets, 186 inputs,
155 one_scorer._score_func) 187 infile_estimator,
156 scores[name] = score 188 infile1,
157 189 infile2,
158 return scores 190 outfile_result,
159 191 outfile_object=None,
160 192 outfile_y_true=None,
161 def main(inputs, infile_estimator, infile1, infile2, 193 outfile_y_preds=None,
162 outfile_result, outfile_object=None, 194 groups=None,
163 outfile_weights=None, outfile_y_true=None, 195 ref_seq=None,
164 outfile_y_preds=None, groups=None, 196 intervals=None,
165 ref_seq=None, intervals=None, targets=None, 197 targets=None,
166 fasta_path=None): 198 fasta_path=None,
199 ):
167 """ 200 """
168 Parameter 201 Parameter
169 --------- 202 ---------
170 inputs : str 203 inputs : str
171 File path to galaxy tool parameter 204 File path to galaxy tool parameter.
172 205
173 infile_estimator : str 206 infile_estimator : str
174 File path to estimator 207 File path to estimator.
175 208
176 infile1 : str 209 infile1 : str
177 File path to dataset containing features 210 File path to dataset containing features.
178 211
179 infile2 : str 212 infile2 : str
180 File path to dataset containing target values 213 File path to dataset containing target values.
181 214
182 outfile_result : str 215 outfile_result : str
183 File path to save the results, either cv_results or test result 216 File path to save the results, either cv_results or test result.
184 217
185 outfile_object : str, optional 218 outfile_object : str, optional
186 File path to save searchCV object 219 File path to save searchCV object.
187
188 outfile_weights : str, optional
189 File path to save deep learning model weights
190 220
191 outfile_y_true : str, optional 221 outfile_y_true : str, optional
192 File path to target values for prediction 222 File path to target values for prediction.
193 223
194 outfile_y_preds : str, optional 224 outfile_y_preds : str, optional
195 File path to save deep learning model weights 225 File path to save predictions.
196 226
197 groups : str 227 groups : str
198 File path to dataset containing groups labels 228 File path to dataset containing groups labels.
199 229
200 ref_seq : str 230 ref_seq : str
201 File path to dataset containing genome sequence file 231 File path to dataset containing genome sequence file.
202 232
203 intervals : str 233 intervals : str
204 File path to dataset containing interval file 234 File path to dataset containing interval file.
205 235
206 targets : str 236 targets : str
207 File path to dataset compressed target bed file 237 File path to dataset compressed target bed file.
208 238
209 fasta_path : str 239 fasta_path : str
210 File path to dataset containing fasta file 240 File path to dataset containing fasta file.
211 """ 241 """
212 warnings.simplefilter('ignore') 242 warnings.simplefilter("ignore")
213 243
214 with open(inputs, 'r') as param_handler: 244 with open(inputs, "r") as param_handler:
215 params = json.load(param_handler) 245 params = json.load(param_handler)
216 246
217 # load estimator 247 # load estimator
218 with open(infile_estimator, 'rb') as estimator_handler: 248 estimator = load_model_from_h5(infile_estimator)
219 estimator = load_model(estimator_handler)
220 249
221 estimator = clean_params(estimator) 250 estimator = clean_params(estimator)
222 251
223 # swap hyperparameter 252 # swap hyperparameter
224 swapping = params['experiment_schemes']['hyperparams_swapping'] 253 swapping = params["experiment_schemes"]["hyperparams_swapping"]
225 swap_params = _eval_swap_params(swapping) 254 swap_params = _eval_swap_params(swapping)
226 estimator.set_params(**swap_params) 255 estimator.set_params(**swap_params)
227 256
228 estimator_params = estimator.get_params() 257 estimator_params = estimator.get_params()
229 258
230 # store read dataframe object 259 # store read dataframe object
231 loaded_df = {} 260 loaded_df = {}
232 261
233 input_type = params['input_options']['selected_input'] 262 input_type = params["input_options"]["selected_input"]
234 # tabular input 263 # tabular input
235 if input_type == 'tabular': 264 if input_type == "tabular":
236 header = 'infer' if params['input_options']['header1'] else None 265 header = "infer" if params["input_options"]["header1"] else None
237 column_option = (params['input_options']['column_selector_options_1'] 266 column_option = params["input_options"]["column_selector_options_1"][
238 ['selected_column_selector_option']) 267 "selected_column_selector_option"
239 if column_option in ['by_index_number', 'all_but_by_index_number', 268 ]
240 'by_header_name', 'all_but_by_header_name']: 269 if column_option in [
241 c = params['input_options']['column_selector_options_1']['col1'] 270 "by_index_number",
271 "all_but_by_index_number",
272 "by_header_name",
273 "all_but_by_header_name",
274 ]:
275 c = params["input_options"]["column_selector_options_1"]["col1"]
242 else: 276 else:
243 c = None 277 c = None
244 278
245 df_key = infile1 + repr(header) 279 df_key = infile1 + repr(header)
246 df = pd.read_csv(infile1, sep='\t', header=header, 280 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
247 parse_dates=True)
248 loaded_df[df_key] = df 281 loaded_df[df_key] = df
249 282
250 X = read_columns(df, c=c, c_option=column_option).astype(float) 283 X = read_columns(df, c=c, c_option=column_option).astype(float)
251 # sparse input 284 # sparse input
252 elif input_type == 'sparse': 285 elif input_type == "sparse":
253 X = mmread(open(infile1, 'r')) 286 X = mmread(open(infile1, "r"))
254 287
255 # fasta_file input 288 # fasta_file input
256 elif input_type == 'seq_fasta': 289 elif input_type == "seq_fasta":
257 pyfaidx = get_module('pyfaidx') 290 pyfaidx = get_module("pyfaidx")
258 sequences = pyfaidx.Fasta(fasta_path) 291 sequences = pyfaidx.Fasta(fasta_path)
259 n_seqs = len(sequences.keys()) 292 n_seqs = len(sequences.keys())
260 X = np.arange(n_seqs)[:, np.newaxis] 293 X = np.arange(n_seqs)[:, np.newaxis]
261 for param in estimator_params.keys(): 294 for param in estimator_params.keys():
262 if param.endswith('fasta_path'): 295 if param.endswith("fasta_path"):
263 estimator.set_params( 296 estimator.set_params(**{param: fasta_path})
264 **{param: fasta_path})
265 break 297 break
266 else: 298 else:
267 raise ValueError( 299 raise ValueError(
268 "The selected estimator doesn't support " 300 "The selected estimator doesn't support "
269 "fasta file input! Please consider using " 301 "fasta file input! Please consider using "
270 "KerasGBatchClassifier with " 302 "KerasGBatchClassifier with "
271 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 303 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
272 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 304 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
273 "in pipeline!") 305 "in pipeline!"
274 306 )
275 elif input_type == 'refseq_and_interval': 307
308 elif input_type == "refseq_and_interval":
276 path_params = { 309 path_params = {
277 'data_batch_generator__ref_genome_path': ref_seq, 310 "data_batch_generator__ref_genome_path": ref_seq,
278 'data_batch_generator__intervals_path': intervals, 311 "data_batch_generator__intervals_path": intervals,
279 'data_batch_generator__target_path': targets 312 "data_batch_generator__target_path": targets,
280 } 313 }
281 estimator.set_params(**path_params) 314 estimator.set_params(**path_params)
282 n_intervals = sum(1 for line in open(intervals)) 315 n_intervals = sum(1 for line in open(intervals))
283 X = np.arange(n_intervals)[:, np.newaxis] 316 X = np.arange(n_intervals)[:, np.newaxis]
284 317
285 # Get target y 318 # Get target y
286 header = 'infer' if params['input_options']['header2'] else None 319 header = "infer" if params["input_options"]["header2"] else None
287 column_option = (params['input_options']['column_selector_options_2'] 320 column_option = params["input_options"]["column_selector_options_2"][
288 ['selected_column_selector_option2']) 321 "selected_column_selector_option2"
289 if column_option in ['by_index_number', 'all_but_by_index_number', 322 ]
290 'by_header_name', 'all_but_by_header_name']: 323 if column_option in [
291 c = params['input_options']['column_selector_options_2']['col2'] 324 "by_index_number",
325 "all_but_by_index_number",
326 "by_header_name",
327 "all_but_by_header_name",
328 ]:
329 c = params["input_options"]["column_selector_options_2"]["col2"]
292 else: 330 else:
293 c = None 331 c = None
294 332
295 df_key = infile2 + repr(header) 333 df_key = infile2 + repr(header)
296 if df_key in loaded_df: 334 if df_key in loaded_df:
297 infile2 = loaded_df[df_key] 335 infile2 = loaded_df[df_key]
298 else: 336 else:
299 infile2 = pd.read_csv(infile2, sep='\t', 337 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
300 header=header, parse_dates=True)
301 loaded_df[df_key] = infile2 338 loaded_df[df_key] = infile2
302 339
303 y = read_columns( 340 y = read_columns(
304 infile2, 341 infile2,
305 c=c, 342 c=c,
306 c_option=column_option, 343 c_option=column_option,
307 sep='\t', 344 sep="\t",
308 header=header, 345 header=header,
309 parse_dates=True) 346 parse_dates=True,
347 )
310 if len(y.shape) == 2 and y.shape[1] == 1: 348 if len(y.shape) == 2 and y.shape[1] == 1:
311 y = y.ravel() 349 y = y.ravel()
312 if input_type == 'refseq_and_interval': 350 if input_type == "refseq_and_interval":
313 estimator.set_params( 351 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
314 data_batch_generator__features=y.ravel().tolist())
315 y = None 352 y = None
316 # end y 353 # end y
317 354
318 # load groups 355 # load groups
319 if groups: 356 if groups:
320 groups_selector = (params['experiment_schemes']['test_split'] 357 groups_selector = (
321 ['split_algos']).pop('groups_selector') 358 params["experiment_schemes"]["test_split"]["split_algos"]
322 359 ).pop("groups_selector")
323 header = 'infer' if groups_selector['header_g'] else None 360
324 column_option = \ 361 header = "infer" if groups_selector["header_g"] else None
325 (groups_selector['column_selector_options_g'] 362 column_option = groups_selector["column_selector_options_g"][
326 ['selected_column_selector_option_g']) 363 "selected_column_selector_option_g"
327 if column_option in ['by_index_number', 'all_but_by_index_number', 364 ]
328 'by_header_name', 'all_but_by_header_name']: 365 if column_option in [
329 c = groups_selector['column_selector_options_g']['col_g'] 366 "by_index_number",
367 "all_but_by_index_number",
368 "by_header_name",
369 "all_but_by_header_name",
370 ]:
371 c = groups_selector["column_selector_options_g"]["col_g"]
330 else: 372 else:
331 c = None 373 c = None
332 374
333 df_key = groups + repr(header) 375 df_key = groups + repr(header)
334 if df_key in loaded_df: 376 if df_key in loaded_df:
335 groups = loaded_df[df_key] 377 groups = loaded_df[df_key]
336 378
337 groups = read_columns( 379 groups = read_columns(
338 groups, 380 groups,
339 c=c, 381 c=c,
340 c_option=column_option, 382 c_option=column_option,
341 sep='\t', 383 sep="\t",
342 header=header, 384 header=header,
343 parse_dates=True) 385 parse_dates=True,
386 )
344 groups = groups.ravel() 387 groups = groups.ravel()
345 388
346 # del loaded_df 389 # del loaded_df
347 del loaded_df 390 del loaded_df
348 391
349 # cache iraps_core fits could increase search speed significantly 392 # cache iraps_core fits could increase search speed significantly
350 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 393 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
351 main_est = get_main_estimator(estimator) 394 main_est = get_main_estimator(estimator)
352 if main_est.__class__.__name__ == 'IRAPSClassifier': 395 if main_est.__class__.__name__ == "IRAPSClassifier":
353 main_est.set_params(memory=memory) 396 main_est.set_params(memory=memory)
354 397
355 # handle scorer, convert to scorer dict 398 # handle scorer, convert to scorer dict
356 scoring = params['experiment_schemes']['metrics']['scoring'] 399 scoring = params["experiment_schemes"]["metrics"]["scoring"]
357 scorer = get_scoring(scoring) 400 scorer = get_scoring(scoring)
358 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 401 if not isinstance(scorer, (dict, list)):
402 scorer = [scoring["primary_scoring"]]
403 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
359 404
360 # handle test (first) split 405 # handle test (first) split
361 test_split_options = (params['experiment_schemes'] 406 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
362 ['test_split']['split_algos']) 407
363 408 if test_split_options["shuffle"] == "group":
364 if test_split_options['shuffle'] == 'group': 409 test_split_options["labels"] = groups
365 test_split_options['labels'] = groups 410 if test_split_options["shuffle"] == "stratified":
366 if test_split_options['shuffle'] == 'stratified':
367 if y is not None: 411 if y is not None:
368 test_split_options['labels'] = y 412 test_split_options["labels"] = y
369 else: 413 else:
370 raise ValueError("Stratified shuffle split is not " 414 raise ValueError(
371 "applicable on empty target values!") 415 "Stratified shuffle split is not " "applicable on empty target values!"
372 416 )
373 X_train, X_test, y_train, y_test, groups_train, groups_test = \ 417
374 train_test_split_none(X, y, groups, **test_split_options) 418 X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split_none(
375 419 X, y, groups, **test_split_options
376 exp_scheme = params['experiment_schemes']['selected_exp_scheme'] 420 )
421
422 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
377 423
378 # handle validation (second) split 424 # handle validation (second) split
379 if exp_scheme == 'train_val_test': 425 if exp_scheme == "train_val_test":
380 val_split_options = (params['experiment_schemes'] 426 val_split_options = params["experiment_schemes"]["val_split"]["split_algos"]
381 ['val_split']['split_algos']) 427
382 428 if val_split_options["shuffle"] == "group":
383 if val_split_options['shuffle'] == 'group': 429 val_split_options["labels"] = groups_train
384 val_split_options['labels'] = groups_train 430 if val_split_options["shuffle"] == "stratified":
385 if val_split_options['shuffle'] == 'stratified':
386 if y_train is not None: 431 if y_train is not None:
387 val_split_options['labels'] = y_train 432 val_split_options["labels"] = y_train
388 else: 433 else:
389 raise ValueError("Stratified shuffle split is not " 434 raise ValueError(
390 "applicable on empty target values!") 435 "Stratified shuffle split is not "
391 436 "applicable on empty target values!"
392 X_train, X_val, y_train, y_val, groups_train, groups_val = \ 437 )
393 train_test_split_none(X_train, y_train, groups_train, 438
394 **val_split_options) 439 (
440 X_train,
441 X_val,
442 y_train,
443 y_val,
444 groups_train,
445 groups_val,
446 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
395 447
396 # train and eval 448 # train and eval
397 if hasattr(estimator, 'validation_data'): 449 if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
398 if exp_scheme == 'train_val_test': 450 if exp_scheme == "train_val_test":
399 estimator.fit(X_train, y_train, 451 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
400 validation_data=(X_val, y_val)) 452 else:
401 else: 453 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
402 estimator.fit(X_train, y_train,
403 validation_data=(X_test, y_test))
404 else: 454 else:
405 estimator.fit(X_train, y_train) 455 estimator.fit(X_train, y_train)
406 456
407 if hasattr(estimator, 'evaluate'): 457 if isinstance(estimator, KerasGBatchClassifier):
458 scores = {}
408 steps = estimator.prediction_steps 459 steps = estimator.prediction_steps
409 batch_size = estimator.batch_size 460 batch_size = estimator.batch_size
410 generator = estimator.data_generator_.flow(X_test, y=y_test, 461 data_generator = estimator.data_generator_
411 batch_size=batch_size) 462
412 predictions, y_true = _predict_generator(estimator.model_, generator, 463 scores, predictions, y_true = _evaluate_keras_and_sklearn_scores(
413 steps=steps) 464 estimator,
414 scores = _evaluate(y_true, predictions, scorer, is_multimetric=True) 465 data_generator,
415 466 X_test,
416 else: 467 y=y_test,
417 if hasattr(estimator, 'predict_proba'): 468 sk_scoring=scoring,
469 steps=steps,
470 batch_size=batch_size,
471 return_predictions=bool(outfile_y_true),
472 )
473
474 else:
475 scores = {}
476 if hasattr(estimator, "model_") and hasattr(estimator.model_, "metrics_names"):
477 batch_size = estimator.batch_size
478 score_results = estimator.model_.evaluate(
479 X_test, y=y_test, batch_size=batch_size, verbose=0
480 )
481 metrics_names = estimator.model_.metrics_names
482 if not isinstance(metrics_names, list):
483 scores[metrics_names] = score_results
484 else:
485 scores = dict(zip(metrics_names, score_results))
486
487 if hasattr(estimator, "predict_proba"):
418 predictions = estimator.predict_proba(X_test) 488 predictions = estimator.predict_proba(X_test)
419 else: 489 else:
420 predictions = estimator.predict(X_test) 490 predictions = estimator.predict(X_test)
421 491
422 y_true = y_test 492 y_true = y_test
423 scores = _score(estimator, X_test, y_test, scorer, 493 sk_scores = _score(estimator, X_test, y_test, scorer)
424 is_multimetric=True) 494 scores.update(sk_scores)
495
496 # handle output
425 if outfile_y_true: 497 if outfile_y_true:
426 try: 498 try:
427 pd.DataFrame(y_true).to_csv(outfile_y_true, sep='\t', 499 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False)
428 index=False)
429 pd.DataFrame(predictions).astype(np.float32).to_csv( 500 pd.DataFrame(predictions).astype(np.float32).to_csv(
430 outfile_y_preds, sep='\t', index=False, 501 outfile_y_preds,
431 float_format='%g', chunksize=10000) 502 sep="\t",
503 index=False,
504 float_format="%g",
505 chunksize=10000,
506 )
432 except Exception as e: 507 except Exception as e:
433 print("Error in saving predictions: %s" % e) 508 print("Error in saving predictions: %s" % e)
434
435 # handle output 509 # handle output
436 for name, score in scores.items(): 510 for name, score in scores.items():
437 scores[name] = [score] 511 scores[name] = [score]
438 df = pd.DataFrame(scores) 512 df = pd.DataFrame(scores)
439 df = df[sorted(df.columns)] 513 df = df[sorted(df.columns)]
440 df.to_csv(path_or_buf=outfile_result, sep='\t', 514 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
441 header=True, index=False)
442 515
443 memory.clear(warn=False) 516 memory.clear(warn=False)
444 517
445 if outfile_object: 518 if outfile_object:
446 main_est = estimator 519 dump_model_to_h5(estimator, outfile_object)
447 if isinstance(estimator, Pipeline): 520
448 main_est = estimator.steps[-1][-1] 521
449 522 if __name__ == "__main__":
450 if hasattr(main_est, 'model_') \
451 and hasattr(main_est, 'save_weights'):
452 if outfile_weights:
453 main_est.save_weights(outfile_weights)
454 del main_est.model_
455 del main_est.fit_params
456 del main_est.model_class_
457 del main_est.validation_data
458 if getattr(main_est, 'data_generator_', None):
459 del main_est.data_generator_
460
461 with open(outfile_object, 'wb') as output_handler:
462 pickle.dump(estimator, output_handler,
463 pickle.HIGHEST_PROTOCOL)
464
465
466 if __name__ == '__main__':
467 aparser = argparse.ArgumentParser() 523 aparser = argparse.ArgumentParser()
468 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 524 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
469 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 525 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
470 aparser.add_argument("-X", "--infile1", dest="infile1") 526 aparser.add_argument("-X", "--infile1", dest="infile1")
471 aparser.add_argument("-y", "--infile2", dest="infile2") 527 aparser.add_argument("-y", "--infile2", dest="infile2")
472 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") 528 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
473 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") 529 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
474 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights")
475 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") 530 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
476 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") 531 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
477 aparser.add_argument("-g", "--groups", dest="groups") 532 aparser.add_argument("-g", "--groups", dest="groups")
478 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 533 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
479 aparser.add_argument("-b", "--intervals", dest="intervals") 534 aparser.add_argument("-b", "--intervals", dest="intervals")
480 aparser.add_argument("-t", "--targets", dest="targets") 535 aparser.add_argument("-t", "--targets", dest="targets")
481 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 536 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
482 args = aparser.parse_args() 537 args = aparser.parse_args()
483 538
484 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 539 main(
485 args.outfile_result, outfile_object=args.outfile_object, 540 args.inputs,
486 outfile_weights=args.outfile_weights, 541 args.infile_estimator,
487 outfile_y_true=args.outfile_y_true, 542 args.infile1,
488 outfile_y_preds=args.outfile_y_preds, 543 args.infile2,
489 groups=args.groups, 544 args.outfile_result,
490 ref_seq=args.ref_seq, intervals=args.intervals, 545 outfile_object=args.outfile_object,
491 targets=args.targets, fasta_path=args.fasta_path) 546 outfile_y_true=args.outfile_y_true,
547 outfile_y_preds=args.outfile_y_preds,
548 groups=args.groups,
549 ref_seq=args.ref_seq,
550 intervals=args.intervals,
551 targets=args.targets,
552 fasta_path=args.fasta_path,
553 )