Mercurial > repos > bgruening > flexynesis
diff flexynesis_plot.py @ 3:525c661a7fdc draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit b2463fb68d0ae54864d87718ee72f5e063aa4587
author | bgruening |
---|---|
date | Tue, 24 Jun 2025 05:55:40 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/flexynesis_plot.py Tue Jun 24 05:55:40 2025 +0000 @@ -0,0 +1,1282 @@ +#!/usr/bin/env python +"""Generate plots using flexynesis +This script generates dimensionality reduction plots, Kaplan-Meier survival curves, +and Cox proportional hazards models from data processed by flexynesis.""" + +import argparse +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +from flexynesis import ( + build_cox_model, + get_important_features, + plot_dim_reduced, + plot_hazard_ratios, + plot_kaplan_meier_curves, + plot_pr_curves, + plot_roc_curves, + plot_scatter +) +from scipy.stats import kruskal, mannwhitneyu + + +def load_embeddings(embeddings_path): + """Load embeddings from a file""" + try: + # Determine file extension + file_ext = Path(embeddings_path).suffix.lower() + + if file_ext == '.csv': + df = pd.read_csv(embeddings_path, index_col=0) + elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: + df = pd.read_csv(embeddings_path, sep='\t', index_col=0) + else: + raise ValueError(f"Unsupported file extension: {file_ext}") + + return df, df.index.tolist() + + except Exception as e: + raise ValueError(f"Error loading embeddings from {embeddings_path}: {e}") from e + + +def load_labels(labels_input): + """Load predicted labels from flexynesis""" + try: + # Determine file extension + file_ext = Path(labels_input).suffix.lower() + + if file_ext == '.csv': + df = pd.read_csv(labels_input) + elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: + df = pd.read_csv(labels_input, sep='\t') + + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + if all(col in df.columns for col in required_cols): + return df + else: + raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}") + + except Exception as e: + raise ValueError(f"Error loading labels from {labels_input}: {e}") from e + + +def load_survival_data(survival_path): + """Load survival data from a file. First column should be sample_id""" + try: + # Determine file extension + file_ext = Path(survival_path).suffix.lower() + + if file_ext == '.csv': + df = pd.read_csv(survival_path, index_col=0) + elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: + df = pd.read_csv(survival_path, sep='\t', index_col=0) + else: + raise ValueError(f"Unsupported file extension: {file_ext}") + return df + + except Exception as e: + raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e + + +def load_omics(omics_path): + """Load omics data from a file. First column should be features""" + try: + # Determine file extension + file_ext = Path(omics_path).suffix.lower() + + if file_ext == '.csv': + df = pd.read_csv(omics_path, index_col=0) + elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: + df = pd.read_csv(omics_path, sep='\t', index_col=0) + else: + raise ValueError(f"Unsupported file extension: {file_ext}") + return df + + except Exception as e: + raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e + + +def load_model(model_path): + """Load flexynesis model from pickle file""" + try: + with open(model_path, 'rb') as f: + model = torch.load(f, weights_only=False) + return model + except Exception as e: + raise ValueError(f"Error loading model from {model_path}: {e}") from e + + +def match_samples_to_embeddings(sample_names, label_data): + """Filter label data to match sample names in the embeddings""" + df_matched = label_data[label_data['sample_id'].isin(sample_names)] + return df_matched + + +def detect_color_type(labels_series): + """Auto-detect whether target variables should be treated as categorical or numerical""" + # Remove NaN + clean_labels = labels_series.dropna() + + if clean_labels.empty: + return 'categorical' # default output if no labels + + # Check if all values can be converted to numbers + try: + numeric_labels = pd.to_numeric(clean_labels, errors='coerce') + + # If conversion failed -> categorical + if numeric_labels.isna().any(): + return 'categorical' + + # Check number of unique values + unique_count = len(clean_labels.unique()) + total_count = len(clean_labels) + + # If few unique values relative to total -> categorical + # Threshold: if unique values < 10 OR unique/total < 0.1 + if unique_count < 10 or (unique_count / total_count) < 0.1: + return 'categorical' + else: + return 'numerical' + + except Exception: + return 'categorical' + + +def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): + """ + Plot a heatmap reflecting the concordance between two sets of labels using pandas crosstab. + + Parameters: + - labels1: The first set of labels. + - labels2: The second set of labels. + """ + # Compute the cross-tabulation + ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2')) + # Normalize the cross-tabulation matrix column-wise + ct_normalized = ct.div(ct.sum(axis=1), axis=0) + + # Plot the heatmap + plt.figure(figsize=figsize) + sns.heatmap(ct_normalized, annot=True, cmap='viridis', linewidths=.5) # col_cluster=False) + plt.title('Concordance between label groups') + + return plt.gcf() + + +def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize=4): + """ + Create a boxplot with to visualize the distribution of predicted probabilities across different categories. + the x axis represents the true labels, and the y axis represents the predicted probabilities for specific categories. + """ + df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y}) + + # Compute p-value + groups = df[title_x].unique() + if len(groups) == 2: + group1 = df[df[title_x] == groups[0]][title_y] + group2 = df[df[title_x] == groups[1]][title_y] + stat, p = mannwhitneyu(group1, group2, alternative='two-sided') + test_name = "Mann-Whitney U" + else: + group_data = [df[df[title_x] == group][title_y] for group in groups] + stat, p = kruskal(*group_data) + test_name = "Kruskal-Wallis" + + # Create a boxplot with jittered points + plt.figure(figsize=figsize) + sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill=False) + sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4) + + # Labels and p-value annotation + plt.xlabel(title_x) + plt.ylabel(title_y) + plt.text( + x=-0.4, + y=plt.ylim()[1], + s=f'{test_name} p = {p:.3e}', + verticalalignment='top', + horizontalalignment='left', + fontsize=12, + bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray') + ) + + plt.tight_layout() + return plt.gcf() + + +def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): + """Generate dimensionality reduction plots""" + + # Parse target variables + target_vars = [var.strip() for var in args.target_variables.split(',')] + + print(f"Generating {args.method.upper()} plots for {len(target_vars)} target variable(s): {', '.join(target_vars)}") + + # Check variables + available_vars = matched_labels['variable'].unique() + missing_vars = [var for var in target_vars if var not in available_vars] + + if missing_vars: + print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") + print(f"Available variables: {', '.join(available_vars)}") + + # Filter to only process available variables + valid_vars = [var for var in target_vars if var in available_vars] + + if not valid_vars: + raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") + + # Generate plots for each valid target variable + for var in valid_vars: + print(f"\nPlotting variable: {var}") + + # Filter matched labels for current variable + var_labels = matched_labels[matched_labels['variable'] == var].copy() + var_labels = var_labels.drop_duplicates(subset='sample_id') + + if var_labels.empty: + print(f"Warning: No data found for variable '{var}', skipping...") + continue + + # Auto-detect color type + known_color_type = detect_color_type(var_labels['known_label']) + predicted_color_type = detect_color_type(var_labels['predicted_label']) + + print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") + + try: + # Plot 1: Known labels + print(f" Creating known labels plot for {var}...") + fig_known = plot_dim_reduced( + matrix=embeddings, + labels=var_labels['known_label'], + method=args.method, + color_type=known_color_type + ) + + output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" + print(f" Saving known labels plot to: {output_path_known.name}") + fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') + + # Plot 2: Predicted labels + print(f" Creating predicted labels plot for {var}...") + fig_predicted = plot_dim_reduced( + matrix=embeddings, + labels=var_labels['predicted_label'], + method=args.method, + color_type=predicted_color_type + ) + + output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" + print(f" Saving predicted labels plot to: {output_path_predicted.name}") + fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') + + print(f" ✓ Successfully created plots for variable '{var}'") + + except Exception as e: + print(f" ✗ Error creating plots for variable '{var}': {e}") + continue + + print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") + + +def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base): + """Generate Kaplan-Meier plots""" + print("Generating Kaplan-Meier curves of risk subtypes...") + + # Reset index and rename the index column to sample_id + survival_data = survival_data.reset_index() + if survival_data.columns[0] != 'sample_id': + survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) + + # Convert survival event column to binary (0/1) based on event_value + # Check if the event column exists + if args.surv_event_var not in survival_data.columns: + raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") + + # Convert to string for comparison to handle mixed types + survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str) + event_value_str = str(args.event_value) + + # Create binary event column (1 if matches event_value, 0 otherwise) + survival_data[f'{args.surv_event_var}_binary'] = ( + survival_data[args.surv_event_var] == event_value_str + ).astype(int) + + # Filter for survival category and class_label == '1:DECEASED' + label_data['class_label'] = label_data['class_label'].astype(str) + + label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)] + + # check survival data + for col in [args.surv_time_var, args.surv_event_var]: + if col not in survival_data.columns: + raise ValueError(f"Column '{col}' not found in survival data") + + # Merge survival data with labels + df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') + + if df_deceased.empty: + raise ValueError("No matching samples found after merging survival and label data.") + + # Get risk scores + risk_scores = df_deceased['probability'].values + + # Compute groups (e.g., median split) + quantiles = np.quantile(risk_scores, [0.5]) + groups = np.digitize(risk_scores, quantiles) + group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] + + fig_known = plot_kaplan_meier_curves( + durations=df_deceased[args.surv_time_var], + events=df_deceased[f'{args.surv_event_var}_binary'], + categorical_variable=group_labels + ) + + output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" + print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") + fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') + + print("Kaplan-Meier plot saved successfully!") + + +def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): + """Generate Cox proportional hazards plots""" + print("Generating Cox proportional hazards analysis...") + + # Parse clinical variables + clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] + + # Validate that survival variables are included + required_vars = [args.surv_time_var, args.surv_event_var] + for var in required_vars: + if var not in clinical_vars: + clinical_vars.append(var) + + print(f"Using clinical variables: {', '.join(clinical_vars)}") + + # filter datasets for clinical variables + if all(var in clinical_train.columns and var in clinical_test.columns for var in clinical_vars): + df_clin_train = clinical_train[clinical_vars] + df_clin_test = clinical_test[clinical_vars] + # Drop rows with NaN in clinical variables + df_clin_train = df_clin_train.dropna(subset=clinical_vars) + df_clin_test = df_clin_test.dropna(subset=clinical_vars) + else: + raise ValueError(f"Not all clinical variables found in datasets. Available in train dataset: {clinical_train.columns.tolist()}, Available in test dataset: {clinical_test.columns.tolist()}") + + # Combine + df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) + + # Get top survival markers + print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") + try: + imp = get_important_features(model, + var=args.surv_event_var, + top=args.top_features + )['name'].unique().tolist() + print(f"Top features: {', '.join(imp)}") + except Exception as e: + raise ValueError(f"Error getting important features: {e}") + + # Extract feature data from omics datasets + try: + omics_test = omics_test.loc[omics_test.index.isin(imp)] + omics_train = omics_train.loc[omics_train.index.isin(imp)] + # Drop rows with NaN in omics datasets + omics_test = omics_test.dropna(subset=omics_test.columns) + omics_train = omics_train.dropna(subset=omics_train.columns) + + df_imp = pd.concat([omics_train, omics_test], axis=1) + df_imp = df_imp.T # Transpose to have samples as rows + + print(f"Feature data shape: {df_imp.shape}") + except Exception as e: + raise ValueError(f"Error extracting feature subset: {e}") + + # Combine markers with clinical variables + df = pd.merge(df_imp, df_clin, left_index=True, right_index=True) + print(f"Combined data shape: {df.shape}") + + # Remove samples without survival endpoints + initial_samples = len(df) + df = df[df[args.surv_event_var].notna()] + final_samples = len(df) + print(f"Removed {initial_samples - final_samples} samples without survival data") + + if df.empty: + raise ValueError("No samples remain after filtering for survival data") + + # Convert survival event column to binary (0/1) based on event_value + # Convert to string for comparison to handle mixed types + df[args.surv_event_var] = df[args.surv_event_var].astype(str) + event_value_str = str(args.event_value) + + df[f'{args.surv_event_var}'] = ( + df[args.surv_event_var] == event_value_str + ).astype(int) + + # Build Cox model + print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") + try: + coxm = build_cox_model(df, + duration_col=args.surv_time_var, + event_col=args.surv_event_var, + crossval=args.crossval, + n_splits=args.n_splits, + random_state=args.random_state) + print("Cox model built successfully") + except Exception as e: + raise ValueError(f"Error building Cox model: {e}") + + # Generate hazard ratios plot + try: + print("Generating hazard ratios plot...") + fig = plot_hazard_ratios(coxm) + + output_path = output_dir / f"{output_name_base}_hazard_ratios.{args.format}" + print(f"Saving hazard ratios plot to: {output_path.absolute()}") + fig.save(output_path, dpi=args.dpi, bbox_inches='tight') + + print("Cox proportional hazards analysis completed successfully!") + + except Exception as e: + raise ValueError(f"Error generating hazard ratios plot: {e}") + + +def generate_plot_scatter(labels, args, output_dir, output_name_base): + """Generate scatter plot of known vs predicted labels""" + print("Generating scatter plots of known vs predicted labels...") + + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() + + print(f"Processing target values: {target_values}") + + successful_plots = 0 + skipped_plots = 0 + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + if target_labels.empty: + print(f" Warning: No data found for target value '{target_value}' - skipping") + skipped_plots += 1 + continue + + # Check if labels are numeric and convert + true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') + predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') + + if true_values.isna().all() or predicted_values.isna().all(): + print(f"No valid numeric values found for known or predicted labels in '{target_value}'") + skipped_plots += 1 + continue + + try: + print(f" Generating scatter plot for '{target_value}'...") + fig = plot_scatter(true_values, predicted_values) + + # Create output filename with target value + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + if len(target_values) > 1: + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" + else: + output_filename = f"{output_name_base}.{args.format}" + + output_path = output_dir / output_filename + print(f" Saving scatter plot to: {output_path.absolute()}") + fig.save(output_path, dpi=args.dpi, bbox_inches='tight') + + successful_plots += 1 + print(f" Scatter plot for '{target_value}' generated successfully!") + + except Exception as e: + print(f" Error generating plot for '{target_value}': {str(e)}") + skipped_plots += 1 + + # Summary + print(" Summary:") + print(f" Successfully generated: {successful_plots} plots") + print(f" Skipped: {skipped_plots} plots") + + if successful_plots == 0: + raise ValueError("No scatter plots could be generated. Check your data and target values.") + + print("Scatter plot generation completed!") + + +def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): + """Generate label concordance heatmap""" + print("Generating label concordance heatmaps...") + + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() + + print(f"Processing target values: {target_values}") + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + if target_labels.empty: + print(f" Warning: No data found for target value '{target_value}' - skipping") + continue + + true_values = target_labels['known_label'].tolist() + predicted_values = target_labels['predicted_label'].tolist() + + try: + print(f" Generating heatmap for '{target_value}'...") + fig = plot_label_concordance_heatmap(true_values, predicted_values) + plt.close(fig) + + # Create output filename with target value + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + if len(target_values) > 1: + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" + else: + output_filename = f"{output_name_base}.{args.format}" + + output_path = output_dir / output_filename + print(f" Saving heatmap to: {output_path.absolute()}") + fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') + + except Exception as e: + print(f" Error generating heatmap for '{target_value}': {str(e)}") + continue + + print("Label concordance heatmap generated successfully!") + + +def generate_pr_curves(labels, args, output_dir, output_name_base): + """Generate precision-recall curves""" + print("Generating precision-recall curves...") + + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() + + print(f"Processing target values: {target_values}") + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + # Check if this is a regression problem (no class probabilities) + prob_columns = target_labels['class_label'].unique() + non_na_probs = target_labels['probability'].notna().sum() + + print(f" Class labels found: {list(prob_columns)}") + print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") + + # If most probabilities are NaN, this is likely a regression problem + if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities + print(" Detected regression problem - precision-recall curves not applicable") + print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") + continue + + # Debug: Check data quality + total_rows = len(target_labels) + missing_labels = target_labels['known_label'].isna().sum() + missing_probs = target_labels['probability'].isna().sum() + unique_samples = target_labels['sample_id'].nunique() + unique_classes = target_labels['class_label'].nunique() + + print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") + print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") + + if missing_labels > 0: + print(f" Warning: Found {missing_labels} missing known_label values") + missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] + print(f" Sample IDs with missing known_label: {list(missing_samples)}") + + # Remove rows with missing known_label + target_labels = target_labels.dropna(subset=['known_label']) + if target_labels.empty: + print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") + continue + + # 1. Pivot to wide format + prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') + + print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") + print(f" Class columns: {list(prob_df.columns)}") + + # Check for NaN values in probability data + nan_counts = prob_df.isna().sum() + if nan_counts.any(): + print(f" NaN counts per class: {dict(nan_counts)}") + print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") + + # Drop only rows where ALL probabilities are NaN + all_nan_rows = prob_df.isna().all(axis=1) + if all_nan_rows.any(): + print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") + prob_df = prob_df[~all_nan_rows] + + remaining_nans = prob_df.isna().sum().sum() + if remaining_nans > 0: + print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") + prob_df = prob_df.fillna(0) + + if prob_df.empty: + print(f" Error: No valid probability data remaining for '{target_value}' - skipping") + continue + + # 2. Get true labels + true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') + + # 3. Align indices - only keep samples that exist in both datasets + common_indices = prob_df.index.intersection(true_labels_df.index) + if len(common_indices) == 0: + print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") + continue + + print(f" Found {len(common_indices)} samples with both probability and true label data") + + # Filter both datasets to common indices + prob_df_aligned = prob_df.loc[common_indices] + y_true = true_labels_df.loc[common_indices]['known_label'] + + # 4. Final check for NaN values + if y_true.isna().any(): + print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") + continue + + if prob_df_aligned.isna().any().any(): + print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") + continue + + # 5. Convert categorical labels to integer labels + # Create a mapping from class names to integers + class_names = list(prob_df_aligned.columns) + class_to_int = {class_name: i for i, class_name in enumerate(class_names)} + + print(f" Class mapping: {class_to_int}") + + # Convert true labels to integers + y_true_np = y_true.map(class_to_int).to_numpy() + y_probs_np = prob_df_aligned.to_numpy() + + print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") + print(f" Unique true labels (integers): {set(y_true_np)}") + print(f" Class labels (columns): {class_names}") + print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") + + # Check for any unmapped labels (will be NaN) + if pd.isna(y_true_np).any(): + print(" Error: Some true labels could not be mapped to class columns") + unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) + print(f" Unmapped labels: {unmapped_labels}") + print(f" Available classes: {class_names}") + continue + + try: + print(f" Generating precision-recall curve for '{target_value}'...") + fig = plot_pr_curves(y_true_np, y_probs_np) + + # Create output filename with target value + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + if len(target_values) > 1: + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" + else: + output_filename = f"{output_name_base}.{args.format}" + + output_path = output_dir / output_filename + print(f" Saving precision-recall curve to: {output_path.absolute()}") + fig.save(output_path, dpi=args.dpi, bbox_inches='tight') + + except Exception as e: + print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") + print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") + print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") + continue + + print("Precision-recall curves generated successfully!") + + +def generate_roc_curves(labels, args, output_dir, output_name_base): + """Generate ROC curves""" + print("Generating ROC curves...") + + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() + + print(f"Processing target values: {target_values}") + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + # Check if this is a regression problem (no class probabilities) + prob_columns = target_labels['class_label'].unique() + non_na_probs = target_labels['probability'].notna().sum() + + print(f" Class labels found: {list(prob_columns)}") + print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") + + # If most probabilities are NaN, this is likely a regression problem + if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities + print(" Detected regression problem - ROC curves not applicable") + print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") + continue + + # Debug: Check data quality + total_rows = len(target_labels) + missing_labels = target_labels['known_label'].isna().sum() + missing_probs = target_labels['probability'].isna().sum() + unique_samples = target_labels['sample_id'].nunique() + unique_classes = target_labels['class_label'].nunique() + + print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") + print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") + + if missing_labels > 0: + print(f" Warning: Found {missing_labels} missing known_label values") + missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] + print(f" Sample IDs with missing known_label: {list(missing_samples)}") + + # Remove rows with missing known_label + target_labels = target_labels.dropna(subset=['known_label']) + if target_labels.empty: + print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") + continue + + # 1. Pivot to wide format + prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') + + print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") + print(f" Class columns: {list(prob_df.columns)}") + + # Check for NaN values in probability data + nan_counts = prob_df.isna().sum() + if nan_counts.any(): + print(f" NaN counts per class: {dict(nan_counts)}") + print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") + + # Drop only rows where ALL probabilities are NaN + all_nan_rows = prob_df.isna().all(axis=1) + if all_nan_rows.any(): + print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") + prob_df = prob_df[~all_nan_rows] + + remaining_nans = prob_df.isna().sum().sum() + if remaining_nans > 0: + print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") + prob_df = prob_df.fillna(0) + + if prob_df.empty: + print(f" Error: No valid probability data remaining for '{target_value}' - skipping") + continue + + # 2. Get true labels + true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') + + # 3. Align indices - only keep samples that exist in both datasets + common_indices = prob_df.index.intersection(true_labels_df.index) + if len(common_indices) == 0: + print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") + continue + + print(f" Found {len(common_indices)} samples with both probability and true label data") + + # Filter both datasets to common indices + prob_df_aligned = prob_df.loc[common_indices] + y_true = true_labels_df.loc[common_indices]['known_label'] + + # 4. Final check for NaN values + if y_true.isna().any(): + print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") + continue + + if prob_df_aligned.isna().any().any(): + print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") + continue + + # 5. Convert categorical labels to integer labels + # Create a mapping from class names to integers + class_names = list(prob_df_aligned.columns) + class_to_int = {class_name: i for i, class_name in enumerate(class_names)} + + print(f" Class mapping: {class_to_int}") + + # Convert true labels to integers + y_true_np = y_true.map(class_to_int).to_numpy() + y_probs_np = prob_df_aligned.to_numpy() + + print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") + print(f" Unique true labels (integers): {set(y_true_np)}") + print(f" Class labels (columns): {class_names}") + print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") + + # Check for any unmapped labels (will be NaN) + if pd.isna(y_true_np).any(): + print(" Error: Some true labels could not be mapped to class columns") + unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) + print(f" Unmapped labels: {unmapped_labels}") + print(f" Available classes: {class_names}") + continue + + try: + print(f" Generating ROC curve for '{target_value}'...") + fig = plot_roc_curves(y_true_np, y_probs_np) + + # Create output filename with target value + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + if len(target_values) > 1: + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" + else: + output_filename = f"{output_name_base}.{args.format}" + + output_path = output_dir / output_filename + print(f" Saving ROC curve to: {output_path.absolute()}") + fig.save(output_path, dpi=args.dpi, bbox_inches='tight') + + except Exception as e: + print(f" Error generating ROC curve for '{target_value}': {str(e)}") + print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") + print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") + continue + + print("ROC curves generated successfully!") + + +def generate_box_plots(labels, args, output_dir, output_name_base): + """Generate box plots for model predictions""" + + print("Generating box plots...") + + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() + + print(f"Processing target values: {target_values}") + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + if target_labels.empty: + print(f" Warning: No data found for target value '{target_value}' - skipping") + continue + + # Check if this is a classification problem (has probabilities) + prob_columns = target_labels['class_label'].unique() + non_na_probs = target_labels['probability'].notna().sum() + + print(f" Class labels found: {list(prob_columns)}") + print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") + + # If most probabilities are NaN, this is likely a regression problem + if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities + print(" Detected regression problem - precision-recall curves not applicable") + print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") + continue + + # Debug: Check data quality + total_rows = len(target_labels) + missing_labels = target_labels['known_label'].isna().sum() + missing_probs = target_labels['probability'].isna().sum() + unique_samples = target_labels['sample_id'].nunique() + unique_classes = target_labels['class_label'].nunique() + + print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") + print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") + + if missing_labels > 0: + print(f" Warning: Found {missing_labels} missing known_label values") + missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] + print(f" Sample IDs with missing known_label: {list(missing_samples)}") + + # Remove rows with missing known_label + target_labels = target_labels.dropna(subset=['known_label']) + if target_labels.empty: + print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") + continue + + # Remove rows with missing data + clean_data = target_labels.dropna(subset=['known_label', 'probability']) + + if clean_data.empty: + print(" No valid data after cleaning - skipping") + continue + + # Get unique classes + classes = clean_data['class_label'].unique() + + for class_label in classes: + print(f" Generating box plot for class: {class_label}") + + # Filter for current class + class_data = clean_data[clean_data['class_label'] == class_label] + + try: + # Create the box plot + fig = plot_boxplot( + categorical_x=class_data['known_label'], + numerical_y=class_data['probability'], + title_x='True Label', + title_y=f'Predicted Probability ({class_label})', + ) + + # Save the plot + safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_') + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + output_filename = f"{output_name_base}_{safe_target_name}_{safe_class_name}.{args.format}" + output_path = output_dir / output_filename + + print(f" Saving box plot to: {output_path.absolute()}") + fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') + plt.close(fig) + + except Exception as e: + print(f" Error generating box plot for class '{class_label}': {str(e)}") + continue + + +def main(): + """Main function to parse arguments and generate plots""" + parser = argparse.ArgumentParser(description="Generate plots using flexynesis") + + # Required arguments + parser.add_argument("--labels", type=str, required=False, + help="Path to labels file generated by flexynesis") + + # Plot type + parser.add_argument("--plot_type", type=str, required=True, + choices=['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'], + help="Type of plot to generate: 'dimred' for dimensionality reduction, 'kaplan_meier' for survival analysis, 'cox' for Cox proportional hazards analysis, 'scatter' for scatter plots, 'concordance_heatmap' for label concordance heatmaps, 'pr_curve' for precision-recall curves, 'roc_curve' for ROC curves, or 'box_plot' for box plots.") + + # Arguments for dimensionality reduction + parser.add_argument("--embeddings", type=str, + help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") + parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], + help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") + parser.add_argument("--target_variables", type=str, required=False, + help="Comma-separated list of target variables to plot.") + + # Arguments for Kaplan-Meier + parser.add_argument("--survival_data", type=str, + help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") + parser.add_argument("--surv_time_var", type=str, required=False, + help="Column name for survival time") + parser.add_argument("--surv_event_var", type=str, required=False, + help="Column name for survival event") + parser.add_argument("--event_value", type=str, required=False, + help="Value in event column that represents an event (e.g., 'DECEASED')") + + # Arguments for Cox analysis + parser.add_argument("--model", type=str, + help="Path to trained flexynesis model (pickle file). Required for cox plots.") + parser.add_argument("--clinical_train", type=str, + help="Path to training dataset (pickle file). Required for cox plots.") + parser.add_argument("--clinical_test", type=str, + help="Path to test dataset (pickle file). Required for cox plots.") + parser.add_argument("--omics_train", type=str, default=None, + help="Path to training omics dataset. Optional for cox plots.") + parser.add_argument("--omics_test", type=str, default=None, + help="Path to test omics dataset. Optional for cox plots.") + parser.add_argument("--clinical_variables", type=str, + help="Comma-separated list of clinical variables to include in Cox model (e.g., 'AGE,SEX,HISTOLOGICAL_DIAGNOSIS,STUDY')") + parser.add_argument("--top_features", type=int, default=20, + help="Number of top important features to include in Cox model. Default is 5") + parser.add_argument("--crossval", action='store_true', + help="If True, performs K-fold cross-validation and returns average C-index. Default is False") + parser.add_argument("--n_splits", type=int, default=5, + help="Number of folds for cross-validation. Default is 5") + parser.add_argument("--random_state", type=int, default=42, + help="Random seed for reproducibility. Default is 42") + + # Arguments for scatter plot, heatmap, PR curves, ROC curves, and box plots + parser.add_argument("--target_value", type=str, default=None, + help="Target value for scatter plot.") + + # Common arguments + parser.add_argument("--output_dir", type=str, default='output', + help="Output directory. Default is 'output'") + parser.add_argument("--output_name", type=str, default=None, + help="Output filename base") + parser.add_argument("--format", type=str, default='jpg', choices=['png', 'pdf', 'svg', 'jpg'], + help="Output format for the plot. Default is 'jpg'") + parser.add_argument("--dpi", type=int, default=300, + help="DPI for the output image. Default is 300") + + args = parser.parse_args() + + try: + # validate plot type + if not args.plot_type: + raise ValueError("Please specify a plot type using --plot_type") + if args.plot_type not in ['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot']: + raise ValueError(f"Invalid plot type: {args.plot_type}. Must be one of: 'dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'") + + # Validate plot type requirements + if args.plot_type in ['dimred']: + if not args.embeddings: + raise ValueError("--embeddings is required when plot_type is 'dimred'") + if not os.path.isfile(args.embeddings): + raise FileNotFoundError(f"embeddings file not found: {args.embeddings}") + if not args.labels: + raise ValueError("--labels is required for dimensionality reduction plots") + if not args.method: + raise ValueError("--method is required for dimensionality reduction plots") + if not args.target_variables: + raise ValueError("--target_variables is required for dimensionality reduction plots") + + if args.plot_type in ['kaplan_meier']: + if not args.survival_data: + raise ValueError("--survival_data is required when plot_type is 'kaplan_meier'") + if not os.path.isfile(args.survival_data): + raise FileNotFoundError(f"Survival data file not found: {args.survival_data}") + if not args.labels: + raise ValueError("--labels is required for dimensionality reduction plots") + if not args.method: + raise ValueError("--method is required for dimensionality reduction plots") + if not args.surv_time_var: + raise ValueError("--surv_time_var is required for Kaplan-Meier plots") + if not args.surv_event_var: + raise ValueError("--surv_event_var is required for Kaplan-Meier plots") + if not args.event_value: + raise ValueError("--event_value is required for Kaplan-Meier plots") + + if args.plot_type in ['cox']: + if not args.model: + raise ValueError("--model is required when plot_type is 'cox'") + if not os.path.isfile(args.model): + raise FileNotFoundError(f"Model file not found: {args.model}") + if not args.clinical_train: + raise ValueError("--clinical_train is required when plot_type is 'cox'") + if not os.path.isfile(args.clinical_train): + raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") + if not args.clinical_test: + raise ValueError("--clinical_test is required when plot_type is 'cox'") + if not os.path.isfile(args.clinical_test): + raise FileNotFoundError(f"Test dataset file not found: {args.clinical_test}") + if not args.omics_train: + raise ValueError("--omics_train is required when plot_type is 'cox'") + if not os.path.isfile(args.omics_train): + raise FileNotFoundError(f"Training omics dataset file not found: {args.omics_train}") + if not args.omics_test: + raise ValueError("--omics_test is required when plot_type is 'cox'") + if not os.path.isfile(args.omics_test): + raise FileNotFoundError(f"Test omics dataset file not found: {args.omics_test}") + if not args.surv_time_var: + raise ValueError("--surv_time_var is required for Cox plots") + if not args.surv_event_var: + raise ValueError("--surv_event_var is required for Cox plots") + if not args.clinical_variables: + raise ValueError("--clinical_variables is required for Cox plots") + if not isinstance(args.top_features, int) or args.top_features <= 0: + raise ValueError("--top_features must be a positive integer") + if not args.event_value: + raise ValueError("--event_value is required for Kaplan-Meier plots") + if not args.crossval: + args.crossval = False + if not isinstance(args.n_splits, int) or args.n_splits <= 0: + raise ValueError("--n_splits must be a positive integer") + if not isinstance(args.random_state, int): + raise ValueError("--random_state must be an integer") + + if args.plot_type in ['scatter']: + if not args.labels: + raise ValueError("--labels is required for scatter plots") + if not args.target_value: + print("--target_value is not specified, using all unique variables from labels") + if not os.path.isfile(args.labels): + raise FileNotFoundError(f"Labels file not found: {args.labels}") + + if args.plot_type in ['concordance_heatmap']: + if not args.labels: + raise ValueError("--labels is required for concordance heatmap") + if not args.target_value: + print("--target_value is not specified, using all unique variables from labels") + if not os.path.isfile(args.labels): + raise FileNotFoundError(f"Labels file not found: {args.labels}") + + if args.plot_type in ['pr_curve']: + if not args.labels: + raise ValueError("--labels is required for precision-recall curves") + if not args.target_value: + print("--target_value is not specified, using all unique variables from labels") + if not os.path.isfile(args.labels): + raise FileNotFoundError(f"Labels file not found: {args.labels}") + + if args.plot_type in ['roc_curve']: + if not args.labels: + raise ValueError("--labels is required for ROC curves") + if not args.target_value: + print("--target_value is not specified, using all unique variables from labels") + if not os.path.isfile(args.labels): + raise FileNotFoundError(f"Labels file not found: {args.labels}") + + if args.plot_type in ['box_plot']: + if not args.labels: + raise ValueError("--labels is required for box plots") + if not args.target_value: + print("--target_value is not specified, using all unique variables from labels") + if not os.path.isfile(args.labels): + raise FileNotFoundError(f"Labels file not found: {args.labels}") + + # Validate other arguments + if args.method not in ['pca', 'umap']: + raise ValueError("Method must be 'pca' or 'umap'") + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Output directory: {output_dir.absolute()}") + + # Generate output filename base + if args.output_name: + output_name_base = args.output_name + else: + if args.plot_type == 'dimred': + embeddings_name = Path(args.embeddings).stem + output_name_base = f"{embeddings_name}_{args.method}" + elif args.plot_type == 'kaplan_meier': + survival_name = Path(args.survival_data).stem + output_name_base = f"{survival_name}_km" + elif args.plot_type == 'cox': + model_name = Path(args.model).stem + output_name_base = f"{model_name}_cox" + elif args.plot_type == 'scatter': + labels_name = Path(args.labels).stem + output_name_base = f"{labels_name}_scatter" + elif args.plot_type == 'concordance_heatmap': + labels_name = Path(args.labels).stem + output_name_base = f"{labels_name}_concordance" + elif args.plot_type == 'pr_curve': + labels_name = Path(args.labels).stem + output_name_base = f"{labels_name}_pr_curves" + elif args.plot_type == 'roc_curve': + labels_name = Path(args.labels).stem + output_name_base = f"{labels_name}_roc_curves" + elif args.plot_type == 'box_plot': + labels_name = Path(args.labels).stem + output_name_base = f"{labels_name}_box_plot" + + # Generate plots based on type + if args.plot_type in ['dimred']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + # Load embeddings data + print(f"Loading embeddings from: {args.embeddings}") + embeddings, sample_names = load_embeddings(args.embeddings) + print(f"embeddings shape: {embeddings.shape}") + + # Match samples to embeddings + matched_labels = match_samples_to_embeddings(sample_names, label_data) + print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") + + generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) + + elif args.plot_type in ['kaplan_meier']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + # Load survival data + print(f"Loading survival data from: {args.survival_data}") + survival_data = load_survival_data(args.survival_data) + print(f"Survival data shape: {survival_data.shape}") + + generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) + + elif args.plot_type in ['cox']: + # Load model and datasets + print(f"Loading model from: {args.model}") + model = load_model(args.model) + print(f"Loading training dataset from: {args.clinical_train}") + clinical_train = load_omics(args.clinical_train) + print(f"Loading test dataset from: {args.clinical_test}") + clinical_test = load_omics(args.clinical_test) + print(f"Loading training omics dataset from: {args.omics_train}") + omics_train = load_omics(args.omics_train) + print(f"Loading test omics dataset from: {args.omics_test}") + omics_test = load_omics(args.omics_test) + + generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) + + elif args.plot_type in ['scatter']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + + generate_plot_scatter(label_data, args, output_dir, output_name_base) + + elif args.plot_type in ['concordance_heatmap']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + + generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) + + elif args.plot_type in ['pr_curve']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + + generate_pr_curves(label_data, args, output_dir, output_name_base) + + elif args.plot_type in ['roc_curve']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + + generate_roc_curves(label_data, args, output_dir, output_name_base) + + elif args.plot_type in ['box_plot']: + # Load labels + print(f"Loading labels from: {args.labels}") + label_data = load_labels(args.labels) + + generate_box_plots(label_data, args, output_dir, output_name_base) + + print("All plots generated successfully!") + + except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: + print(f"Error: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main())