Mercurial > repos > bgruening > flexynesis
diff flexynesis_plot.py @ 6:33816f44fc7d draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit 6b520305ec30e6dc37eba92c67a5368cea0fc5ad
author | bgruening |
---|---|
date | Wed, 23 Jul 2025 07:49:41 +0000 |
parents | 466b593fd87e |
children |
line wrap: on
line diff
--- a/flexynesis_plot.py Fri Jul 04 14:57:40 2025 +0000 +++ b/flexynesis_plot.py Wed Jul 23 07:49:41 2025 +0000 @@ -11,10 +11,8 @@ import numpy as np import pandas as pd import seaborn as sns -import torch from flexynesis import ( build_cox_model, - get_important_features, plot_dim_reduced, plot_hazard_ratios, plot_kaplan_meier_curves, @@ -55,35 +53,13 @@ elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: df = pd.read_csv(labels_input, sep='\t') - # Check if this is the specific format with sample_id, known_label, predicted_label - required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] - if all(col in df.columns for col in required_cols): - return df - else: - raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}") + print(f"available columns: {df.columns.tolist()}") + return df except Exception as e: raise ValueError(f"Error loading labels from {labels_input}: {e}") from e -def load_survival_data(survival_path): - """Load survival data from a file. First column should be sample_id""" - try: - # Determine file extension - file_ext = Path(survival_path).suffix.lower() - - if file_ext == '.csv': - df = pd.read_csv(survival_path, index_col=0) - elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: - df = pd.read_csv(survival_path, sep='\t', index_col=0) - else: - raise ValueError(f"Unsupported file extension: {file_ext}") - return df - - except Exception as e: - raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e - - def load_omics(omics_path): """Load omics data from a file. First column should be features""" try: @@ -102,19 +78,17 @@ raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e -def load_model(model_path): - """Load flexynesis model from pickle file""" - try: - with open(model_path, 'rb') as f: - model = torch.load(f, weights_only=False) - return model - except Exception as e: - raise ValueError(f"Error loading model from {model_path}: {e}") from e +def match_samples_to_embeddings(sample_names, labels): + """Filter label data to match sample names in the embeddings""" + # Create a DataFrame from sample_names to preserve order + sample_df = pd.DataFrame({'sample_names': sample_names}) + # left_join + first_column = labels.columns[0] + df_matched = sample_df.merge(labels, left_on='sample_names', right_on=first_column, how='left') -def match_samples_to_embeddings(sample_names, label_data): - """Filter label data to match sample names in the embeddings""" - df_matched = label_data[label_data['sample_id'].isin(sample_names)] + # remove sample_names to keep the initial structure + df_matched = df_matched.drop('sample_names', axis=1) return df_matched @@ -214,124 +188,149 @@ def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): """Generate dimensionality reduction plots""" - # Parse target values from comma-separated string - if args.target_value: - target_values = [val.strip() for val in args.target_value.split(',')] - else: - # If no target values specified, use all unique variables - target_values = matched_labels['variable'].unique().tolist() + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in matched_labels.columns for col in required_cols) + + if not args.color: + if is_flexynesis_format: + print("Detected flexynesis labels format") + print(f"Generating {args.method.upper()} plots for known and predicted labels...") + else: + print("Labels are not in flexynesis format (Custom labels), please specify a color variable with --color") + + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = matched_labels['variable'].unique().tolist() + + print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}") - print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}") + # Check variables + available_vars = matched_labels['variable'].unique() + missing_vars = [var for var in target_values if var not in available_vars] + + if missing_vars: + print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") + print(f"Available variables: {', '.join(available_vars)}") + + # Filter to only process available variables + valid_vars = [var for var in target_values if var in available_vars] - # Check variables - available_vars = matched_labels['variable'].unique() - missing_vars = [var for var in target_values if var not in available_vars] + if not valid_vars: + raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") + + # Generate plots for each valid target variable + for var in valid_vars: + print(f"\nPlotting variable: {var}") - if missing_vars: - print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") - print(f"Available variables: {', '.join(available_vars)}") + # Filter matched labels for current variable + var_labels = matched_labels[matched_labels['variable'] == var].copy() + var_labels = var_labels.drop_duplicates(subset='sample_id') + + if var_labels.empty: + print(f"Warning: No data found for variable '{var}', skipping...") + continue - # Filter to only process available variables - valid_vars = [var for var in target_values if var in available_vars] + # Auto-detect color type + known_color_type = detect_color_type(var_labels['known_label']) + predicted_color_type = detect_color_type(var_labels['predicted_label']) + + print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") - if not valid_vars: - raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") + try: + # Plot 1: Known labels + print(f" Creating known labels plot for {var}...") + fig_known = plot_dim_reduced( + matrix=embeddings, + labels=var_labels['known_label'], + method=args.method, + color_type=known_color_type + ) + + output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" + print(f" Saving known labels plot to: {output_path_known.name}") + fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') - # Generate plots for each valid target variable - for var in valid_vars: - print(f"\nPlotting variable: {var}") + # Plot 2: Predicted labels + print(f" Creating predicted labels plot for {var}...") + fig_predicted = plot_dim_reduced( + matrix=embeddings, + labels=var_labels['predicted_label'], + method=args.method, + color_type=predicted_color_type + ) + + output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" + print(f" Saving predicted labels plot to: {output_path_predicted.name}") + fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') - # Filter matched labels for current variable - var_labels = matched_labels[matched_labels['variable'] == var].copy() - var_labels = var_labels.drop_duplicates(subset='sample_id') + print(f" ✓ Successfully created plots for variable '{var}'") + + except Exception as e: + print(f" ✗ Error creating plots for variable '{var}': {e}") + continue - if var_labels.empty: - print(f"Warning: No data found for variable '{var}', skipping...") - continue + print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") + + else: + # check if the color variable exists in matched_labels + if args.color not in matched_labels.columns: + raise ValueError(f"Color variable '{args.color}' not found in matched labels. Available columns: {matched_labels.columns.tolist()}") # Auto-detect color type - known_color_type = detect_color_type(var_labels['known_label']) - predicted_color_type = detect_color_type(var_labels['predicted_label']) - - print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") + color_type = detect_color_type(matched_labels[args.color]) - try: - # Plot 1: Known labels - print(f" Creating known labels plot for {var}...") - fig_known = plot_dim_reduced( - matrix=embeddings, - labels=var_labels['known_label'], - method=args.method, - color_type=known_color_type - ) - - output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" - print(f" Saving known labels plot to: {output_path_known.name}") - fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') + print(f" Auto-detected color type: {color_type}") - # Plot 2: Predicted labels - print(f" Creating predicted labels plot for {var}...") - fig_predicted = plot_dim_reduced( - matrix=embeddings, - labels=var_labels['predicted_label'], - method=args.method, - color_type=predicted_color_type - ) + # Plot: Specified color column + print(f" Creating plot for {args.color}...") + fig = plot_dim_reduced( + matrix=embeddings, + labels=matched_labels[args.color], + method=args.method, + color_type=color_type + ) - output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" - print(f" Saving predicted labels plot to: {output_path_predicted.name}") - fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') - - print(f" ✓ Successfully created plots for variable '{var}'") + output_path = output_dir / f"{output_name_base}_{args.color}.{args.format}" + print(f" Saving plot to: {output_path.name}") + fig.save(output_path, dpi=args.dpi, bbox_inches='tight') - except Exception as e: - print(f" ✗ Error creating plots for variable '{var}': {e}") - continue - - print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") + print(f" ✓ Successfully created plot for variable '{args.color}'") -def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base): +def generate_km_plots(survival_data, labels, args, output_dir, output_name_base): """Generate Kaplan-Meier plots""" + + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in labels.columns for col in required_cols) + + if not is_flexynesis_format: + raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") + print("Generating Kaplan-Meier curves of risk subtypes...") - # Reset index and rename the index column to sample_id - survival_data = survival_data.reset_index() if survival_data.columns[0] != 'sample_id': survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) - # Convert survival event column to binary (0/1) based on event_value # Check if the event column exists if args.surv_event_var not in survival_data.columns: raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") - # Convert to string for comparison to handle mixed types - survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str) - event_value_str = str(args.event_value) - - # Create binary event column (1 if matches event_value, 0 otherwise) - survival_data[f'{args.surv_event_var}_binary'] = ( - survival_data[args.surv_event_var] == event_value_str - ).astype(int) - - # Filter for survival category and class_label == '1:DECEASED' - label_data['class_label'] = label_data['class_label'].astype(str) - - label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)] - - # check survival data - for col in [args.surv_time_var, args.surv_event_var]: - if col not in survival_data.columns: - raise ValueError(f"Column '{col}' not found in survival data") + labels = labels[(labels['variable'] == args.surv_event_var)] # Merge survival data with labels - df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') + df_deceased = pd.merge(survival_data, labels, on='sample_id', how='inner') + df_deceased = df_deceased.dropna(subset=[args.surv_time_var, args.surv_event_var]) if df_deceased.empty: raise ValueError("No matching samples found after merging survival and label data.") # Get risk scores - risk_scores = df_deceased['probability'].values + risk_scores = df_deceased['predicted_label'].values # Compute groups (e.g., median split) quantiles = np.quantile(risk_scores, [0.5]) @@ -340,7 +339,7 @@ fig_known = plot_kaplan_meier_curves( durations=df_deceased[args.surv_time_var], - events=df_deceased[f'{args.surv_event_var}_binary'], + events=df_deceased[args.surv_event_var], categorical_variable=group_labels ) @@ -351,12 +350,21 @@ print("Kaplan-Meier plot saved successfully!") -def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): +def generate_cox_plots(important_features, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): """Generate Cox proportional hazards plots""" print("Generating Cox proportional hazards analysis...") + # Check if this is the specific format with target_variable, importance + required_cols = ['target_variable', 'layer', 'importance'] + is_flexynesis_format = all(col in important_features.columns for col in required_cols) + + if not is_flexynesis_format: + raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid important_features file with the required columns, {required_cols}.") + # Parse clinical variables - clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] + clinical_vars = [] + if args.clinical_variables: + clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] # Validate that survival variables are included required_vars = [args.surv_time_var, args.surv_event_var] @@ -382,10 +390,22 @@ # Get top survival markers print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") try: - imp = get_important_features(model, - var=args.surv_event_var, - top=args.top_features - )['name'].unique().tolist() + print(f"Loading {args.top_features} important features from: {args.important_features}") + imp_features = load_labels(args.important_features) + imp_features = imp_features[imp_features['target_variable'] == args.surv_event_var] + if args.layer not in imp_features['layer'].unique(): + print(f"Available class labels: {imp_features['layer'].unique()}") + raise ValueError(f"Class label '{args.layer}' not found in important features data: {args.important_features}") + imp_features = imp_features[imp_features['layer'] == args.layer] + if imp_features.empty: + raise ValueError(f"No important features found for target variable '{args.surv_event_var}' in {args.important_features}") + imp_features = imp_features.sort_values(by='importance', ascending=False) + + if len(imp_features) < args.top_features: + raise ValueError(f"Requested top {args.top_features} features, but only {len(imp_features)} available in {args.important_features}") + + imp = imp_features['name'].unique().tolist()[0:args.top_features] + print(f"Top features: {', '.join(imp)}") except Exception as e: raise ValueError(f"Error getting important features: {e}") @@ -418,15 +438,6 @@ if df.empty: raise ValueError("No samples remain after filtering for survival data") - # Convert survival event column to binary (0/1) based on event_value - # Convert to string for comparison to handle mixed types - df[args.surv_event_var] = df[args.surv_event_var].astype(str) - event_value_str = str(args.event_value) - - df[f'{args.surv_event_var}'] = ( - df[args.surv_event_var] == event_value_str - ).astype(int) - # Build Cox model print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") try: @@ -459,124 +470,190 @@ """Generate scatter plot of known vs predicted labels""" print("Generating scatter plots of known vs predicted labels...") - # Parse target values from comma-separated string - if args.target_value: - target_values = [val.strip() for val in args.target_value.split(',')] - else: - # If no target values specified, use all unique variables - target_values = labels['variable'].unique().tolist() + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in labels.columns for col in required_cols) + + if is_flexynesis_format: + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() + + print(f"Processing target values: {target_values}") - print(f"Processing target values: {target_values}") + successful_plots = 0 + skipped_plots = 0 + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + if target_labels.empty: + print(f" Warning: No data found for target value '{target_value}' - skipping") + skipped_plots += 1 + continue + + # Check if labels are numeric and convert + true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') + predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') - successful_plots = 0 - skipped_plots = 0 + if true_values.isna().all() or predicted_values.isna().all(): + print(f"No valid numeric values found for known or predicted labels in '{target_value}'") + skipped_plots += 1 + continue + + try: + print(f" Generating scatter plot for '{target_value}'...") + fig = plot_scatter(true_values, predicted_values) - for target_value in target_values: - print(f"\nProcessing target value: '{target_value}'") + # Create output filename with target value + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" + + output_path = output_dir / output_filename + print(f" Saving scatter plot to: {output_path.absolute()}") + fig.save(output_path, dpi=args.dpi, bbox_inches='tight') + + successful_plots += 1 + print(f" Scatter plot for '{target_value}' generated successfully!") - # Filter labels for the current target value - target_labels = labels[labels['variable'] == target_value] + except Exception as e: + print(f" Error generating plot for '{target_value}': {str(e)}") + skipped_plots += 1 + + # Summary + print(" Summary:") + print(f" Successfully generated: {successful_plots} plots") + print(f" Skipped: {skipped_plots} plots") - if target_labels.empty: - print(f" Warning: No data found for target value '{target_value}' - skipping") - skipped_plots += 1 - continue + if successful_plots == 0: + raise ValueError("No scatter plots could be generated. Check your data and target values.") + + print("Scatter plot generation completed!") + + if not is_flexynesis_format: + print("Labels are not in flexynesis format (Custom labels)") + + if not args.true_label or not args.predicted_label: + raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.") # Check if labels are numeric and convert - true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') - predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') + true_values = pd.to_numeric(labels[args.true_label], errors='coerce') + predicted_values = pd.to_numeric(labels[args.predicted_label], errors='coerce') if true_values.isna().all() or predicted_values.isna().all(): - print(f"No valid numeric values found for known or predicted labels in '{target_value}'") - skipped_plots += 1 - continue + print("No valid numeric values found for known or predicted labels") try: - print(f" Generating scatter plot for '{target_value}'...") + print(" Generating scatter plot...") fig = plot_scatter(true_values, predicted_values) # Create output filename with target value - safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') - if len(target_values) > 1: - output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" - else: - output_filename = f"{output_name_base}.{args.format}" + output_filename = f"{output_name_base}.{args.format}" output_path = output_dir / output_filename print(f" Saving scatter plot to: {output_path.absolute()}") fig.save(output_path, dpi=args.dpi, bbox_inches='tight') - successful_plots += 1 - print(f" Scatter plot for '{target_value}' generated successfully!") - except Exception as e: - print(f" Error generating plot for '{target_value}': {str(e)}") - skipped_plots += 1 + print(f" Error generating plot: {str(e)}") - # Summary - print(" Summary:") - print(f" Successfully generated: {successful_plots} plots") - print(f" Skipped: {skipped_plots} plots") - - if successful_plots == 0: - raise ValueError("No scatter plots could be generated. Check your data and target values.") - - print("Scatter plot generation completed!") + print("Scatter plot generation completed!") def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): """Generate label concordance heatmap""" print("Generating label concordance heatmaps...") - # Parse target values from comma-separated string - if args.target_value: - target_values = [val.strip() for val in args.target_value.split(',')] - else: - # If no target values specified, use all unique variables - target_values = labels['variable'].unique().tolist() + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in labels.columns for col in required_cols) + + if is_flexynesis_format: + # Parse target values from comma-separated string + if args.target_value: + target_values = [val.strip() for val in args.target_value.split(',')] + else: + # If no target values specified, use all unique variables + target_values = labels['variable'].unique().tolist() - print(f"Processing target values: {target_values}") + print(f"Processing target values: {target_values}") + + for target_value in target_values: + print(f"\nProcessing target value: '{target_value}'") + + # Filter labels for the current target value + target_labels = labels[labels['variable'] == target_value] + + if target_labels.empty: + print(f" Warning: No data found for target value '{target_value}' - skipping") + continue + + true_values = target_labels['known_label'].tolist() + predicted_values = target_labels['predicted_label'].tolist() - for target_value in target_values: - print(f"\nProcessing target value: '{target_value}'") + try: + print(f" Generating heatmap for '{target_value}'...") + fig = plot_label_concordance_heatmap(true_values, predicted_values) + plt.close(fig) - # Filter labels for the current target value - target_labels = labels[labels['variable'] == target_value] + # Create output filename with target value + safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" + + output_path = output_dir / output_filename + print(f" Saving heatmap to: {output_path.absolute()}") + fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') - if target_labels.empty: - print(f" Warning: No data found for target value '{target_value}' - skipping") - continue + except Exception as e: + print(f" Error generating heatmap for '{target_value}': {str(e)}") + continue + + print("Label concordance heatmap generated successfully!") - true_values = target_labels['known_label'].tolist() - predicted_values = target_labels['predicted_label'].tolist() + if not is_flexynesis_format: + print("Labels are not in flexynesis format (Custom labels)") + + if not args.true_label or not args.predicted_label: + raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.") + + true_values = labels[args.true_label].tolist() + predicted_values = labels[args.predicted_label].tolist() try: - print(f" Generating heatmap for '{target_value}'...") + print(" Generating heatmap for...") fig = plot_label_concordance_heatmap(true_values, predicted_values) plt.close(fig) # Create output filename with target value - safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') - if len(target_values) > 1: - output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" - else: - output_filename = f"{output_name_base}.{args.format}" + output_filename = f"{output_name_base}.{args.format}" output_path = output_dir / output_filename print(f" Saving heatmap to: {output_path.absolute()}") fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') except Exception as e: - print(f" Error generating heatmap for '{target_value}': {str(e)}") - continue + print(f" Error generating heatmap': {str(e)}") - print("Label concordance heatmap generated successfully!") + print("Label concordance heatmap generated successfully!") def generate_pr_curves(labels, args, output_dir, output_name_base): """Generate precision-recall curves""" print("Generating precision-recall curves...") + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in labels.columns for col in required_cols) + + if not is_flexynesis_format: + raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") + # Parse target values from comma-separated string if args.target_value: target_values = [val.strip() for val in args.target_value.split(',')] @@ -707,10 +784,7 @@ # Create output filename with target value safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') - if len(target_values) > 1: - output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" - else: - output_filename = f"{output_name_base}.{args.format}" + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" output_path = output_dir / output_filename print(f" Saving precision-recall curve to: {output_path.absolute()}") @@ -729,6 +803,13 @@ """Generate ROC curves""" print("Generating ROC curves...") + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in labels.columns for col in required_cols) + + if not is_flexynesis_format: + raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") + # Parse target values from comma-separated string if args.target_value: target_values = [val.strip() for val in args.target_value.split(',')] @@ -859,10 +940,7 @@ # Create output filename with target value safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') - if len(target_values) > 1: - output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" - else: - output_filename = f"{output_name_base}.{args.format}" + output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" output_path = output_dir / output_filename print(f" Saving ROC curve to: {output_path.absolute()}") @@ -880,6 +958,13 @@ def generate_box_plots(labels, args, output_dir, output_name_base): """Generate box plots for model predictions""" + # Check if this is the specific format with sample_id, known_label, predicted_label + required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] + is_flexynesis_format = all(col in labels.columns for col in required_cols) + + if not is_flexynesis_format: + raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") + print("Generating box plots...") # Parse target values from comma-separated string @@ -993,6 +1078,8 @@ help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") + parser.add_argument("--color", type=str, default=None, + help="User-defined color for the plot.") # Arguments for Kaplan-Meier parser.add_argument("--survival_data", type=str, @@ -1001,12 +1088,10 @@ help="Column name for survival time") parser.add_argument("--surv_event_var", type=str, required=False, help="Column name for survival event") - parser.add_argument("--event_value", type=str, required=False, - help="Value in event column that represents an event (e.g., 'DECEASED')") # Arguments for Cox analysis - parser.add_argument("--model", type=str, - help="Path to trained flexynesis model (pickle file). Required for cox plots.") + parser.add_argument("--important_features", type=str, + help="Path to calculated feature importance file. Required for cox plots.") parser.add_argument("--clinical_train", type=str, help="Path to training dataset (pickle file). Required for cox plots.") parser.add_argument("--clinical_test", type=str, @@ -1025,11 +1110,18 @@ help="Number of folds for cross-validation. Default is 5") parser.add_argument("--random_state", type=int, default=42, help="Random seed for reproducibility. Default is 42") + parser.add_argument("--layer", type=str, default=None, + help="Class label for filtering important features.") # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots parser.add_argument("--target_value", type=str, default=None, help="Target value for scatter plot.") + # Arguments for scatter plots and concordance heatmaps + parser.add_argument("--true_label", type=str, default=None, + help="Column name for true labels in scatter plots and concordance heatmaps.") + parser.add_argument("--predicted_label", type=str, default=None, + help="Column name for predicted labels in scatter plots and concordance heatmaps.") # Common arguments parser.add_argument("--output_dir", type=str, default='output', help="Output directory. Default is 'output'") @@ -1073,14 +1165,12 @@ raise ValueError("--surv_time_var is required for Kaplan-Meier plots") if not args.surv_event_var: raise ValueError("--surv_event_var is required for Kaplan-Meier plots") - if not args.event_value: - raise ValueError("--event_value is required for Kaplan-Meier plots") if args.plot_type in ['cox']: - if not args.model: - raise ValueError("--model is required when plot_type is 'cox'") - if not os.path.isfile(args.model): - raise FileNotFoundError(f"Model file not found: {args.model}") + if not args.important_features: + raise ValueError("--important_features is required when plot_type is 'cox'") + if not os.path.isfile(args.important_features): + raise FileNotFoundError(f"Important features file not found: {args.important_features}") if not args.clinical_train: raise ValueError("--clinical_train is required when plot_type is 'cox'") if not os.path.isfile(args.clinical_train): @@ -1102,17 +1192,17 @@ if not args.surv_event_var: raise ValueError("--surv_event_var is required for Cox plots") if not args.clinical_variables: - raise ValueError("--clinical_variables is required for Cox plots") + print("--clinical_variables is not set for Cox plots") if not isinstance(args.top_features, int) or args.top_features <= 0: raise ValueError("--top_features must be a positive integer") - if not args.event_value: - raise ValueError("--event_value is required for Kaplan-Meier plots") if not args.crossval: args.crossval = False if not isinstance(args.n_splits, int) or args.n_splits <= 0: raise ValueError("--n_splits must be a positive integer") if not isinstance(args.random_state, int): raise ValueError("--random_state must be an integer") + if not args.layer: + print("--layer is not specified, using all classes from labels") if args.plot_type in ['scatter']: if not args.labels: @@ -1174,7 +1264,7 @@ survival_name = Path(args.survival_data).stem output_name_base = f"{survival_name}_km" elif args.plot_type == 'cox': - model_name = Path(args.model).stem + model_name = Path(args.important_features).stem output_name_base = f"{model_name}_cox" elif args.plot_type == 'scatter': labels_name = Path(args.labels).stem @@ -1196,33 +1286,34 @@ if args.plot_type in ['dimred']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) # Load embeddings data print(f"Loading embeddings from: {args.embeddings}") embeddings, sample_names = load_embeddings(args.embeddings) print(f"embeddings shape: {embeddings.shape}") # Match samples to embeddings - matched_labels = match_samples_to_embeddings(sample_names, label_data) + matched_labels = match_samples_to_embeddings(sample_names, labels) print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") - + print(f"Matched labels shape: {matched_labels.shape}") + print(f"Columns in matched labels: {matched_labels.columns.tolist()}") generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) elif args.plot_type in ['kaplan_meier']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) # Load survival data print(f"Loading survival data from: {args.survival_data}") - survival_data = load_survival_data(args.survival_data) + survival_data = load_labels(args.survival_data) print(f"Survival data shape: {survival_data.shape}") - generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) + generate_km_plots(survival_data, labels, args, output_dir, output_name_base) elif args.plot_type in ['cox']: - # Load model and datasets - print(f"Loading model from: {args.model}") - model = load_model(args.model) + # Load important_features and datasets + print(f"Loading important features from: {args.important_features}") + important_features = load_labels(args.important_features) print(f"Loading training dataset from: {args.clinical_train}") clinical_train = load_omics(args.clinical_train) print(f"Loading test dataset from: {args.clinical_test}") @@ -1232,42 +1323,42 @@ print(f"Loading test omics dataset from: {args.omics_test}") omics_test = load_omics(args.omics_test) - generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) + generate_cox_plots(important_features, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) elif args.plot_type in ['scatter']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) - generate_plot_scatter(label_data, args, output_dir, output_name_base) + generate_plot_scatter(labels, args, output_dir, output_name_base) elif args.plot_type in ['concordance_heatmap']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) - generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) + generate_label_concordance_heatmap(labels, args, output_dir, output_name_base) elif args.plot_type in ['pr_curve']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) - generate_pr_curves(label_data, args, output_dir, output_name_base) + generate_pr_curves(labels, args, output_dir, output_name_base) elif args.plot_type in ['roc_curve']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) - generate_roc_curves(label_data, args, output_dir, output_name_base) + generate_roc_curves(labels, args, output_dir, output_name_base) elif args.plot_type in ['box_plot']: # Load labels print(f"Loading labels from: {args.labels}") - label_data = load_labels(args.labels) + labels = load_labels(args.labels) - generate_box_plots(label_data, args, output_dir, output_name_base) + generate_box_plots(labels, args, output_dir, output_name_base) print("All plots generated successfully!")