Mercurial > repos > bgruening > cleanlab
diff cleanlab_issue_handler.py @ 0:ecc18228c32e draft default tip
planemo upload for repository https://github.com/cleanlab/cleanlab commit ac4753a61ee908bc2a5953b6c6d38d2bbbacc6c0
| author | bgruening |
|---|---|
| date | Wed, 28 May 2025 11:30:39 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/cleanlab_issue_handler.py Wed May 28 11:30:39 2025 +0000 @@ -0,0 +1,187 @@ +import argparse + +import numpy as np +import pandas as pd +from cleanlab.datalab.datalab import Datalab +from cleanlab.regression.rank import get_label_quality_scores +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import cross_val_predict, KFold, StratifiedKFold +from xgboost import XGBClassifier + +# ------------------- +# Issue Handler +# ------------------- + + +class IssueHandler: + def __init__(self, dataset, task, target_column, n_splits=3, quality_threshold=0.2): + self.dataset = dataset + self.task = task + self.target_column = target_column + self.n_splits = n_splits + self.quality_threshold = quality_threshold + self.issues = None + self.features = self.dataset.drop(target_column, axis=1).columns.tolist() + self.issue_summary = None + self.pred_probs = None + + def report_issues(self): + X = self.dataset.drop(self.target_column, axis=1) + y = self.dataset[self.target_column] + + # Ensure compatibility with Galaxy + X = X.to_numpy() if hasattr(X, 'to_numpy') else np.asarray(X) + y = y.to_numpy() if hasattr(y, 'to_numpy') else np.asarray(y) + + if self.task == 'classification': + model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42) + cv = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=42) + self.pred_probs = cross_val_predict(model, X, y, cv=cv, method='predict_proba') + + lab = Datalab(self.dataset, label_name=self.target_column) + lab.find_issues(pred_probs=self.pred_probs) + self.issues = lab.get_issues() + self.issue_summary = lab.get_issue_summary() + print(self.issue_summary) + + elif self.task == 'regression': + model = LinearRegression() + cv = KFold(n_splits=self.n_splits, shuffle=True, random_state=42) + pred_y = cross_val_predict(model, X, y, cv=cv, method='predict') + scores = get_label_quality_scores(y, pred_y, method='residual') + is_low_quality = scores < self.quality_threshold + self.issues = pd.DataFrame({ + 'label_quality': scores, + 'is_low_quality': is_low_quality + }) + self.issue_summary = { + 'quality_threshold': self.quality_threshold, + 'num_low_quality': int(is_low_quality.sum()), + 'mean_label_quality': float(np.mean(scores)), + 'median_label_quality': float(np.median(scores)), + 'min_label_quality': float(np.min(scores)), + 'max_label_quality': float(np.max(scores)), + } + print("Regression Issue Summary:") + for k, v in self.issue_summary.items(): + print(f"{k.replace('_', ' ').capitalize()}: {v:.4f}" if isinstance(v, float) else f"{k.replace('_', ' ').capitalize()}: {v}") + + return self.dataset.copy(), self.issues.copy(), self.issue_summary + + def clean_selected_issues(self, method='remove', label_issues=True, outliers=True, near_duplicates=True, non_iid=True): + if self.issues is None: + raise RuntimeError("Must run report_issues() before cleaning.") + + if self.task == 'regression': + clean_mask = self.issues['is_low_quality'].fillna(False) + else: + clean_mask = pd.Series([False] * len(self.dataset)) + for issue_type, use_flag in [ + ('is_label_issue', label_issues), + ('is_outlier_issue', outliers), + ('is_near_duplicate_issue', near_duplicates), + ('is_non_iid_issue', non_iid) + ]: + if use_flag and issue_type in self.issues.columns: + clean_mask |= self.issues[issue_type].fillna(False) + + if method == 'remove': + return self.dataset[~clean_mask].copy() + + elif method == 'replace' and self.task == 'classification': + most_likely = np.argmax(self.pred_probs, axis=1) + fixed = self.dataset.copy() + to_fix = self.issues['is_label_issue'] & label_issues + fixed.loc[to_fix, self.target_column] = most_likely[to_fix] + return fixed + + elif method == 'replace' and self.task == 'regression': + raise NotImplementedError("Replace method not implemented for regression label correction.") + + else: + raise ValueError("Invalid method or unsupported combination.") + +# ------------------- +# Main CLI Entry +# ------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Cleanlab Issue Handler CLI") + parser.add_argument("--input_file", nargs=2, required=True, metavar=('FILE', 'EXT'), help="Input file path and its extension") + parser.add_argument("--task", required=True, choices=["classification", "regression"], help="Type of ML task") + parser.add_argument("--target_column", default="target", help="Name of the target column") + parser.add_argument("--method", default="remove", choices=["remove", "replace"], help="Cleaning method") + parser.add_argument("--summary", action="store_true", help="Print and save issue summary only, no cleaning") + parser.add_argument("--no-label-issues", action="store_true", help="Exclude label issues from cleaning") + parser.add_argument("--no-outliers", action="store_true", help="Exclude outlier issues from cleaning") + parser.add_argument("--no-near-duplicates", action="store_true", help="Exclude near-duplicate issues from cleaning") + parser.add_argument("--no-non-iid", action="store_true", help="Exclude non-i.i.d. issues from cleaning") + parser.add_argument('--quality-threshold', type=float, default=0.2, help='Threshold for low-quality labels (regression only)') + + args = parser.parse_args() + + # Load dataset based on file extension + file_path, file_ext = args.input_file + file_ext = file_ext.lower() + + print(f"Loading dataset from: {file_path} with extension: {file_ext}") + + if file_ext == "csv": + df = pd.read_csv(file_path) + elif file_ext in ["tsv", "tabular"]: + df = pd.read_csv(file_path, sep="\t") + else: + raise ValueError(f"Unsupported file format: {file_ext}") + + # Run IssueHandler + handler = IssueHandler(dataset=df, + task=args.task, + target_column=args.target_column, + quality_threshold=args.quality_threshold) + _, issues, summary = handler.report_issues() + + # Save summary + if summary is not None: + with open("summary.txt", "w") as f: + if args.task == "regression": + f.write("Regression Issue Summary:\n") + for k, v in summary.items(): + text = f"{k.replace('_', ' ').capitalize()}: {v:.4f}" if isinstance(v, float) else f"{k.replace('_', ' ').capitalize()}: {v}" + f.write(text + "\n") + else: + f.write(str(summary)) + print("Issue summary saved to: summary.txt") + + if args.summary: + return + + # Clean selected issues + cleaned_df = handler.clean_selected_issues( + method=args.method, + label_issues=not args.no_label_issues, + outliers=not args.no_outliers, + near_duplicates=not args.no_near_duplicates, + non_iid=not args.no_non_iid + ) + + print(f"Cleaned dataset shape: {cleaned_df.shape}") + print(f"Original dataset shape: {df.shape}") + + output_filename = "cleaned_data" + if file_ext == "csv": + cleaned_df.to_csv(output_filename, index=False) + elif file_ext in ["tsv", "tabular"]: + cleaned_df.to_csv(output_filename, sep="\t", index=False) + else: + raise ValueError(f"Unsupported output format: {file_ext}") + + print(f"Cleaned dataset saved to: {output_filename}") + +# ------------------- +# Entry point +# ------------------- + + +if __name__ == "__main__": + main()
