Mercurial > repos > bgruening > cleanlab
comparison cleanlab_issue_handler.py @ 0:ecc18228c32e draft default tip
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
| author | bgruening |
|---|---|
| date | Wed, 28 May 2025 11:30:39 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:ecc18228c32e |
|---|---|
| 1 import argparse | |
| 2 | |
| 3 import numpy as np | |
| 4 import pandas as pd | |
| 5 from cleanlab.datalab.datalab import Datalab | |
| 6 from cleanlab.regression.rank import get_label_quality_scores | |
| 7 from sklearn.linear_model import LinearRegression | |
| 8 from sklearn.model_selection import cross_val_predict, KFold, StratifiedKFold | |
| 9 from xgboost import XGBClassifier | |
| 10 | |
| 11 # ------------------- | |
| 12 # Issue Handler | |
| 13 # ------------------- | |
| 14 | |
| 15 | |
| 16 class IssueHandler: | |
| 17 def __init__(self, dataset, task, target_column, n_splits=3, quality_threshold=0.2): | |
| 18 self.dataset = dataset | |
| 19 self.task = task | |
| 20 self.target_column = target_column | |
| 21 self.n_splits = n_splits | |
| 22 self.quality_threshold = quality_threshold | |
| 23 self.issues = None | |
| 24 self.features = self.dataset.drop(target_column, axis=1).columns.tolist() | |
| 25 self.issue_summary = None | |
| 26 self.pred_probs = None | |
| 27 | |
| 28 def report_issues(self): | |
| 29 X = self.dataset.drop(self.target_column, axis=1) | |
| 30 y = self.dataset[self.target_column] | |
| 31 | |
| 32 # Ensure compatibility with Galaxy | |
| 33 X = X.to_numpy() if hasattr(X, 'to_numpy') else np.asarray(X) | |
| 34 y = y.to_numpy() if hasattr(y, 'to_numpy') else np.asarray(y) | |
| 35 | |
| 36 if self.task == 'classification': | |
| 37 model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42) | |
| 38 cv = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=42) | |
| 39 self.pred_probs = cross_val_predict(model, X, y, cv=cv, method='predict_proba') | |
| 40 | |
| 41 lab = Datalab(self.dataset, label_name=self.target_column) | |
| 42 lab.find_issues(pred_probs=self.pred_probs) | |
| 43 self.issues = lab.get_issues() | |
| 44 self.issue_summary = lab.get_issue_summary() | |
| 45 print(self.issue_summary) | |
| 46 | |
| 47 elif self.task == 'regression': | |
| 48 model = LinearRegression() | |
| 49 cv = KFold(n_splits=self.n_splits, shuffle=True, random_state=42) | |
| 50 pred_y = cross_val_predict(model, X, y, cv=cv, method='predict') | |
| 51 scores = get_label_quality_scores(y, pred_y, method='residual') | |
| 52 is_low_quality = scores < self.quality_threshold | |
| 53 self.issues = pd.DataFrame({ | |
| 54 'label_quality': scores, | |
| 55 'is_low_quality': is_low_quality | |
| 56 }) | |
| 57 self.issue_summary = { | |
| 58 'quality_threshold': self.quality_threshold, | |
| 59 'num_low_quality': int(is_low_quality.sum()), | |
| 60 'mean_label_quality': float(np.mean(scores)), | |
| 61 'median_label_quality': float(np.median(scores)), | |
| 62 'min_label_quality': float(np.min(scores)), | |
| 63 'max_label_quality': float(np.max(scores)), | |
| 64 } | |
| 65 print("Regression Issue Summary:") | |
| 66 for k, v in self.issue_summary.items(): | |
| 67 print(f"{k.replace('_', ' ').capitalize()}: {v:.4f}" if isinstance(v, float) else f"{k.replace('_', ' ').capitalize()}: {v}") | |
| 68 | |
| 69 return self.dataset.copy(), self.issues.copy(), self.issue_summary | |
| 70 | |
| 71 def clean_selected_issues(self, method='remove', label_issues=True, outliers=True, near_duplicates=True, non_iid=True): | |
| 72 if self.issues is None: | |
| 73 raise RuntimeError("Must run report_issues() before cleaning.") | |
| 74 | |
| 75 if self.task == 'regression': | |
| 76 clean_mask = self.issues['is_low_quality'].fillna(False) | |
| 77 else: | |
| 78 clean_mask = pd.Series([False] * len(self.dataset)) | |
| 79 for issue_type, use_flag in [ | |
| 80 ('is_label_issue', label_issues), | |
| 81 ('is_outlier_issue', outliers), | |
| 82 ('is_near_duplicate_issue', near_duplicates), | |
| 83 ('is_non_iid_issue', non_iid) | |
| 84 ]: | |
| 85 if use_flag and issue_type in self.issues.columns: | |
| 86 clean_mask |= self.issues[issue_type].fillna(False) | |
| 87 | |
| 88 if method == 'remove': | |
| 89 return self.dataset[~clean_mask].copy() | |
| 90 | |
| 91 elif method == 'replace' and self.task == 'classification': | |
| 92 most_likely = np.argmax(self.pred_probs, axis=1) | |
| 93 fixed = self.dataset.copy() | |
| 94 to_fix = self.issues['is_label_issue'] & label_issues | |
| 95 fixed.loc[to_fix, self.target_column] = most_likely[to_fix] | |
| 96 return fixed | |
| 97 | |
| 98 elif method == 'replace' and self.task == 'regression': | |
| 99 raise NotImplementedError("Replace method not implemented for regression label correction.") | |
| 100 | |
| 101 else: | |
| 102 raise ValueError("Invalid method or unsupported combination.") | |
| 103 | |
| 104 # ------------------- | |
| 105 # Main CLI Entry | |
| 106 # ------------------- | |
| 107 | |
| 108 | |
| 109 def main(): | |
| 110 parser = argparse.ArgumentParser(description="Cleanlab Issue Handler CLI") | |
| 111 parser.add_argument("--input_file", nargs=2, required=True, metavar=('FILE', 'EXT'), help="Input file path and its extension") | |
| 112 parser.add_argument("--task", required=True, choices=["classification", "regression"], help="Type of ML task") | |
| 113 parser.add_argument("--target_column", default="target", help="Name of the target column") | |
| 114 parser.add_argument("--method", default="remove", choices=["remove", "replace"], help="Cleaning method") | |
| 115 parser.add_argument("--summary", action="store_true", help="Print and save issue summary only, no cleaning") | |
| 116 parser.add_argument("--no-label-issues", action="store_true", help="Exclude label issues from cleaning") | |
| 117 parser.add_argument("--no-outliers", action="store_true", help="Exclude outlier issues from cleaning") | |
| 118 parser.add_argument("--no-near-duplicates", action="store_true", help="Exclude near-duplicate issues from cleaning") | |
| 119 parser.add_argument("--no-non-iid", action="store_true", help="Exclude non-i.i.d. issues from cleaning") | |
| 120 parser.add_argument('--quality-threshold', type=float, default=0.2, help='Threshold for low-quality labels (regression only)') | |
| 121 | |
| 122 args = parser.parse_args() | |
| 123 | |
| 124 # Load dataset based on file extension | |
| 125 file_path, file_ext = args.input_file | |
| 126 file_ext = file_ext.lower() | |
| 127 | |
| 128 print(f"Loading dataset from: {file_path} with extension: {file_ext}") | |
| 129 | |
| 130 if file_ext == "csv": | |
| 131 df = pd.read_csv(file_path) | |
| 132 elif file_ext in ["tsv", "tabular"]: | |
| 133 df = pd.read_csv(file_path, sep="\t") | |
| 134 else: | |
| 135 raise ValueError(f"Unsupported file format: {file_ext}") | |
| 136 | |
| 137 # Run IssueHandler | |
| 138 handler = IssueHandler(dataset=df, | |
| 139 task=args.task, | |
| 140 target_column=args.target_column, | |
| 141 quality_threshold=args.quality_threshold) | |
| 142 _, issues, summary = handler.report_issues() | |
| 143 | |
| 144 # Save summary | |
| 145 if summary is not None: | |
| 146 with open("summary.txt", "w") as f: | |
| 147 if args.task == "regression": | |
| 148 f.write("Regression Issue Summary:\n") | |
| 149 for k, v in summary.items(): | |
| 150 text = f"{k.replace('_', ' ').capitalize()}: {v:.4f}" if isinstance(v, float) else f"{k.replace('_', ' ').capitalize()}: {v}" | |
| 151 f.write(text + "\n") | |
| 152 else: | |
| 153 f.write(str(summary)) | |
| 154 print("Issue summary saved to: summary.txt") | |
| 155 | |
| 156 if args.summary: | |
| 157 return | |
| 158 | |
| 159 # Clean selected issues | |
| 160 cleaned_df = handler.clean_selected_issues( | |
| 161 method=args.method, | |
| 162 label_issues=not args.no_label_issues, | |
| 163 outliers=not args.no_outliers, | |
| 164 near_duplicates=not args.no_near_duplicates, | |
| 165 non_iid=not args.no_non_iid | |
| 166 ) | |
| 167 | |
| 168 print(f"Cleaned dataset shape: {cleaned_df.shape}") | |
| 169 print(f"Original dataset shape: {df.shape}") | |
| 170 | |
| 171 output_filename = "cleaned_data" | |
| 172 if file_ext == "csv": | |
| 173 cleaned_df.to_csv(output_filename, index=False) | |
| 174 elif file_ext in ["tsv", "tabular"]: | |
| 175 cleaned_df.to_csv(output_filename, sep="\t", index=False) | |
| 176 else: | |
| 177 raise ValueError(f"Unsupported output format: {file_ext}") | |
| 178 | |
| 179 print(f"Cleaned dataset saved to: {output_filename}") | |
| 180 | |
| 181 # ------------------- | |
| 182 # Entry point | |
| 183 # ------------------- | |
| 184 | |
| 185 | |
| 186 if __name__ == "__main__": | |
| 187 main() |
