annotate main.py @ 5:7808193b5626 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
author bgruening
date Wed, 26 Mar 2025 16:32:35 +0000
parents 3957cd124013
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
1 """
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
2 Tabular data prediction using TabPFN
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
3 """
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
4 import argparse
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
5 import time
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
6
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
7 import matplotlib.pyplot as plt
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
8 import numpy as np
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
9 import pandas as pd
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
10 from sklearn.metrics import (
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
11 average_precision_score,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
12 precision_recall_curve,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
13 r2_score,
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
14 root_mean_squared_error,
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
15 )
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
16 from sklearn.preprocessing import label_binarize
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
17 from tabpfn import TabPFNClassifier, TabPFNRegressor
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
18
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
19
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
20 def separate_features_labels(data):
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
21 df = pd.read_csv(data, sep="\t")
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
22 labels = df.iloc[:, -1]
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
23 features = df.iloc[:, :-1]
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
24 return features, labels
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
25
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
26
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
27 def classification_plot(y_true, y_scores):
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
28 plt.figure(figsize=(8, 6))
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
29 is_binary = len(np.unique(y_true)) == 2
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
30 if is_binary:
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
31 # Compute precision-recall curve
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
32 precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
33 average_precision = average_precision_score(y_true, y_scores[:, 1])
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
34 plt.plot(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
35 recall,
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
36 precision,
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
37 label=f"Precision-Recall Curve (AP={average_precision:.2f})",
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
38 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
39 plt.title("Precision-Recall Curve (binary classification)")
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
40 else:
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
41 y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
42 n_classes = y_true_bin.shape[1]
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
43 class_labels = [f"Class {i}" for i in range(n_classes)]
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
44 # Plot PR curve for each class
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
45 for i in range(n_classes):
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
46 precision, recall, _ = precision_recall_curve(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
47 y_true_bin[:, i], y_scores[:, i]
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
48 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
49 ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
50 plt.plot(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
51 recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
52 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
53 # Compute micro-average PR curve
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
54 precision, recall, _ = precision_recall_curve(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
55 y_true_bin.ravel(), y_scores.ravel()
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
56 )
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
57 plt.plot(
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
58 recall, precision, linestyle="--", color="black", label="Micro-average"
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
59 )
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
60 plt.title(
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
61 "Precision-Recall Curve (Multiclass Classification)"
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
62 )
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
63 plt.xlabel("Recall")
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
64 plt.ylabel("Precision")
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
65 plt.legend(loc="lower left")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
66 plt.grid(True)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
67 plt.savefig("output_plot.png")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
68
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
69
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
70 def regression_plot(xval, yval, title, xlabel, ylabel):
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
71 plt.figure(figsize=(8, 6))
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
72 plt.xlabel(xlabel)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
73 plt.ylabel(ylabel)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
74 plt.title(title)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
75 plt.legend(loc="lower left")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
76 plt.grid(True)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
77 plt.scatter(xval, yval, alpha=0.8)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
78 xticks = np.arange(len(xval))
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
79 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
80 plt.savefig("output_plot.png")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
81
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
82
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
83 def train_evaluate(args):
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
84 """
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
85 Train TabPFN and predict
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
86 """
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
87 # prepare train data
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
88 tr_features, tr_labels = separate_features_labels(args["train_data"])
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
89 # prepare test data
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
90 if args["testhaslabels"] == "true":
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
91 te_features, te_labels = separate_features_labels(args["test_data"])
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
92 else:
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
93 te_features = pd.read_csv(args["test_data"], sep="\t")
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
94 te_labels = []
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
95 s_time = time.time()
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
96 if args["selected_task"] == "Classification":
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
97 classifier = TabPFNClassifier(random_state=42)
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
98 classifier.fit(tr_features, tr_labels)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
99 y_eval = classifier.predict(te_features)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
100 pred_probas_test = classifier.predict_proba(te_features)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
101 if len(te_labels) > 0:
4
3957cd124013 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
102 classification_plot(te_labels, pred_probas_test)
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
103 te_features["predicted_labels"] = y_eval
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
104 te_features.to_csv(
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
105 "output_predicted_data", sep="\t", index=None
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
106 )
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
107 else:
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
108 regressor = TabPFNRegressor(random_state=42)
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
109 regressor.fit(tr_features, tr_labels)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
110 y_eval = regressor.predict(te_features)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
111 if len(te_labels) > 0:
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
112 score = root_mean_squared_error(te_labels, y_eval)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
113 r2_metric_score = r2_score(te_labels, y_eval)
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
114 regression_plot(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
115 te_labels,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
116 y_eval,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
117 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
118 "True values",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
119 "Predicted values",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
120 )
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
121 te_features["predicted_labels"] = y_eval
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
122 te_features.to_csv(
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
123 "output_predicted_data", sep="\t", index=None
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
124 )
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
125 e_time = time.time()
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
126 print(
5
7808193b5626 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
127 f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds"
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
128 )
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
129
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
130
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
131 if __name__ == "__main__":
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
132 arg_parser = argparse.ArgumentParser()
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
133 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
134 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
2
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
135 arg_parser.add_argument(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
136 "-testhaslabels",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
137 "--testhaslabels",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
138 required=True,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
139 help="if test data contain labels",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
140 )
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
141 arg_parser.add_argument(
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
142 "-selectedtask",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
143 "--selected_task",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
144 required=True,
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
145 help="Type of machine learning task",
abe1c3ac9145 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
146 )
0
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
147 # get argument values
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
148 args = vars(arg_parser.parse_args())
2a0c6d2090f4 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
149 train_evaluate(args)