Mercurial > repos > bgruening > plotly_ml_performance_plots
diff plot_ml_performance.py @ 3:e73eb091612b draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author | bgruening |
---|---|
date | Wed, 07 Aug 2024 10:20:05 +0000 |
parents | 2cfa4aabda3e |
children |
line wrap: on
line diff
--- a/plot_ml_performance.py Thu Jan 16 18:49:18 2020 +0000 +++ b/plot_ml_performance.py Wed Aug 07 10:20:05 2024 +0000 @@ -1,9 +1,17 @@ import argparse + +import matplotlib.pyplot as plt import pandas as pd import plotly -import pickle import plotly.graph_objs as go -from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc +from galaxy_ml.model_persist import load_model_from_h5 +from galaxy_ml.utils import clean_params +from sklearn.metrics import ( + auc, + confusion_matrix, + precision_recall_fscore_support, + roc_curve, +) from sklearn.preprocessing import label_binarize @@ -13,61 +21,51 @@ Args: infile_input: str, input tabular file with true labels infile_output: str, input tabular file with predicted labels - infile_trained_model: str, input trained model file (zip) + infile_trained_model: str, input trained model file (h5mlm) """ - df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) - df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) + df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) + df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) true_labels = df_input.iloc[:, -1].copy() predicted_labels = df_output.iloc[:, -1].copy() axis_labels = list(set(true_labels)) c_matrix = confusion_matrix(true_labels, predicted_labels) - data = [ - go.Heatmap( - z=c_matrix, - x=axis_labels, - y=axis_labels, - colorscale='Portland', - ) - ] + fig, ax = plt.subplots(figsize=(7, 7)) + im = plt.imshow(c_matrix, cmap="viridis") + # add number of samples to each cell of confusion matrix plot + for i in range(len(c_matrix)): + for j in range(len(c_matrix)): + ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") + ax.set_ylabel("True class labels") + ax.set_xlabel("Predicted class labels") + ax.set_title("Confusion Matrix between true and predicted class labels") + ax.set_xticks(axis_labels) + ax.set_yticks(axis_labels) + fig.colorbar(im, ax=ax) + fig.tight_layout() + plt.savefig("output_confusion.png", dpi=120) - layout = go.Layout( - title='Confusion Matrix between true and predicted class labels', - xaxis=dict(title='Predicted class labels'), - yaxis=dict(title='True class labels') + # plot precision, recall and f_score for each class label + precision, recall, f_score, _ = precision_recall_fscore_support( + true_labels, predicted_labels ) - fig = go.Figure(data=data, layout=layout) - plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) - - # plot precision, recall and f_score for each class label - precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) - trace_precision = go.Scatter( - x=axis_labels, - y=precision, - mode='lines+markers', - name='Precision' + x=axis_labels, y=precision, mode="lines+markers", name="Precision" ) trace_recall = go.Scatter( - x=axis_labels, - y=recall, - mode='lines+markers', - name='Recall' + x=axis_labels, y=recall, mode="lines+markers", name="Recall" ) trace_fscore = go.Scatter( - x=axis_labels, - y=f_score, - mode='lines+markers', - name='F-score' + x=axis_labels, y=f_score, mode="lines+markers", name="F-score" ) layout_prf = go.Layout( - title='Precision, recall and f-score of true and predicted class labels', - xaxis=dict(title='Class labels'), - yaxis=dict(title='Precision, recall and f-score') + title="Precision, recall and f-score of true and predicted class labels", + xaxis=dict(title="Class labels"), + yaxis=dict(title="Precision, recall and f-score"), ) data_prf = [trace_precision, trace_recall, trace_fscore] @@ -75,8 +73,8 @@ plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) # plot roc and auc curves for different classes - with open(infile_trained_model, 'rb') as model_file: - model = pickle.load(model_file) + classifier_object = load_model_from_h5(infile_trained_model) + model = clean_params(classifier_object) # remove the last column (label column) test_data = df_input.iloc[:, :-1] @@ -84,9 +82,9 @@ try: # find the probability estimating method - if 'predict_proba' in model_items: + if "predict_proba" in model_items: y_score = model.predict_proba(test_data) - elif 'decision_function' in model_items: + elif "decision_function" in model_items: y_score = model.decision_function(test_data) true_labels_list = true_labels.tolist() @@ -104,43 +102,44 @@ trace = go.Scatter( x=fpr[i], y=tpr[i], - mode='lines+markers', - name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) + mode="lines+markers", + name="ROC curve of class {0} (AUC = {1:0.2f})".format( + i, roc_auc[i] + ), ) data_roc.append(trace) else: try: y_score_binary = y_score[:, 1] - except: + except Exception: y_score_binary = y_score fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) roc_auc = auc(fpr, tpr) trace = go.Scatter( x=fpr, y=tpr, - mode='lines+markers', - name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) + mode="lines+markers", + name="ROC curve (AUC = {0:0.2f})".format(roc_auc), ) data_roc.append(trace) - trace_diag = go.Scatter( - x=[0, 1], - y=[0, 1], - mode='lines', - name='Chance' - ) + trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance") data_roc.append(trace_diag) layout_roc = go.Layout( - title='Receiver operating characteristics (ROC) and area under curve (AUC)', - xaxis=dict(title='False positive rate'), - yaxis=dict(title='True positive rate') + title="Receiver operating characteristics (ROC) and area under curve (AUC)", + xaxis=dict(title="False positive rate"), + yaxis=dict(title="True positive rate"), ) fig_roc = go.Figure(data=data_roc, layout=layout_roc) plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) except Exception as exp: - print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp)) + print( + "Plotting the ROC-AUC graph failed. This exception was raised: {}".format( + exp + ) + ) if __name__ == "__main__":