Mercurial > repos > bgruening > stacking_ensemble_models
comparison train_test_eval.py @ 2:38c4f8a98038 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author | bgruening |
---|---|
date | Mon, 16 Dec 2019 10:07:37 +0000 |
parents | c1b0c8232816 |
children | 0a1812986bc3 |
comparison
equal
deleted
inserted
replaced
1:c1b0c8232816 | 2:38c4f8a98038 |
---|---|
1 import argparse | 1 import argparse |
2 import joblib | 2 import joblib |
3 import json | 3 import json |
4 import numpy as np | 4 import numpy as np |
5 import os | |
5 import pandas as pd | 6 import pandas as pd |
6 import pickle | 7 import pickle |
7 import warnings | 8 import warnings |
8 from itertools import chain | 9 from itertools import chain |
9 from scipy.io import mmread | 10 from scipy.io import mmread |
27 | 28 |
28 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') | 29 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') |
29 setattr(_search, '_fit_and_score', _fit_and_score) | 30 setattr(_search, '_fit_and_score', _fit_and_score) |
30 setattr(_validation, '_fit_and_score', _fit_and_score) | 31 setattr(_validation, '_fit_and_score', _fit_and_score) |
31 | 32 |
32 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 33 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) |
33 CACHE_DIR = './cached' | 34 CACHE_DIR = os.path.join(os.getcwd(), 'cached') |
35 del os | |
34 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', | 36 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', |
35 'nthread', 'callbacks') | 37 'nthread', 'callbacks') |
36 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', | 38 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', |
37 'CSVLogger', 'None') | 39 'CSVLogger', 'None') |
38 | 40 |
401 del main_est.fit_params | 403 del main_est.fit_params |
402 del main_est.model_class_ | 404 del main_est.model_class_ |
403 del main_est.validation_data | 405 del main_est.validation_data |
404 if getattr(main_est, 'data_generator_', None): | 406 if getattr(main_est, 'data_generator_', None): |
405 del main_est.data_generator_ | 407 del main_est.data_generator_ |
406 del main_est.data_batch_generator | |
407 | 408 |
408 with open(outfile_object, 'wb') as output_handler: | 409 with open(outfile_object, 'wb') as output_handler: |
409 pickle.dump(estimator, output_handler, | 410 pickle.dump(estimator, output_handler, |
410 pickle.HIGHEST_PROTOCOL) | 411 pickle.HIGHEST_PROTOCOL) |
411 | 412 |