diff plot_ml_performance.py @ 3:e73eb091612b draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author bgruening
date Wed, 07 Aug 2024 10:20:05 +0000
parents 2cfa4aabda3e
children
line wrap: on
line diff
--- a/plot_ml_performance.py	Thu Jan 16 18:49:18 2020 +0000
+++ b/plot_ml_performance.py	Wed Aug 07 10:20:05 2024 +0000
@@ -1,9 +1,17 @@
 import argparse
+
+import matplotlib.pyplot as plt
 import pandas as pd
 import plotly
-import pickle
 import plotly.graph_objs as go
-from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc
+from galaxy_ml.model_persist import load_model_from_h5
+from galaxy_ml.utils import clean_params
+from sklearn.metrics import (
+    auc,
+    confusion_matrix,
+    precision_recall_fscore_support,
+    roc_curve,
+)
 from sklearn.preprocessing import label_binarize
 
 
@@ -13,61 +21,51 @@
     Args:
         infile_input: str, input tabular file with true labels
         infile_output: str, input tabular file with predicted labels
-        infile_trained_model: str, input trained model file (zip)
+        infile_trained_model: str, input trained model file (h5mlm)
     """
 
-    df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True)
-    df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True)
+    df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
+    df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True)
     true_labels = df_input.iloc[:, -1].copy()
     predicted_labels = df_output.iloc[:, -1].copy()
     axis_labels = list(set(true_labels))
     c_matrix = confusion_matrix(true_labels, predicted_labels)
-    data = [
-        go.Heatmap(
-            z=c_matrix,
-            x=axis_labels,
-            y=axis_labels,
-            colorscale='Portland',
-        )
-    ]
+    fig, ax = plt.subplots(figsize=(7, 7))
+    im = plt.imshow(c_matrix, cmap="viridis")
+    # add number of samples to each cell of confusion matrix plot
+    for i in range(len(c_matrix)):
+        for j in range(len(c_matrix)):
+            ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
+    ax.set_ylabel("True class labels")
+    ax.set_xlabel("Predicted class labels")
+    ax.set_title("Confusion Matrix between true and predicted class labels")
+    ax.set_xticks(axis_labels)
+    ax.set_yticks(axis_labels)
+    fig.colorbar(im, ax=ax)
+    fig.tight_layout()
+    plt.savefig("output_confusion.png", dpi=120)
 
-    layout = go.Layout(
-        title='Confusion Matrix between true and predicted class labels',
-        xaxis=dict(title='Predicted class labels'),
-        yaxis=dict(title='True class labels')
+    # plot precision, recall and f_score for each class label
+    precision, recall, f_score, _ = precision_recall_fscore_support(
+        true_labels, predicted_labels
     )
 
-    fig = go.Figure(data=data, layout=layout)
-    plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
-
-    # plot precision, recall and f_score for each class label
-    precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels)
-
     trace_precision = go.Scatter(
-        x=axis_labels,
-        y=precision,
-        mode='lines+markers',
-        name='Precision'
+        x=axis_labels, y=precision, mode="lines+markers", name="Precision"
     )
 
     trace_recall = go.Scatter(
-        x=axis_labels,
-        y=recall,
-        mode='lines+markers',
-        name='Recall'
+        x=axis_labels, y=recall, mode="lines+markers", name="Recall"
     )
 
     trace_fscore = go.Scatter(
-        x=axis_labels,
-        y=f_score,
-        mode='lines+markers',
-        name='F-score'
+        x=axis_labels, y=f_score, mode="lines+markers", name="F-score"
     )
 
     layout_prf = go.Layout(
-        title='Precision, recall and f-score of true and predicted class labels',
-        xaxis=dict(title='Class labels'),
-        yaxis=dict(title='Precision, recall and f-score')
+        title="Precision, recall and f-score of true and predicted class labels",
+        xaxis=dict(title="Class labels"),
+        yaxis=dict(title="Precision, recall and f-score"),
     )
 
     data_prf = [trace_precision, trace_recall, trace_fscore]
@@ -75,8 +73,8 @@
     plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False)
 
     # plot roc and auc curves for different classes
-    with open(infile_trained_model, 'rb') as model_file:
-        model = pickle.load(model_file)
+    classifier_object = load_model_from_h5(infile_trained_model)
+    model = clean_params(classifier_object)
 
     # remove the last column (label column)
     test_data = df_input.iloc[:, :-1]
@@ -84,9 +82,9 @@
 
     try:
         # find the probability estimating method
-        if 'predict_proba' in model_items:
+        if "predict_proba" in model_items:
             y_score = model.predict_proba(test_data)
-        elif 'decision_function' in model_items:
+        elif "decision_function" in model_items:
             y_score = model.decision_function(test_data)
 
         true_labels_list = true_labels.tolist()
@@ -104,43 +102,44 @@
                 trace = go.Scatter(
                     x=fpr[i],
                     y=tpr[i],
-                    mode='lines+markers',
-                    name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i])
+                    mode="lines+markers",
+                    name="ROC curve of class {0} (AUC = {1:0.2f})".format(
+                        i, roc_auc[i]
+                    ),
                 )
                 data_roc.append(trace)
         else:
             try:
                 y_score_binary = y_score[:, 1]
-            except:
+            except Exception:
                 y_score_binary = y_score
             fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1)
             roc_auc = auc(fpr, tpr)
             trace = go.Scatter(
                 x=fpr,
                 y=tpr,
-                mode='lines+markers',
-                name='ROC curve (AUC = {0:0.2f})'.format(roc_auc)
+                mode="lines+markers",
+                name="ROC curve (AUC = {0:0.2f})".format(roc_auc),
             )
             data_roc.append(trace)
 
-        trace_diag = go.Scatter(
-            x=[0, 1],
-            y=[0, 1],
-            mode='lines',
-            name='Chance'
-        )
+        trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance")
         data_roc.append(trace_diag)
         layout_roc = go.Layout(
-            title='Receiver operating characteristics (ROC) and area under curve (AUC)',
-            xaxis=dict(title='False positive rate'),
-            yaxis=dict(title='True positive rate')
+            title="Receiver operating characteristics (ROC) and area under curve (AUC)",
+            xaxis=dict(title="False positive rate"),
+            yaxis=dict(title="True positive rate"),
         )
 
         fig_roc = go.Figure(data=data_roc, layout=layout_roc)
         plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False)
 
     except Exception as exp:
-        print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp))
+        print(
+            "Plotting the ROC-AUC graph failed. This exception was raised: {}".format(
+                exp
+            )
+        )
 
 
 if __name__ == "__main__":