Mercurial > repos > bgruening > create_tool_recommendation_model
comparison utils.py @ 3:98bc44d17561 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
| author | bgruening |
|---|---|
| date | Tue, 07 Jul 2020 07:24:21 +0000 |
| parents | 50753817983a |
| children | f0da532be419 |
comparison
equal
deleted
inserted
replaced
| 2:50753817983a | 3:98bc44d17561 |
|---|---|
| 1 import os | |
| 2 import numpy as np | 1 import numpy as np |
| 3 import json | 2 import json |
| 4 import h5py | 3 import h5py |
| 5 import random | 4 import random |
| 5 from numpy.random import choice | |
| 6 | 6 |
| 7 from keras import backend as K | 7 from keras import backend as K |
| 8 | 8 |
| 9 | 9 |
| 10 def read_file(file_path): | 10 def read_file(file_path): |
| 52 Create a weighted loss function. Penalise the misclassification | 52 Create a weighted loss function. Penalise the misclassification |
| 53 of classes more with the higher usage | 53 of classes more with the higher usage |
| 54 """ | 54 """ |
| 55 weight_values = list(class_weights.values()) | 55 weight_values = list(class_weights.values()) |
| 56 weight_values.extend(weight_values) | 56 weight_values.extend(weight_values) |
| 57 | |
| 58 def weighted_binary_crossentropy(y_true, y_pred): | 57 def weighted_binary_crossentropy(y_true, y_pred): |
| 59 # add another dimension to compute dot product | 58 # add another dimension to compute dot product |
| 60 expanded_weights = K.expand_dims(weight_values, axis=-1) | 59 expanded_weights = K.expand_dims(weight_values, axis=-1) |
| 61 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) | 60 return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) |
| 62 return weighted_binary_crossentropy | 61 return weighted_binary_crossentropy |
| 63 | 62 |
| 64 | 63 |
| 65 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples): | 64 def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary): |
| 66 while True: | 65 while True: |
| 67 dimension = train_data.shape[1] | 66 dimension = train_data.shape[1] |
| 68 n_classes = train_labels.shape[1] | 67 n_classes = train_labels.shape[1] |
| 69 tool_ids = list(l_tool_tr_samples.keys()) | 68 tool_ids = list(l_tool_tr_samples.keys()) |
| 69 random.shuffle(tool_ids) | |
| 70 generator_batch_data = np.zeros([batch_size, dimension]) | 70 generator_batch_data = np.zeros([batch_size, dimension]) |
| 71 generator_batch_labels = np.zeros([batch_size, n_classes]) | 71 generator_batch_labels = np.zeros([batch_size, n_classes]) |
| 72 generated_tool_ids = choice(tool_ids, batch_size) | |
| 72 for i in range(batch_size): | 73 for i in range(batch_size): |
| 73 random_toolid_index = random.sample(range(0, len(tool_ids)), 1)[0] | 74 random_toolid = generated_tool_ids[i] |
| 74 random_toolid = tool_ids[random_toolid_index] | |
| 75 sample_indices = l_tool_tr_samples[str(random_toolid)] | 75 sample_indices = l_tool_tr_samples[str(random_toolid)] |
| 76 random_index = random.sample(range(0, len(sample_indices)), 1)[0] | 76 random_index = random.sample(range(0, len(sample_indices)), 1)[0] |
| 77 random_tr_index = sample_indices[random_index] | 77 random_tr_index = sample_indices[random_index] |
| 78 generator_batch_data[i] = train_data[random_tr_index] | 78 generator_batch_data[i] = train_data[random_tr_index] |
| 79 generator_batch_labels[i] = train_labels[random_tr_index] | 79 generator_batch_labels[i] = train_labels[random_tr_index] |
| 127 # compute scores for published recommendations | 127 # compute scores for published recommendations |
| 128 if standard_topk_prediction_pos in reverse_data_dictionary: | 128 if standard_topk_prediction_pos in reverse_data_dictionary: |
| 129 pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)] | 129 pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)] |
| 130 if last_tool_name in standard_conn: | 130 if last_tool_name in standard_conn: |
| 131 pub_tools = standard_conn[last_tool_name] | 131 pub_tools = standard_conn[last_tool_name] |
| 132 if pred_t_name in pub_tools: | 132 if pred_t_name in pub_tools: |
| 133 pub_precision = 1.0 | 133 pub_precision = 1.0 |
| 134 if last_tool_id in lowest_tool_ids: | 134 # count precision only when there is actually true published tools |
| 135 lowest_pub_prec = 1.0 | 135 if last_tool_id in lowest_tool_ids: |
| 136 if standard_topk_prediction_pos in usage_scores: | 136 lowest_pub_prec = 1.0 |
| 137 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) | 137 else: |
| 138 lowest_pub_prec = np.nan | |
| 139 if standard_topk_prediction_pos in usage_scores: | |
| 140 usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) | |
| 141 else: | |
| 142 # count precision only when there is actually true published tools | |
| 143 # else set to np.nan. Set to 0 only when there is wrong prediction | |
| 144 pub_precision = np.nan | |
| 145 lowest_pub_prec = np.nan | |
| 138 # compute scores for normal recommendations | 146 # compute scores for normal recommendations |
| 139 if normal_topk_prediction_pos in reverse_data_dictionary: | 147 if normal_topk_prediction_pos in reverse_data_dictionary: |
| 140 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] | 148 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] |
| 141 if pred_t_name in actual_next_tool_names: | 149 if pred_t_name in actual_next_tool_names: |
| 142 if normal_topk_prediction_pos in usage_scores: | 150 if normal_topk_prediction_pos in usage_scores: |
| 143 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) | 151 usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) |
| 144 top_precision = 1.0 | 152 top_precision = 1.0 |
| 145 if last_tool_id in lowest_tool_ids: | 153 if last_tool_id in lowest_tool_ids: |
| 146 lowest_norm_prec = 1.0 | 154 lowest_norm_prec = 1.0 |
| 155 else: | |
| 156 lowest_norm_prec = np.nan | |
| 147 if len(usage_wt_score) > 0: | 157 if len(usage_wt_score) > 0: |
| 148 mean_usage = np.mean(usage_wt_score) | 158 mean_usage = np.mean(usage_wt_score) |
| 149 return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec | 159 return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec |
| 150 | 160 |
| 151 | 161 |
| 166 precision = np.zeros([len(y), len(topk_list)]) | 176 precision = np.zeros([len(y), len(topk_list)]) |
| 167 usage_weights = np.zeros([len(y), len(topk_list)]) | 177 usage_weights = np.zeros([len(y), len(topk_list)]) |
| 168 epo_pub_prec = np.zeros([len(y), len(topk_list)]) | 178 epo_pub_prec = np.zeros([len(y), len(topk_list)]) |
| 169 epo_lowest_tools_pub_prec = list() | 179 epo_lowest_tools_pub_prec = list() |
| 170 epo_lowest_tools_norm_prec = list() | 180 epo_lowest_tools_norm_prec = list() |
| 171 | 181 lowest_counter = 0 |
| 172 # loop over all the test samples and find prediction precision | 182 # loop over all the test samples and find prediction precision |
| 173 for i in range(size): | 183 for i in range(size): |
| 174 lowest_pub_topk = list() | 184 lowest_pub_topk = list() |
| 175 lowest_norm_topk = list() | 185 lowest_norm_topk = list() |
| 176 actual_classes_pos = np.where(y[i] > 0)[0] | 186 actual_classes_pos = np.where(y[i] > 0)[0] |
| 179 for index, abs_topk in enumerate(topk_list): | 189 for index, abs_topk in enumerate(topk_list): |
| 180 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) | 190 usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) |
| 181 precision[i][index] = absolute_precision | 191 precision[i][index] = absolute_precision |
| 182 usage_weights[i][index] = usg_wt_score | 192 usage_weights[i][index] = usg_wt_score |
| 183 epo_pub_prec[i][index] = pub_prec | 193 epo_pub_prec[i][index] = pub_prec |
| 184 if last_tool_id in lowest_tool_ids: | 194 lowest_pub_topk.append(lowest_p_prec) |
| 185 lowest_pub_topk.append(lowest_p_prec) | 195 lowest_norm_topk.append(lowest_n_prec) |
| 186 lowest_norm_topk.append(lowest_n_prec) | 196 epo_lowest_tools_pub_prec.append(lowest_pub_topk) |
| 197 epo_lowest_tools_norm_prec.append(lowest_norm_topk) | |
| 187 if last_tool_id in lowest_tool_ids: | 198 if last_tool_id in lowest_tool_ids: |
| 188 epo_lowest_tools_pub_prec.append(lowest_pub_topk) | 199 lowest_counter += 1 |
| 189 epo_lowest_tools_norm_prec.append(lowest_norm_topk) | |
| 190 mean_precision = np.mean(precision, axis=0) | 200 mean_precision = np.mean(precision, axis=0) |
| 191 mean_usage = np.mean(usage_weights, axis=0) | 201 mean_usage = np.mean(usage_weights, axis=0) |
| 192 mean_pub_prec = np.mean(epo_pub_prec, axis=0) | 202 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) |
| 193 mean_lowest_pub_prec = np.mean(epo_lowest_tools_pub_prec, axis=0) | 203 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) |
| 194 mean_lowest_norm_prec = np.mean(epo_lowest_tools_norm_prec, axis=0) | 204 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) |
| 195 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, len(epo_lowest_tools_pub_prec) | 205 return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter |
| 196 | 206 |
| 197 | 207 |
| 198 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): | 208 def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): |
| 199 # save files | 209 # save files |
| 200 trained_model = results["model"] | 210 trained_model = results["model"] |
