# HG changeset patch # User bgruening # Date 1573121111 18000 # Node ID 4bb7ad4713101aaaa5b4d0e41e3bf52134d2e115 # Parent b28f2931005774c48b0caeeb6fe02c830465aebf "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53" diff -r b28f29310057 -r 4bb7ad471310 simple_model_fit.py --- a/simple_model_fit.py Fri Nov 01 16:52:34 2019 -0400 +++ b/simple_model_fit.py Thu Nov 07 05:05:11 2019 -0500 @@ -7,6 +7,44 @@ from sklearn.pipeline import Pipeline +N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) + + +# TODO import from galaxy_ml.utils in future versions +def clean_params(estimator, n_jobs=None): + """clean unwanted hyperparameter settings + + If n_jobs is not None, set it into the estimator, if applicable + + Return + ------ + Cleaned estimator object + """ + ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', + 'ReduceLROnPlateau', 'CSVLogger', 'None') + + estimator_params = estimator.get_params() + + for name, p in estimator_params.items(): + # all potential unauthorized file write + if name == 'memory' or name.endswith('__memory') \ + or name.endswith('_path'): + new_p = {name: None} + estimator.set_params(**new_p) + elif n_jobs is not None and (name == 'n_jobs' or + name.endswith('__n_jobs')): + new_p = {name: n_jobs} + estimator.set_params(**new_p) + elif name.endswith('callbacks'): + for cb in p: + cb_type = cb['callback_selection']['callback_type'] + if cb_type not in ALLOWED_CALLBACKS: + raise ValueError( + "Prohibited callback type: %s!" % cb_type) + + return estimator + + def _get_X_y(params, infile1, infile2): """ read from inputs and output X and y @@ -107,6 +145,7 @@ # load model with open(infile_estimator, 'rb') as est_handler: estimator = load_model(est_handler) + estimator = clean_params(estimator, n_jobs=N_JOBS) X_train, y_train = _get_X_y(params, infile1, infile2)