diff ml_visualization_ex.py @ 23:bc3b489825b2 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author bgruening
date Mon, 02 Oct 2023 07:59:32 +0000
parents 006db575e1f3
children
line wrap: on
line diff
--- a/ml_visualization_ex.py	Thu Aug 11 07:41:31 2022 +0000
+++ b/ml_visualization_ex.py	Mon Oct 02 07:59:32 2023 +0000
@@ -9,13 +9,19 @@
 import pandas as pd
 import plotly
 import plotly.graph_objs as go
-from galaxy_ml.utils import load_model, read_columns, SafeEval
-from keras.models import model_from_json
-from keras.utils import plot_model
-from sklearn.feature_selection.base import SelectorMixin
-from sklearn.metrics import (auc, average_precision_score, confusion_matrix,
-                             precision_recall_curve, roc_curve)
+from galaxy_ml.model_persist import load_model_from_h5
+from galaxy_ml.utils import read_columns, SafeEval
+from sklearn.feature_selection._base import SelectorMixin
+from sklearn.metrics import (
+    auc,
+    average_precision_score,
+    confusion_matrix,
+    precision_recall_curve,
+    roc_curve,
+)
 from sklearn.pipeline import Pipeline
+from tensorflow.keras.models import model_from_json
+from tensorflow.keras.utils import plot_model
 
 safe_eval = SafeEval()
 
@@ -357,8 +363,7 @@
     plot_format = params["plotting_selection"]["plot_format"]
 
     if plot_type == "feature_importances":
-        with open(infile_estimator, "rb") as estimator_handler:
-            estimator = load_model(estimator_handler)
+        estimator = load_model_from_h5(infile_estimator)
 
         column_option = params["plotting_selection"]["column_selector_options"][
             "selected_column_selector_option"