comparison simple_model_fit.py @ 8:f2c240cce242 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 20:47:28 +0000
parents b2156aa78d1e
children 022e85ead64f
comparison
equal deleted inserted replaced
7:c9b521fcc3ac 8:f2c240cce242
1 import argparse 1 import argparse
2 import json 2 import json
3 import pandas as pd
4 import pickle 3 import pickle
5 4
5 import pandas as pd
6 from galaxy_ml.utils import load_model, read_columns 6 from galaxy_ml.utils import load_model, read_columns
7 from scipy.io import mmread
7 from sklearn.pipeline import Pipeline 8 from sklearn.pipeline import Pipeline
8 9
9 10
10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
11 12
12 13
13 # TODO import from galaxy_ml.utils in future versions 14 # TODO import from galaxy_ml.utils in future versions
14 def clean_params(estimator, n_jobs=None): 15 def clean_params(estimator, n_jobs=None):
15 """clean unwanted hyperparameter settings 16 """clean unwanted hyperparameter settings
18 19
19 Return 20 Return
20 ------ 21 ------
21 Cleaned estimator object 22 Cleaned estimator object
22 """ 23 """
23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 24 ALLOWED_CALLBACKS = (
24 'ReduceLROnPlateau', 'CSVLogger', 'None') 25 "EarlyStopping",
26 "TerminateOnNaN",
27 "ReduceLROnPlateau",
28 "CSVLogger",
29 "None",
30 )
25 31
26 estimator_params = estimator.get_params() 32 estimator_params = estimator.get_params()
27 33
28 for name, p in estimator_params.items(): 34 for name, p in estimator_params.items():
29 # all potential unauthorized file write 35 # all potential unauthorized file write
30 if name == 'memory' or name.endswith('__memory') \ 36 if name == "memory" or name.endswith("__memory") or name.endswith("_path"):
31 or name.endswith('_path'):
32 new_p = {name: None} 37 new_p = {name: None}
33 estimator.set_params(**new_p) 38 estimator.set_params(**new_p)
34 elif n_jobs is not None and (name == 'n_jobs' or 39 elif n_jobs is not None and (name == 'n_jobs' or name.endswith('__n_jobs')):
35 name.endswith('__n_jobs')):
36 new_p = {name: n_jobs} 40 new_p = {name: n_jobs}
37 estimator.set_params(**new_p) 41 estimator.set_params(**new_p)
38 elif name.endswith('callbacks'): 42 elif name.endswith("callbacks"):
39 for cb in p: 43 for cb in p:
40 cb_type = cb['callback_selection']['callback_type'] 44 cb_type = cb["callback_selection"]["callback_type"]
41 if cb_type not in ALLOWED_CALLBACKS: 45 if cb_type not in ALLOWED_CALLBACKS:
42 raise ValueError( 46 raise ValueError("Prohibited callback type: %s!" % cb_type)
43 "Prohibited callback type: %s!" % cb_type)
44 47
45 return estimator 48 return estimator
46 49
47 50
48 def _get_X_y(params, infile1, infile2): 51 def _get_X_y(params, infile1, infile2):
49 """ read from inputs and output X and y 52 """read from inputs and output X and y
50 53
51 Parameters 54 Parameters
52 ---------- 55 ----------
53 params : dict 56 params : dict
54 Tool inputs parameter 57 Tool inputs parameter
59 62
60 """ 63 """
61 # store read dataframe object 64 # store read dataframe object
62 loaded_df = {} 65 loaded_df = {}
63 66
64 input_type = params['input_options']['selected_input'] 67 input_type = params["input_options"]["selected_input"]
65 # tabular input 68 # tabular input
66 if input_type == 'tabular': 69 if input_type == "tabular":
67 header = 'infer' if params['input_options']['header1'] else None 70 header = "infer" if params["input_options"]["header1"] else None
68 column_option = (params['input_options']['column_selector_options_1'] 71 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
69 ['selected_column_selector_option']) 72 if column_option in [
70 if column_option in ['by_index_number', 'all_but_by_index_number', 73 "by_index_number",
71 'by_header_name', 'all_but_by_header_name']: 74 "all_but_by_index_number",
72 c = params['input_options']['column_selector_options_1']['col1'] 75 "by_header_name",
76 "all_but_by_header_name",
77 ]:
78 c = params["input_options"]["column_selector_options_1"]["col1"]
73 else: 79 else:
74 c = None 80 c = None
75 81
76 df_key = infile1 + repr(header) 82 df_key = infile1 + repr(header)
77 df = pd.read_csv(infile1, sep='\t', header=header, 83 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
78 parse_dates=True)
79 loaded_df[df_key] = df 84 loaded_df[df_key] = df
80 85
81 X = read_columns(df, c=c, c_option=column_option).astype(float) 86 X = read_columns(df, c=c, c_option=column_option).astype(float)
82 # sparse input 87 # sparse input
83 elif input_type == 'sparse': 88 elif input_type == "sparse":
84 X = mmread(open(infile1, 'r')) 89 X = mmread(open(infile1, "r"))
85 90
86 # Get target y 91 # Get target y
87 header = 'infer' if params['input_options']['header2'] else None 92 header = "infer" if params["input_options"]["header2"] else None
88 column_option = (params['input_options']['column_selector_options_2'] 93 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"]
89 ['selected_column_selector_option2']) 94 if column_option in [
90 if column_option in ['by_index_number', 'all_but_by_index_number', 95 "by_index_number",
91 'by_header_name', 'all_but_by_header_name']: 96 "all_but_by_index_number",
92 c = params['input_options']['column_selector_options_2']['col2'] 97 "by_header_name",
98 "all_but_by_header_name",
99 ]:
100 c = params["input_options"]["column_selector_options_2"]["col2"]
93 else: 101 else:
94 c = None 102 c = None
95 103
96 df_key = infile2 + repr(header) 104 df_key = infile2 + repr(header)
97 if df_key in loaded_df: 105 if df_key in loaded_df:
98 infile2 = loaded_df[df_key] 106 infile2 = loaded_df[df_key]
99 else: 107 else:
100 infile2 = pd.read_csv(infile2, sep='\t', 108 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
101 header=header, parse_dates=True)
102 loaded_df[df_key] = infile2 109 loaded_df[df_key] = infile2
103 110
104 y = read_columns( 111 y = read_columns(infile2,
105 infile2, 112 c=c,
106 c=c, 113 c_option=column_option,
107 c_option=column_option, 114 sep='\t',
108 sep='\t', 115 header=header,
109 header=header, 116 parse_dates=True)
110 parse_dates=True)
111 if len(y.shape) == 2 and y.shape[1] == 1: 117 if len(y.shape) == 2 and y.shape[1] == 1:
112 y = y.ravel() 118 y = y.ravel()
113 119
114 return X, y 120 return X, y
115 121
116 122
117 def main(inputs, infile_estimator, infile1, infile2, out_object, 123 def main(inputs, infile_estimator, infile1, infile2, out_object, out_weights=None):
118 out_weights=None): 124 """main
119 """ main
120 125
121 Parameters 126 Parameters
122 ---------- 127 ----------
123 inputs : str 128 inputs : str
124 File path to galaxy tool parameter 129 File path to galaxy tool parameter
137 142
138 out_weights : str 143 out_weights : str
139 File path for output of weights 144 File path for output of weights
140 145
141 """ 146 """
142 with open(inputs, 'r') as param_handler: 147 with open(inputs, "r") as param_handler:
143 params = json.load(param_handler) 148 params = json.load(param_handler)
144 149
145 # load model 150 # load model
146 with open(infile_estimator, 'rb') as est_handler: 151 with open(infile_estimator, "rb") as est_handler:
147 estimator = load_model(est_handler) 152 estimator = load_model(est_handler)
148 estimator = clean_params(estimator, n_jobs=N_JOBS) 153 estimator = clean_params(estimator, n_jobs=N_JOBS)
149 154
150 X_train, y_train = _get_X_y(params, infile1, infile2) 155 X_train, y_train = _get_X_y(params, infile1, infile2)
151 156
152 estimator.fit(X_train, y_train) 157 estimator.fit(X_train, y_train)
153 158
154 main_est = estimator 159 main_est = estimator
155 if isinstance(main_est, Pipeline): 160 if isinstance(main_est, Pipeline):
156 main_est = main_est.steps[-1][-1] 161 main_est = main_est.steps[-1][-1]
157 if hasattr(main_est, 'model_') \ 162 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
158 and hasattr(main_est, 'save_weights'):
159 if out_weights: 163 if out_weights:
160 main_est.save_weights(out_weights) 164 main_est.save_weights(out_weights)
161 del main_est.model_ 165 del main_est.model_
162 del main_est.fit_params 166 del main_est.fit_params
163 del main_est.model_class_ 167 del main_est.model_class_
164 del main_est.validation_data 168 if getattr(main_est, "validation_data", None):
165 if getattr(main_est, 'data_generator_', None): 169 del main_est.validation_data
170 if getattr(main_est, "data_generator_", None):
166 del main_est.data_generator_ 171 del main_est.data_generator_
167 172
168 with open(out_object, 'wb') as output_handler: 173 with open(out_object, "wb") as output_handler:
169 pickle.dump(estimator, output_handler, 174 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
170 pickle.HIGHEST_PROTOCOL)
171 175
172 176
173 if __name__ == '__main__': 177 if __name__ == "__main__":
174 aparser = argparse.ArgumentParser() 178 aparser = argparse.ArgumentParser()
175 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 179 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
176 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator") 180 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator")
177 aparser.add_argument("-y", "--infile1", dest="infile1") 181 aparser.add_argument("-y", "--infile1", dest="infile1")
178 aparser.add_argument("-g", "--infile2", dest="infile2") 182 aparser.add_argument("-g", "--infile2", dest="infile2")
179 aparser.add_argument("-o", "--out_object", dest="out_object") 183 aparser.add_argument("-o", "--out_object", dest="out_object")
180 aparser.add_argument("-t", "--out_weights", dest="out_weights") 184 aparser.add_argument("-t", "--out_weights", dest="out_weights")
181 args = aparser.parse_args() 185 args = aparser.parse_args()
182 186
183 main(args.inputs, args.infile_estimator, args.infile1, 187 main(
184 args.infile2, args.out_object, args.out_weights) 188 args.inputs,
189 args.infile_estimator,
190 args.infile1,
191 args.infile2,
192 args.out_object,
193 args.out_weights,
194 )