comparison train_test_eval.py @ 2:38c4f8a98038 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author bgruening
date Mon, 16 Dec 2019 10:07:37 +0000
parents c1b0c8232816
children 0a1812986bc3
comparison
equal deleted inserted replaced
1:c1b0c8232816 2:38c4f8a98038
1 import argparse 1 import argparse
2 import joblib 2 import joblib
3 import json 3 import json
4 import numpy as np 4 import numpy as np
5 import os
5 import pandas as pd 6 import pandas as pd
6 import pickle 7 import pickle
7 import warnings 8 import warnings
8 from itertools import chain 9 from itertools import chain
9 from scipy.io import mmread 10 from scipy.io import mmread
27 28
28 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 29 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score')
29 setattr(_search, '_fit_and_score', _fit_and_score) 30 setattr(_search, '_fit_and_score', _fit_and_score)
30 setattr(_validation, '_fit_and_score', _fit_and_score) 31 setattr(_validation, '_fit_and_score', _fit_and_score)
31 32
32 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 33 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
33 CACHE_DIR = './cached' 34 CACHE_DIR = os.path.join(os.getcwd(), 'cached')
35 del os
34 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 36 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path',
35 'nthread', 'callbacks') 37 'nthread', 'callbacks')
36 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', 38 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau',
37 'CSVLogger', 'None') 39 'CSVLogger', 'None')
38 40
401 del main_est.fit_params 403 del main_est.fit_params
402 del main_est.model_class_ 404 del main_est.model_class_
403 del main_est.validation_data 405 del main_est.validation_data
404 if getattr(main_est, 'data_generator_', None): 406 if getattr(main_est, 'data_generator_', None):
405 del main_est.data_generator_ 407 del main_est.data_generator_
406 del main_est.data_batch_generator
407 408
408 with open(outfile_object, 'wb') as output_handler: 409 with open(outfile_object, 'wb') as output_handler:
409 pickle.dump(estimator, output_handler, 410 pickle.dump(estimator, output_handler,
410 pickle.HIGHEST_PROTOCOL) 411 pickle.HIGHEST_PROTOCOL)
411 412