Mercurial > repos > bgruening > cleanlab
annotate cleanlab_issue_handler.py @ 0:ecc18228c32e draft default tip
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
| author | bgruening |
|---|---|
| date | Wed, 28 May 2025 11:30:39 +0000 |
| parents | |
| children |
| rev | line source |
|---|---|
|
0
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
1 import argparse |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
2 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
3 import numpy as np |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
4 import pandas as pd |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
5 from cleanlab.datalab.datalab import Datalab |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
6 from cleanlab.regression.rank import get_label_quality_scores |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
7 from sklearn.linear_model import LinearRegression |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
8 from sklearn.model_selection import cross_val_predict, KFold, StratifiedKFold |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
9 from xgboost import XGBClassifier |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
10 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
11 # ------------------- |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
12 # Issue Handler |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
13 # ------------------- |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
14 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
15 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
16 class IssueHandler: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
17 def __init__(self, dataset, task, target_column, n_splits=3, quality_threshold=0.2): |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
18 self.dataset = dataset |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
19 self.task = task |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
20 self.target_column = target_column |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
21 self.n_splits = n_splits |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
22 self.quality_threshold = quality_threshold |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
23 self.issues = None |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
24 self.features = self.dataset.drop(target_column, axis=1).columns.tolist() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
25 self.issue_summary = None |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
26 self.pred_probs = None |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
27 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
28 def report_issues(self): |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
29 X = self.dataset.drop(self.target_column, axis=1) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
30 y = self.dataset[self.target_column] |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
31 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
32 # Ensure compatibility with Galaxy |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
33 X = X.to_numpy() if hasattr(X, 'to_numpy') else np.asarray(X) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
34 y = y.to_numpy() if hasattr(y, 'to_numpy') else np.asarray(y) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
35 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
36 if self.task == 'classification': |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
37 model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
38 cv = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=42) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
39 self.pred_probs = cross_val_predict(model, X, y, cv=cv, method='predict_proba') |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
40 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
41 lab = Datalab(self.dataset, label_name=self.target_column) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
42 lab.find_issues(pred_probs=self.pred_probs) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
43 self.issues = lab.get_issues() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
44 self.issue_summary = lab.get_issue_summary() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
45 print(self.issue_summary) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
46 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
47 elif self.task == 'regression': |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
48 model = LinearRegression() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
49 cv = KFold(n_splits=self.n_splits, shuffle=True, random_state=42) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
50 pred_y = cross_val_predict(model, X, y, cv=cv, method='predict') |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
51 scores = get_label_quality_scores(y, pred_y, method='residual') |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
52 is_low_quality = scores < self.quality_threshold |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
53 self.issues = pd.DataFrame({ |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
54 'label_quality': scores, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
55 'is_low_quality': is_low_quality |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
56 }) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
57 self.issue_summary = { |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
58 'quality_threshold': self.quality_threshold, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
59 'num_low_quality': int(is_low_quality.sum()), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
60 'mean_label_quality': float(np.mean(scores)), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
61 'median_label_quality': float(np.median(scores)), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
62 'min_label_quality': float(np.min(scores)), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
63 'max_label_quality': float(np.max(scores)), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
64 } |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
65 print("Regression Issue Summary:") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
66 for k, v in self.issue_summary.items(): |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
67 print(f"{k.replace('_', ' ').capitalize()}: {v:.4f}" if isinstance(v, float) else f"{k.replace('_', ' ').capitalize()}: {v}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
68 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
69 return self.dataset.copy(), self.issues.copy(), self.issue_summary |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
70 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
71 def clean_selected_issues(self, method='remove', label_issues=True, outliers=True, near_duplicates=True, non_iid=True): |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
72 if self.issues is None: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
73 raise RuntimeError("Must run report_issues() before cleaning.") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
74 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
75 if self.task == 'regression': |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
76 clean_mask = self.issues['is_low_quality'].fillna(False) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
77 else: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
78 clean_mask = pd.Series([False] * len(self.dataset)) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
79 for issue_type, use_flag in [ |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
80 ('is_label_issue', label_issues), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
81 ('is_outlier_issue', outliers), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
82 ('is_near_duplicate_issue', near_duplicates), |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
83 ('is_non_iid_issue', non_iid) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
84 ]: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
85 if use_flag and issue_type in self.issues.columns: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
86 clean_mask |= self.issues[issue_type].fillna(False) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
87 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
88 if method == 'remove': |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
89 return self.dataset[~clean_mask].copy() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
90 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
91 elif method == 'replace' and self.task == 'classification': |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
92 most_likely = np.argmax(self.pred_probs, axis=1) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
93 fixed = self.dataset.copy() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
94 to_fix = self.issues['is_label_issue'] & label_issues |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
95 fixed.loc[to_fix, self.target_column] = most_likely[to_fix] |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
96 return fixed |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
97 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
98 elif method == 'replace' and self.task == 'regression': |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
99 raise NotImplementedError("Replace method not implemented for regression label correction.") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
100 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
101 else: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
102 raise ValueError("Invalid method or unsupported combination.") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
103 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
104 # ------------------- |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
105 # Main CLI Entry |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
106 # ------------------- |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
107 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
108 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
109 def main(): |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
110 parser = argparse.ArgumentParser(description="Cleanlab Issue Handler CLI") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
111 parser.add_argument("--input_file", nargs=2, required=True, metavar=('FILE', 'EXT'), help="Input file path and its extension") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
112 parser.add_argument("--task", required=True, choices=["classification", "regression"], help="Type of ML task") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
113 parser.add_argument("--target_column", default="target", help="Name of the target column") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
114 parser.add_argument("--method", default="remove", choices=["remove", "replace"], help="Cleaning method") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
115 parser.add_argument("--summary", action="store_true", help="Print and save issue summary only, no cleaning") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
116 parser.add_argument("--no-label-issues", action="store_true", help="Exclude label issues from cleaning") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
117 parser.add_argument("--no-outliers", action="store_true", help="Exclude outlier issues from cleaning") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
118 parser.add_argument("--no-near-duplicates", action="store_true", help="Exclude near-duplicate issues from cleaning") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
119 parser.add_argument("--no-non-iid", action="store_true", help="Exclude non-i.i.d. issues from cleaning") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
120 parser.add_argument('--quality-threshold', type=float, default=0.2, help='Threshold for low-quality labels (regression only)') |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
121 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
122 args = parser.parse_args() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
123 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
124 # Load dataset based on file extension |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
125 file_path, file_ext = args.input_file |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
126 file_ext = file_ext.lower() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
127 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
128 print(f"Loading dataset from: {file_path} with extension: {file_ext}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
129 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
130 if file_ext == "csv": |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
131 df = pd.read_csv(file_path) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
132 elif file_ext in ["tsv", "tabular"]: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
133 df = pd.read_csv(file_path, sep="\t") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
134 else: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
135 raise ValueError(f"Unsupported file format: {file_ext}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
136 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
137 # Run IssueHandler |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
138 handler = IssueHandler(dataset=df, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
139 task=args.task, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
140 target_column=args.target_column, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
141 quality_threshold=args.quality_threshold) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
142 _, issues, summary = handler.report_issues() |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
143 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
144 # Save summary |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
145 if summary is not None: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
146 with open("summary.txt", "w") as f: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
147 if args.task == "regression": |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
148 f.write("Regression Issue Summary:\n") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
149 for k, v in summary.items(): |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
150 text = f"{k.replace('_', ' ').capitalize()}: {v:.4f}" if isinstance(v, float) else f"{k.replace('_', ' ').capitalize()}: {v}" |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
151 f.write(text + "\n") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
152 else: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
153 f.write(str(summary)) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
154 print("Issue summary saved to: summary.txt") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
155 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
156 if args.summary: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
157 return |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
158 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
159 # Clean selected issues |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
160 cleaned_df = handler.clean_selected_issues( |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
161 method=args.method, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
162 label_issues=not args.no_label_issues, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
163 outliers=not args.no_outliers, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
164 near_duplicates=not args.no_near_duplicates, |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
165 non_iid=not args.no_non_iid |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
166 ) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
167 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
168 print(f"Cleaned dataset shape: {cleaned_df.shape}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
169 print(f"Original dataset shape: {df.shape}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
170 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
171 output_filename = "cleaned_data" |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
172 if file_ext == "csv": |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
173 cleaned_df.to_csv(output_filename, index=False) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
174 elif file_ext in ["tsv", "tabular"]: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
175 cleaned_df.to_csv(output_filename, sep="\t", index=False) |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
176 else: |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
177 raise ValueError(f"Unsupported output format: {file_ext}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
178 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
179 print(f"Cleaned dataset saved to: {output_filename}") |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
180 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
181 # ------------------- |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
182 # Entry point |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
183 # ------------------- |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
184 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
185 |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
186 if __name__ == "__main__": |
|
ecc18228c32e
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
bgruening
parents:
diff
changeset
|
187 main() |
