comparison train_test_split.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import json 2 import json
3 import warnings
4 from distutils.version import LooseVersion as Version
5
3 import pandas as pd 6 import pandas as pd
4 import warnings 7 from galaxy_ml import __version__ as galaxy_ml_version
5
6 from galaxy_ml.model_validations import train_test_split 8 from galaxy_ml.model_validations import train_test_split
7 from galaxy_ml.utils import get_cv, read_columns 9 from galaxy_ml.utils import get_cv, read_columns
8 10
9 11
10 def _get_single_cv_split(params, array, infile_labels=None, 12 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
11 infile_groups=None): 13 """output (train, test) subset from a cv splitter
12 """ output (train, test) subset from a cv splitter
13 14
14 Parameters 15 Parameters
15 ---------- 16 ----------
16 params : dict 17 params : dict
17 Galaxy tool inputs 18 Galaxy tool inputs
23 File path to dataset containing group values 24 File path to dataset containing group values
24 """ 25 """
25 y = None 26 y = None
26 groups = None 27 groups = None
27 28
28 nth_split = params['mode_selection']['nth_split'] 29 nth_split = params["mode_selection"]["nth_split"]
29 30
30 # read groups 31 # read groups
31 if infile_groups: 32 if infile_groups:
32 header = 'infer' if (params['mode_selection']['cv_selector'] 33 header = (
33 ['groups_selector']['header_g']) else None 34 "infer"
34 column_option = (params['mode_selection']['cv_selector'] 35 if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"])
35 ['groups_selector']['column_selector_options_g'] 36 else None
36 ['selected_column_selector_option_g']) 37 )
37 if column_option in ['by_index_number', 'all_but_by_index_number', 38 column_option = params["mode_selection"]["cv_selector"]["groups_selector"][
38 'by_header_name', 'all_but_by_header_name']: 39 "column_selector_options_g"
39 c = (params['mode_selection']['cv_selector']['groups_selector'] 40 ]["selected_column_selector_option_g"]
40 ['column_selector_options_g']['col_g']) 41 if column_option in [
42 "by_index_number",
43 "all_but_by_index_number",
44 "by_header_name",
45 "all_but_by_header_name",
46 ]:
47 c = params["mode_selection"]["cv_selector"]["groups_selector"][
48 "column_selector_options_g"
49 ]["col_g"]
41 else: 50 else:
42 c = None 51 c = None
43 52
44 groups = read_columns(infile_groups, c=c, c_option=column_option, 53 groups = read_columns(
45 sep='\t', header=header, parse_dates=True) 54 infile_groups,
55 c=c,
56 c_option=column_option,
57 sep="\t",
58 header=header,
59 parse_dates=True,
60 )
46 groups = groups.ravel() 61 groups = groups.ravel()
47 62
48 params['mode_selection']['cv_selector']['groups_selector'] = groups 63 params["mode_selection"]["cv_selector"]["groups_selector"] = groups
49 64
50 # read labels 65 # read labels
51 if infile_labels: 66 if infile_labels:
52 target_input = (params['mode_selection'] 67 target_input = params["mode_selection"]["cv_selector"].pop("target_input")
53 ['cv_selector'].pop('target_input')) 68 header = "infer" if target_input["header1"] else None
54 header = 'infer' if target_input['header1'] else None 69 col_index = target_input["col"][0] - 1
55 col_index = target_input['col'][0] - 1 70 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
56 df = pd.read_csv(infile_labels, sep='\t', header=header,
57 parse_dates=True)
58 y = df.iloc[:, col_index].values 71 y = df.iloc[:, col_index].values
59 72
60 # construct the cv splitter object 73 # construct the cv splitter object
61 splitter, groups = get_cv(params['mode_selection']['cv_selector']) 74 cv_selector = params["mode_selection"]["cv_selector"]
75 if Version(galaxy_ml_version) < Version("0.8.3"):
76 cv_selector.pop("n_stratification_bins", None)
77 splitter, groups = get_cv(cv_selector)
62 78
63 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) 79 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
64 if nth_split > total_n_splits: 80 if nth_split > total_n_splits:
65 raise ValueError("Total number of splits is {}, but got `nth_split` " 81 raise ValueError(
66 "= {}".format(total_n_splits, nth_split)) 82 "Total number of splits is {}, but got `nth_split` "
83 "= {}".format(total_n_splits, nth_split)
84 )
67 85
68 i = 1 86 i = 1
69 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): 87 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
70 # suppose nth_split >= 1 88 # suppose nth_split >= 1
71 if i == nth_split: 89 if i == nth_split:
77 test = array.iloc[test_index, :] 95 test = array.iloc[test_index, :]
78 96
79 return train, test 97 return train, test
80 98
81 99
82 def main(inputs, infile_array, outfile_train, outfile_test, 100 def main(
83 infile_labels=None, infile_groups=None): 101 inputs,
102 infile_array,
103 outfile_train,
104 outfile_test,
105 infile_labels=None,
106 infile_groups=None,
107 ):
84 """ 108 """
85 Parameter 109 Parameter
86 --------- 110 ---------
87 inputs : str 111 inputs : str
88 File path to galaxy tool parameter 112 File path to galaxy tool parameter
100 File path to dataset containing train split 124 File path to dataset containing train split
101 125
102 outfile_test : str 126 outfile_test : str
103 File path to dataset containing test split 127 File path to dataset containing test split
104 """ 128 """
105 warnings.simplefilter('ignore') 129 warnings.simplefilter("ignore")
106 130
107 with open(inputs, 'r') as param_handler: 131 with open(inputs, "r") as param_handler:
108 params = json.load(param_handler) 132 params = json.load(param_handler)
109 133
110 input_header = params['header0'] 134 input_header = params["header0"]
111 header = 'infer' if input_header else None 135 header = "infer" if input_header else None
112 array = pd.read_csv(infile_array, sep='\t', header=header, 136 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True)
113 parse_dates=True)
114 137
115 # train test split 138 # train test split
116 if params['mode_selection']['selected_mode'] == 'train_test_split': 139 if params["mode_selection"]["selected_mode"] == "train_test_split":
117 options = params['mode_selection']['options'] 140 options = params["mode_selection"]["options"]
118 shuffle_selection = options.pop('shuffle_selection') 141 shuffle_selection = options.pop("shuffle_selection")
119 options['shuffle'] = shuffle_selection['shuffle'] 142 options["shuffle"] = shuffle_selection["shuffle"]
120 if infile_labels: 143 if infile_labels:
121 header = 'infer' if shuffle_selection['header1'] else None 144 header = "infer" if shuffle_selection["header1"] else None
122 col_index = shuffle_selection['col'][0] - 1 145 col_index = shuffle_selection["col"][0] - 1
123 df = pd.read_csv(infile_labels, sep='\t', header=header, 146 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
124 parse_dates=True)
125 labels = df.iloc[:, col_index].values 147 labels = df.iloc[:, col_index].values
126 options['labels'] = labels 148 options["labels"] = labels
127 149
128 train, test = train_test_split(array, **options) 150 train, test = train_test_split(array, **options)
129 151
130 # cv splitter 152 # cv splitter
131 else: 153 else:
132 train, test = _get_single_cv_split(params, array, 154 train, test = _get_single_cv_split(
133 infile_labels=infile_labels, 155 params, array, infile_labels=infile_labels, infile_groups=infile_groups
134 infile_groups=infile_groups) 156 )
135 157
136 print("Input shape: %s" % repr(array.shape)) 158 print("Input shape: %s" % repr(array.shape))
137 print("Train shape: %s" % repr(train.shape)) 159 print("Train shape: %s" % repr(train.shape))
138 print("Test shape: %s" % repr(test.shape)) 160 print("Test shape: %s" % repr(test.shape))
139 train.to_csv(outfile_train, sep='\t', header=input_header, index=False) 161 train.to_csv(outfile_train, sep="\t", header=input_header, index=False)
140 test.to_csv(outfile_test, sep='\t', header=input_header, index=False) 162 test.to_csv(outfile_test, sep="\t", header=input_header, index=False)
141 163
142 164
143 if __name__ == '__main__': 165 if __name__ == "__main__":
144 aparser = argparse.ArgumentParser() 166 aparser = argparse.ArgumentParser()
145 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 167 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
146 aparser.add_argument("-X", "--infile_array", dest="infile_array") 168 aparser.add_argument("-X", "--infile_array", dest="infile_array")
147 aparser.add_argument("-y", "--infile_labels", dest="infile_labels") 169 aparser.add_argument("-y", "--infile_labels", dest="infile_labels")
148 aparser.add_argument("-g", "--infile_groups", dest="infile_groups") 170 aparser.add_argument("-g", "--infile_groups", dest="infile_groups")
149 aparser.add_argument("-o", "--outfile_train", dest="outfile_train") 171 aparser.add_argument("-o", "--outfile_train", dest="outfile_train")
150 aparser.add_argument("-t", "--outfile_test", dest="outfile_test") 172 aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
151 args = aparser.parse_args() 173 args = aparser.parse_args()
152 174
153 main(args.inputs, args.infile_array, args.outfile_train, 175 main(
154 args.outfile_test, args.infile_labels, args.infile_groups) 176 args.inputs,
177 args.infile_array,
178 args.outfile_train,
179 args.outfile_test,
180 args.infile_labels,
181 args.infile_groups,
182 )