comparison flexynesis_plot.py @ 3:525c661a7fdc draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit b2463fb68d0ae54864d87718ee72f5e063aa4587
author bgruening
date Tue, 24 Jun 2025 05:55:40 +0000
parents
children
comparison
equal deleted inserted replaced
2:902e26dc8e81 3:525c661a7fdc
1 #!/usr/bin/env python
2 """Generate plots using flexynesis
3 This script generates dimensionality reduction plots, Kaplan-Meier survival curves,
4 and Cox proportional hazards models from data processed by flexynesis."""
5
6 import argparse
7 import os
8 from pathlib import Path
9
10 import matplotlib.pyplot as plt
11 import numpy as np
12 import pandas as pd
13 import seaborn as sns
14 import torch
15 from flexynesis import (
16 build_cox_model,
17 get_important_features,
18 plot_dim_reduced,
19 plot_hazard_ratios,
20 plot_kaplan_meier_curves,
21 plot_pr_curves,
22 plot_roc_curves,
23 plot_scatter
24 )
25 from scipy.stats import kruskal, mannwhitneyu
26
27
28 def load_embeddings(embeddings_path):
29 """Load embeddings from a file"""
30 try:
31 # Determine file extension
32 file_ext = Path(embeddings_path).suffix.lower()
33
34 if file_ext == '.csv':
35 df = pd.read_csv(embeddings_path, index_col=0)
36 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
37 df = pd.read_csv(embeddings_path, sep='\t', index_col=0)
38 else:
39 raise ValueError(f"Unsupported file extension: {file_ext}")
40
41 return df, df.index.tolist()
42
43 except Exception as e:
44 raise ValueError(f"Error loading embeddings from {embeddings_path}: {e}") from e
45
46
47 def load_labels(labels_input):
48 """Load predicted labels from flexynesis"""
49 try:
50 # Determine file extension
51 file_ext = Path(labels_input).suffix.lower()
52
53 if file_ext == '.csv':
54 df = pd.read_csv(labels_input)
55 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
56 df = pd.read_csv(labels_input, sep='\t')
57
58 # Check if this is the specific format with sample_id, known_label, predicted_label
59 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label']
60 if all(col in df.columns for col in required_cols):
61 return df
62 else:
63 raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}")
64
65 except Exception as e:
66 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e
67
68
69 def load_survival_data(survival_path):
70 """Load survival data from a file. First column should be sample_id"""
71 try:
72 # Determine file extension
73 file_ext = Path(survival_path).suffix.lower()
74
75 if file_ext == '.csv':
76 df = pd.read_csv(survival_path, index_col=0)
77 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
78 df = pd.read_csv(survival_path, sep='\t', index_col=0)
79 else:
80 raise ValueError(f"Unsupported file extension: {file_ext}")
81 return df
82
83 except Exception as e:
84 raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e
85
86
87 def load_omics(omics_path):
88 """Load omics data from a file. First column should be features"""
89 try:
90 # Determine file extension
91 file_ext = Path(omics_path).suffix.lower()
92
93 if file_ext == '.csv':
94 df = pd.read_csv(omics_path, index_col=0)
95 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']:
96 df = pd.read_csv(omics_path, sep='\t', index_col=0)
97 else:
98 raise ValueError(f"Unsupported file extension: {file_ext}")
99 return df
100
101 except Exception as e:
102 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e
103
104
105 def load_model(model_path):
106 """Load flexynesis model from pickle file"""
107 try:
108 with open(model_path, 'rb') as f:
109 model = torch.load(f, weights_only=False)
110 return model
111 except Exception as e:
112 raise ValueError(f"Error loading model from {model_path}: {e}") from e
113
114
115 def match_samples_to_embeddings(sample_names, label_data):
116 """Filter label data to match sample names in the embeddings"""
117 df_matched = label_data[label_data['sample_id'].isin(sample_names)]
118 return df_matched
119
120
121 def detect_color_type(labels_series):
122 """Auto-detect whether target variables should be treated as categorical or numerical"""
123 # Remove NaN
124 clean_labels = labels_series.dropna()
125
126 if clean_labels.empty:
127 return 'categorical' # default output if no labels
128
129 # Check if all values can be converted to numbers
130 try:
131 numeric_labels = pd.to_numeric(clean_labels, errors='coerce')
132
133 # If conversion failed -> categorical
134 if numeric_labels.isna().any():
135 return 'categorical'
136
137 # Check number of unique values
138 unique_count = len(clean_labels.unique())
139 total_count = len(clean_labels)
140
141 # If few unique values relative to total -> categorical
142 # Threshold: if unique values < 10 OR unique/total < 0.1
143 if unique_count < 10 or (unique_count / total_count) < 0.1:
144 return 'categorical'
145 else:
146 return 'numerical'
147
148 except Exception:
149 return 'categorical'
150
151
152 def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)):
153 """
154 Plot a heatmap reflecting the concordance between two sets of labels using pandas crosstab.
155
156 Parameters:
157 - labels1: The first set of labels.
158 - labels2: The second set of labels.
159 """
160 # Compute the cross-tabulation
161 ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2'))
162 # Normalize the cross-tabulation matrix column-wise
163 ct_normalized = ct.div(ct.sum(axis=1), axis=0)
164
165 # Plot the heatmap
166 plt.figure(figsize=figsize)
167 sns.heatmap(ct_normalized, annot=True, cmap='viridis', linewidths=.5) # col_cluster=False)
168 plt.title('Concordance between label groups')
169
170 return plt.gcf()
171
172
173 def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize=4):
174 """
175 Create a boxplot with to visualize the distribution of predicted probabilities across different categories.
176 the x axis represents the true labels, and the y axis represents the predicted probabilities for specific categories.
177 """
178 df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y})
179
180 # Compute p-value
181 groups = df[title_x].unique()
182 if len(groups) == 2:
183 group1 = df[df[title_x] == groups[0]][title_y]
184 group2 = df[df[title_x] == groups[1]][title_y]
185 stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
186 test_name = "Mann-Whitney U"
187 else:
188 group_data = [df[df[title_x] == group][title_y] for group in groups]
189 stat, p = kruskal(*group_data)
190 test_name = "Kruskal-Wallis"
191
192 # Create a boxplot with jittered points
193 plt.figure(figsize=figsize)
194 sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill=False)
195 sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4)
196
197 # Labels and p-value annotation
198 plt.xlabel(title_x)
199 plt.ylabel(title_y)
200 plt.text(
201 x=-0.4,
202 y=plt.ylim()[1],
203 s=f'{test_name} p = {p:.3e}',
204 verticalalignment='top',
205 horizontalalignment='left',
206 fontsize=12,
207 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray')
208 )
209
210 plt.tight_layout()
211 return plt.gcf()
212
213
214 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base):
215 """Generate dimensionality reduction plots"""
216
217 # Parse target variables
218 target_vars = [var.strip() for var in args.target_variables.split(',')]
219
220 print(f"Generating {args.method.upper()} plots for {len(target_vars)} target variable(s): {', '.join(target_vars)}")
221
222 # Check variables
223 available_vars = matched_labels['variable'].unique()
224 missing_vars = [var for var in target_vars if var not in available_vars]
225
226 if missing_vars:
227 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}")
228 print(f"Available variables: {', '.join(available_vars)}")
229
230 # Filter to only process available variables
231 valid_vars = [var for var in target_vars if var in available_vars]
232
233 if not valid_vars:
234 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}")
235
236 # Generate plots for each valid target variable
237 for var in valid_vars:
238 print(f"\nPlotting variable: {var}")
239
240 # Filter matched labels for current variable
241 var_labels = matched_labels[matched_labels['variable'] == var].copy()
242 var_labels = var_labels.drop_duplicates(subset='sample_id')
243
244 if var_labels.empty:
245 print(f"Warning: No data found for variable '{var}', skipping...")
246 continue
247
248 # Auto-detect color type
249 known_color_type = detect_color_type(var_labels['known_label'])
250 predicted_color_type = detect_color_type(var_labels['predicted_label'])
251
252 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}")
253
254 try:
255 # Plot 1: Known labels
256 print(f" Creating known labels plot for {var}...")
257 fig_known = plot_dim_reduced(
258 matrix=embeddings,
259 labels=var_labels['known_label'],
260 method=args.method,
261 color_type=known_color_type
262 )
263
264 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}"
265 print(f" Saving known labels plot to: {output_path_known.name}")
266 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
267
268 # Plot 2: Predicted labels
269 print(f" Creating predicted labels plot for {var}...")
270 fig_predicted = plot_dim_reduced(
271 matrix=embeddings,
272 labels=var_labels['predicted_label'],
273 method=args.method,
274 color_type=predicted_color_type
275 )
276
277 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}"
278 print(f" Saving predicted labels plot to: {output_path_predicted.name}")
279 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight')
280
281 print(f" ✓ Successfully created plots for variable '{var}'")
282
283 except Exception as e:
284 print(f" ✗ Error creating plots for variable '{var}': {e}")
285 continue
286
287 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!")
288
289
290 def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base):
291 """Generate Kaplan-Meier plots"""
292 print("Generating Kaplan-Meier curves of risk subtypes...")
293
294 # Reset index and rename the index column to sample_id
295 survival_data = survival_data.reset_index()
296 if survival_data.columns[0] != 'sample_id':
297 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'})
298
299 # Convert survival event column to binary (0/1) based on event_value
300 # Check if the event column exists
301 if args.surv_event_var not in survival_data.columns:
302 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data")
303
304 # Convert to string for comparison to handle mixed types
305 survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str)
306 event_value_str = str(args.event_value)
307
308 # Create binary event column (1 if matches event_value, 0 otherwise)
309 survival_data[f'{args.surv_event_var}_binary'] = (
310 survival_data[args.surv_event_var] == event_value_str
311 ).astype(int)
312
313 # Filter for survival category and class_label == '1:DECEASED'
314 label_data['class_label'] = label_data['class_label'].astype(str)
315
316 label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)]
317
318 # check survival data
319 for col in [args.surv_time_var, args.surv_event_var]:
320 if col not in survival_data.columns:
321 raise ValueError(f"Column '{col}' not found in survival data")
322
323 # Merge survival data with labels
324 df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner')
325
326 if df_deceased.empty:
327 raise ValueError("No matching samples found after merging survival and label data.")
328
329 # Get risk scores
330 risk_scores = df_deceased['probability'].values
331
332 # Compute groups (e.g., median split)
333 quantiles = np.quantile(risk_scores, [0.5])
334 groups = np.digitize(risk_scores, quantiles)
335 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups]
336
337 fig_known = plot_kaplan_meier_curves(
338 durations=df_deceased[args.surv_time_var],
339 events=df_deceased[f'{args.surv_event_var}_binary'],
340 categorical_variable=group_labels
341 )
342
343 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}"
344 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}")
345 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight')
346
347 print("Kaplan-Meier plot saved successfully!")
348
349
350 def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base):
351 """Generate Cox proportional hazards plots"""
352 print("Generating Cox proportional hazards analysis...")
353
354 # Parse clinical variables
355 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')]
356
357 # Validate that survival variables are included
358 required_vars = [args.surv_time_var, args.surv_event_var]
359 for var in required_vars:
360 if var not in clinical_vars:
361 clinical_vars.append(var)
362
363 print(f"Using clinical variables: {', '.join(clinical_vars)}")
364
365 # filter datasets for clinical variables
366 if all(var in clinical_train.columns and var in clinical_test.columns for var in clinical_vars):
367 df_clin_train = clinical_train[clinical_vars]
368 df_clin_test = clinical_test[clinical_vars]
369 # Drop rows with NaN in clinical variables
370 df_clin_train = df_clin_train.dropna(subset=clinical_vars)
371 df_clin_test = df_clin_test.dropna(subset=clinical_vars)
372 else:
373 raise ValueError(f"Not all clinical variables found in datasets. Available in train dataset: {clinical_train.columns.tolist()}, Available in test dataset: {clinical_test.columns.tolist()}")
374
375 # Combine
376 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0)
377
378 # Get top survival markers
379 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...")
380 try:
381 imp = get_important_features(model,
382 var=args.surv_event_var,
383 top=args.top_features
384 )['name'].unique().tolist()
385 print(f"Top features: {', '.join(imp)}")
386 except Exception as e:
387 raise ValueError(f"Error getting important features: {e}")
388
389 # Extract feature data from omics datasets
390 try:
391 omics_test = omics_test.loc[omics_test.index.isin(imp)]
392 omics_train = omics_train.loc[omics_train.index.isin(imp)]
393 # Drop rows with NaN in omics datasets
394 omics_test = omics_test.dropna(subset=omics_test.columns)
395 omics_train = omics_train.dropna(subset=omics_train.columns)
396
397 df_imp = pd.concat([omics_train, omics_test], axis=1)
398 df_imp = df_imp.T # Transpose to have samples as rows
399
400 print(f"Feature data shape: {df_imp.shape}")
401 except Exception as e:
402 raise ValueError(f"Error extracting feature subset: {e}")
403
404 # Combine markers with clinical variables
405 df = pd.merge(df_imp, df_clin, left_index=True, right_index=True)
406 print(f"Combined data shape: {df.shape}")
407
408 # Remove samples without survival endpoints
409 initial_samples = len(df)
410 df = df[df[args.surv_event_var].notna()]
411 final_samples = len(df)
412 print(f"Removed {initial_samples - final_samples} samples without survival data")
413
414 if df.empty:
415 raise ValueError("No samples remain after filtering for survival data")
416
417 # Convert survival event column to binary (0/1) based on event_value
418 # Convert to string for comparison to handle mixed types
419 df[args.surv_event_var] = df[args.surv_event_var].astype(str)
420 event_value_str = str(args.event_value)
421
422 df[f'{args.surv_event_var}'] = (
423 df[args.surv_event_var] == event_value_str
424 ).astype(int)
425
426 # Build Cox model
427 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}")
428 try:
429 coxm = build_cox_model(df,
430 duration_col=args.surv_time_var,
431 event_col=args.surv_event_var,
432 crossval=args.crossval,
433 n_splits=args.n_splits,
434 random_state=args.random_state)
435 print("Cox model built successfully")
436 except Exception as e:
437 raise ValueError(f"Error building Cox model: {e}")
438
439 # Generate hazard ratios plot
440 try:
441 print("Generating hazard ratios plot...")
442 fig = plot_hazard_ratios(coxm)
443
444 output_path = output_dir / f"{output_name_base}_hazard_ratios.{args.format}"
445 print(f"Saving hazard ratios plot to: {output_path.absolute()}")
446 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
447
448 print("Cox proportional hazards analysis completed successfully!")
449
450 except Exception as e:
451 raise ValueError(f"Error generating hazard ratios plot: {e}")
452
453
454 def generate_plot_scatter(labels, args, output_dir, output_name_base):
455 """Generate scatter plot of known vs predicted labels"""
456 print("Generating scatter plots of known vs predicted labels...")
457
458 # Parse target values from comma-separated string
459 if args.target_value:
460 target_values = [val.strip() for val in args.target_value.split(',')]
461 else:
462 # If no target values specified, use all unique variables
463 target_values = labels['variable'].unique().tolist()
464
465 print(f"Processing target values: {target_values}")
466
467 successful_plots = 0
468 skipped_plots = 0
469
470 for target_value in target_values:
471 print(f"\nProcessing target value: '{target_value}'")
472
473 # Filter labels for the current target value
474 target_labels = labels[labels['variable'] == target_value]
475
476 if target_labels.empty:
477 print(f" Warning: No data found for target value '{target_value}' - skipping")
478 skipped_plots += 1
479 continue
480
481 # Check if labels are numeric and convert
482 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce')
483 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce')
484
485 if true_values.isna().all() or predicted_values.isna().all():
486 print(f"No valid numeric values found for known or predicted labels in '{target_value}'")
487 skipped_plots += 1
488 continue
489
490 try:
491 print(f" Generating scatter plot for '{target_value}'...")
492 fig = plot_scatter(true_values, predicted_values)
493
494 # Create output filename with target value
495 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
496 if len(target_values) > 1:
497 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
498 else:
499 output_filename = f"{output_name_base}.{args.format}"
500
501 output_path = output_dir / output_filename
502 print(f" Saving scatter plot to: {output_path.absolute()}")
503 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
504
505 successful_plots += 1
506 print(f" Scatter plot for '{target_value}' generated successfully!")
507
508 except Exception as e:
509 print(f" Error generating plot for '{target_value}': {str(e)}")
510 skipped_plots += 1
511
512 # Summary
513 print(" Summary:")
514 print(f" Successfully generated: {successful_plots} plots")
515 print(f" Skipped: {skipped_plots} plots")
516
517 if successful_plots == 0:
518 raise ValueError("No scatter plots could be generated. Check your data and target values.")
519
520 print("Scatter plot generation completed!")
521
522
523 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base):
524 """Generate label concordance heatmap"""
525 print("Generating label concordance heatmaps...")
526
527 # Parse target values from comma-separated string
528 if args.target_value:
529 target_values = [val.strip() for val in args.target_value.split(',')]
530 else:
531 # If no target values specified, use all unique variables
532 target_values = labels['variable'].unique().tolist()
533
534 print(f"Processing target values: {target_values}")
535
536 for target_value in target_values:
537 print(f"\nProcessing target value: '{target_value}'")
538
539 # Filter labels for the current target value
540 target_labels = labels[labels['variable'] == target_value]
541
542 if target_labels.empty:
543 print(f" Warning: No data found for target value '{target_value}' - skipping")
544 continue
545
546 true_values = target_labels['known_label'].tolist()
547 predicted_values = target_labels['predicted_label'].tolist()
548
549 try:
550 print(f" Generating heatmap for '{target_value}'...")
551 fig = plot_label_concordance_heatmap(true_values, predicted_values)
552 plt.close(fig)
553
554 # Create output filename with target value
555 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
556 if len(target_values) > 1:
557 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
558 else:
559 output_filename = f"{output_name_base}.{args.format}"
560
561 output_path = output_dir / output_filename
562 print(f" Saving heatmap to: {output_path.absolute()}")
563 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
564
565 except Exception as e:
566 print(f" Error generating heatmap for '{target_value}': {str(e)}")
567 continue
568
569 print("Label concordance heatmap generated successfully!")
570
571
572 def generate_pr_curves(labels, args, output_dir, output_name_base):
573 """Generate precision-recall curves"""
574 print("Generating precision-recall curves...")
575
576 # Parse target values from comma-separated string
577 if args.target_value:
578 target_values = [val.strip() for val in args.target_value.split(',')]
579 else:
580 # If no target values specified, use all unique variables
581 target_values = labels['variable'].unique().tolist()
582
583 print(f"Processing target values: {target_values}")
584
585 for target_value in target_values:
586 print(f"\nProcessing target value: '{target_value}'")
587
588 # Filter labels for the current target value
589 target_labels = labels[labels['variable'] == target_value]
590
591 # Check if this is a regression problem (no class probabilities)
592 prob_columns = target_labels['class_label'].unique()
593 non_na_probs = target_labels['probability'].notna().sum()
594
595 print(f" Class labels found: {list(prob_columns)}")
596 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
597
598 # If most probabilities are NaN, this is likely a regression problem
599 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities
600 print(" Detected regression problem - precision-recall curves not applicable")
601 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)")
602 continue
603
604 # Debug: Check data quality
605 total_rows = len(target_labels)
606 missing_labels = target_labels['known_label'].isna().sum()
607 missing_probs = target_labels['probability'].isna().sum()
608 unique_samples = target_labels['sample_id'].nunique()
609 unique_classes = target_labels['class_label'].nunique()
610
611 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
612 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
613
614 if missing_labels > 0:
615 print(f" Warning: Found {missing_labels} missing known_label values")
616 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
617 print(f" Sample IDs with missing known_label: {list(missing_samples)}")
618
619 # Remove rows with missing known_label
620 target_labels = target_labels.dropna(subset=['known_label'])
621 if target_labels.empty:
622 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping")
623 continue
624
625 # 1. Pivot to wide format
626 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')
627
628 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
629 print(f" Class columns: {list(prob_df.columns)}")
630
631 # Check for NaN values in probability data
632 nan_counts = prob_df.isna().sum()
633 if nan_counts.any():
634 print(f" NaN counts per class: {dict(nan_counts)}")
635 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")
636
637 # Drop only rows where ALL probabilities are NaN
638 all_nan_rows = prob_df.isna().all(axis=1)
639 if all_nan_rows.any():
640 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
641 prob_df = prob_df[~all_nan_rows]
642
643 remaining_nans = prob_df.isna().sum().sum()
644 if remaining_nans > 0:
645 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0")
646 prob_df = prob_df.fillna(0)
647
648 if prob_df.empty:
649 print(f" Error: No valid probability data remaining for '{target_value}' - skipping")
650 continue
651
652 # 2. Get true labels
653 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')
654
655 # 3. Align indices - only keep samples that exist in both datasets
656 common_indices = prob_df.index.intersection(true_labels_df.index)
657 if len(common_indices) == 0:
658 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
659 continue
660
661 print(f" Found {len(common_indices)} samples with both probability and true label data")
662
663 # Filter both datasets to common indices
664 prob_df_aligned = prob_df.loc[common_indices]
665 y_true = true_labels_df.loc[common_indices]['known_label']
666
667 # 4. Final check for NaN values
668 if y_true.isna().any():
669 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
670 continue
671
672 if prob_df_aligned.isna().any().any():
673 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
674 continue
675
676 # 5. Convert categorical labels to integer labels
677 # Create a mapping from class names to integers
678 class_names = list(prob_df_aligned.columns)
679 class_to_int = {class_name: i for i, class_name in enumerate(class_names)}
680
681 print(f" Class mapping: {class_to_int}")
682
683 # Convert true labels to integers
684 y_true_np = y_true.map(class_to_int).to_numpy()
685 y_probs_np = prob_df_aligned.to_numpy()
686
687 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
688 print(f" Unique true labels (integers): {set(y_true_np)}")
689 print(f" Class labels (columns): {class_names}")
690 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")
691
692 # Check for any unmapped labels (will be NaN)
693 if pd.isna(y_true_np).any():
694 print(" Error: Some true labels could not be mapped to class columns")
695 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
696 print(f" Unmapped labels: {unmapped_labels}")
697 print(f" Available classes: {class_names}")
698 continue
699
700 try:
701 print(f" Generating precision-recall curve for '{target_value}'...")
702 fig = plot_pr_curves(y_true_np, y_probs_np)
703
704 # Create output filename with target value
705 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
706 if len(target_values) > 1:
707 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
708 else:
709 output_filename = f"{output_name_base}.{args.format}"
710
711 output_path = output_dir / output_filename
712 print(f" Saving precision-recall curve to: {output_path.absolute()}")
713 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
714
715 except Exception as e:
716 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}")
717 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
718 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
719 continue
720
721 print("Precision-recall curves generated successfully!")
722
723
724 def generate_roc_curves(labels, args, output_dir, output_name_base):
725 """Generate ROC curves"""
726 print("Generating ROC curves...")
727
728 # Parse target values from comma-separated string
729 if args.target_value:
730 target_values = [val.strip() for val in args.target_value.split(',')]
731 else:
732 # If no target values specified, use all unique variables
733 target_values = labels['variable'].unique().tolist()
734
735 print(f"Processing target values: {target_values}")
736
737 for target_value in target_values:
738 print(f"\nProcessing target value: '{target_value}'")
739
740 # Filter labels for the current target value
741 target_labels = labels[labels['variable'] == target_value]
742
743 # Check if this is a regression problem (no class probabilities)
744 prob_columns = target_labels['class_label'].unique()
745 non_na_probs = target_labels['probability'].notna().sum()
746
747 print(f" Class labels found: {list(prob_columns)}")
748 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
749
750 # If most probabilities are NaN, this is likely a regression problem
751 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities
752 print(" Detected regression problem - ROC curves not applicable")
753 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)")
754 continue
755
756 # Debug: Check data quality
757 total_rows = len(target_labels)
758 missing_labels = target_labels['known_label'].isna().sum()
759 missing_probs = target_labels['probability'].isna().sum()
760 unique_samples = target_labels['sample_id'].nunique()
761 unique_classes = target_labels['class_label'].nunique()
762
763 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
764 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
765
766 if missing_labels > 0:
767 print(f" Warning: Found {missing_labels} missing known_label values")
768 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
769 print(f" Sample IDs with missing known_label: {list(missing_samples)}")
770
771 # Remove rows with missing known_label
772 target_labels = target_labels.dropna(subset=['known_label'])
773 if target_labels.empty:
774 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping")
775 continue
776
777 # 1. Pivot to wide format
778 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability')
779
780 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes")
781 print(f" Class columns: {list(prob_df.columns)}")
782
783 # Check for NaN values in probability data
784 nan_counts = prob_df.isna().sum()
785 if nan_counts.any():
786 print(f" NaN counts per class: {dict(nan_counts)}")
787 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}")
788
789 # Drop only rows where ALL probabilities are NaN
790 all_nan_rows = prob_df.isna().all(axis=1)
791 if all_nan_rows.any():
792 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities")
793 prob_df = prob_df[~all_nan_rows]
794
795 remaining_nans = prob_df.isna().sum().sum()
796 if remaining_nans > 0:
797 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0")
798 prob_df = prob_df.fillna(0)
799
800 if prob_df.empty:
801 print(f" Error: No valid probability data remaining for '{target_value}' - skipping")
802 continue
803
804 # 2. Get true labels
805 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id')
806
807 # 3. Align indices - only keep samples that exist in both datasets
808 common_indices = prob_df.index.intersection(true_labels_df.index)
809 if len(common_indices) == 0:
810 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping")
811 continue
812
813 print(f" Found {len(common_indices)} samples with both probability and true label data")
814
815 # Filter both datasets to common indices
816 prob_df_aligned = prob_df.loc[common_indices]
817 y_true = true_labels_df.loc[common_indices]['known_label']
818
819 # 4. Final check for NaN values
820 if y_true.isna().any():
821 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping")
822 continue
823
824 if prob_df_aligned.isna().any().any():
825 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping")
826 continue
827
828 # 5. Convert categorical labels to integer labels
829 # Create a mapping from class names to integers
830 class_names = list(prob_df_aligned.columns)
831 class_to_int = {class_name: i for i, class_name in enumerate(class_names)}
832
833 print(f" Class mapping: {class_to_int}")
834
835 # Convert true labels to integers
836 y_true_np = y_true.map(class_to_int).to_numpy()
837 y_probs_np = prob_df_aligned.to_numpy()
838
839 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}")
840 print(f" Unique true labels (integers): {set(y_true_np)}")
841 print(f" Class labels (columns): {class_names}")
842 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}")
843
844 # Check for any unmapped labels (will be NaN)
845 if pd.isna(y_true_np).any():
846 print(" Error: Some true labels could not be mapped to class columns")
847 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()])
848 print(f" Unmapped labels: {unmapped_labels}")
849 print(f" Available classes: {class_names}")
850 continue
851
852 try:
853 print(f" Generating ROC curve for '{target_value}'...")
854 fig = plot_roc_curves(y_true_np, y_probs_np)
855
856 # Create output filename with target value
857 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
858 if len(target_values) > 1:
859 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}"
860 else:
861 output_filename = f"{output_name_base}.{args.format}"
862
863 output_path = output_dir / output_filename
864 print(f" Saving ROC curve to: {output_path.absolute()}")
865 fig.save(output_path, dpi=args.dpi, bbox_inches='tight')
866
867 except Exception as e:
868 print(f" Error generating ROC curve for '{target_value}': {str(e)}")
869 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}")
870 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}")
871 continue
872
873 print("ROC curves generated successfully!")
874
875
876 def generate_box_plots(labels, args, output_dir, output_name_base):
877 """Generate box plots for model predictions"""
878
879 print("Generating box plots...")
880
881 # Parse target values from comma-separated string
882 if args.target_value:
883 target_values = [val.strip() for val in args.target_value.split(',')]
884 else:
885 # If no target values specified, use all unique variables
886 target_values = labels['variable'].unique().tolist()
887
888 print(f"Processing target values: {target_values}")
889
890 for target_value in target_values:
891 print(f"\nProcessing target value: '{target_value}'")
892
893 # Filter labels for the current target value
894 target_labels = labels[labels['variable'] == target_value]
895
896 if target_labels.empty:
897 print(f" Warning: No data found for target value '{target_value}' - skipping")
898 continue
899
900 # Check if this is a classification problem (has probabilities)
901 prob_columns = target_labels['class_label'].unique()
902 non_na_probs = target_labels['probability'].notna().sum()
903
904 print(f" Class labels found: {list(prob_columns)}")
905 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}")
906
907 # If most probabilities are NaN, this is likely a regression problem
908 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities
909 print(" Detected regression problem - precision-recall curves not applicable")
910 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)")
911 continue
912
913 # Debug: Check data quality
914 total_rows = len(target_labels)
915 missing_labels = target_labels['known_label'].isna().sum()
916 missing_probs = target_labels['probability'].isna().sum()
917 unique_samples = target_labels['sample_id'].nunique()
918 unique_classes = target_labels['class_label'].nunique()
919
920 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes")
921 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability")
922
923 if missing_labels > 0:
924 print(f" Warning: Found {missing_labels} missing known_label values")
925 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5]
926 print(f" Sample IDs with missing known_label: {list(missing_samples)}")
927
928 # Remove rows with missing known_label
929 target_labels = target_labels.dropna(subset=['known_label'])
930 if target_labels.empty:
931 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping")
932 continue
933
934 # Remove rows with missing data
935 clean_data = target_labels.dropna(subset=['known_label', 'probability'])
936
937 if clean_data.empty:
938 print(" No valid data after cleaning - skipping")
939 continue
940
941 # Get unique classes
942 classes = clean_data['class_label'].unique()
943
944 for class_label in classes:
945 print(f" Generating box plot for class: {class_label}")
946
947 # Filter for current class
948 class_data = clean_data[clean_data['class_label'] == class_label]
949
950 try:
951 # Create the box plot
952 fig = plot_boxplot(
953 categorical_x=class_data['known_label'],
954 numerical_y=class_data['probability'],
955 title_x='True Label',
956 title_y=f'Predicted Probability ({class_label})',
957 )
958
959 # Save the plot
960 safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_')
961 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_')
962 output_filename = f"{output_name_base}_{safe_target_name}_{safe_class_name}.{args.format}"
963 output_path = output_dir / output_filename
964
965 print(f" Saving box plot to: {output_path.absolute()}")
966 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight')
967 plt.close(fig)
968
969 except Exception as e:
970 print(f" Error generating box plot for class '{class_label}': {str(e)}")
971 continue
972
973
974 def main():
975 """Main function to parse arguments and generate plots"""
976 parser = argparse.ArgumentParser(description="Generate plots using flexynesis")
977
978 # Required arguments
979 parser.add_argument("--labels", type=str, required=False,
980 help="Path to labels file generated by flexynesis")
981
982 # Plot type
983 parser.add_argument("--plot_type", type=str, required=True,
984 choices=['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'],
985 help="Type of plot to generate: 'dimred' for dimensionality reduction, 'kaplan_meier' for survival analysis, 'cox' for Cox proportional hazards analysis, 'scatter' for scatter plots, 'concordance_heatmap' for label concordance heatmaps, 'pr_curve' for precision-recall curves, 'roc_curve' for ROC curves, or 'box_plot' for box plots.")
986
987 # Arguments for dimensionality reduction
988 parser.add_argument("--embeddings", type=str,
989 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.")
990 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'],
991 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.")
992 parser.add_argument("--target_variables", type=str, required=False,
993 help="Comma-separated list of target variables to plot.")
994
995 # Arguments for Kaplan-Meier
996 parser.add_argument("--survival_data", type=str,
997 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.")
998 parser.add_argument("--surv_time_var", type=str, required=False,
999 help="Column name for survival time")
1000 parser.add_argument("--surv_event_var", type=str, required=False,
1001 help="Column name for survival event")
1002 parser.add_argument("--event_value", type=str, required=False,
1003 help="Value in event column that represents an event (e.g., 'DECEASED')")
1004
1005 # Arguments for Cox analysis
1006 parser.add_argument("--model", type=str,
1007 help="Path to trained flexynesis model (pickle file). Required for cox plots.")
1008 parser.add_argument("--clinical_train", type=str,
1009 help="Path to training dataset (pickle file). Required for cox plots.")
1010 parser.add_argument("--clinical_test", type=str,
1011 help="Path to test dataset (pickle file). Required for cox plots.")
1012 parser.add_argument("--omics_train", type=str, default=None,
1013 help="Path to training omics dataset. Optional for cox plots.")
1014 parser.add_argument("--omics_test", type=str, default=None,
1015 help="Path to test omics dataset. Optional for cox plots.")
1016 parser.add_argument("--clinical_variables", type=str,
1017 help="Comma-separated list of clinical variables to include in Cox model (e.g., 'AGE,SEX,HISTOLOGICAL_DIAGNOSIS,STUDY')")
1018 parser.add_argument("--top_features", type=int, default=20,
1019 help="Number of top important features to include in Cox model. Default is 5")
1020 parser.add_argument("--crossval", action='store_true',
1021 help="If True, performs K-fold cross-validation and returns average C-index. Default is False")
1022 parser.add_argument("--n_splits", type=int, default=5,
1023 help="Number of folds for cross-validation. Default is 5")
1024 parser.add_argument("--random_state", type=int, default=42,
1025 help="Random seed for reproducibility. Default is 42")
1026
1027 # Arguments for scatter plot, heatmap, PR curves, ROC curves, and box plots
1028 parser.add_argument("--target_value", type=str, default=None,
1029 help="Target value for scatter plot.")
1030
1031 # Common arguments
1032 parser.add_argument("--output_dir", type=str, default='output',
1033 help="Output directory. Default is 'output'")
1034 parser.add_argument("--output_name", type=str, default=None,
1035 help="Output filename base")
1036 parser.add_argument("--format", type=str, default='jpg', choices=['png', 'pdf', 'svg', 'jpg'],
1037 help="Output format for the plot. Default is 'jpg'")
1038 parser.add_argument("--dpi", type=int, default=300,
1039 help="DPI for the output image. Default is 300")
1040
1041 args = parser.parse_args()
1042
1043 try:
1044 # validate plot type
1045 if not args.plot_type:
1046 raise ValueError("Please specify a plot type using --plot_type")
1047 if args.plot_type not in ['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot']:
1048 raise ValueError(f"Invalid plot type: {args.plot_type}. Must be one of: 'dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'")
1049
1050 # Validate plot type requirements
1051 if args.plot_type in ['dimred']:
1052 if not args.embeddings:
1053 raise ValueError("--embeddings is required when plot_type is 'dimred'")
1054 if not os.path.isfile(args.embeddings):
1055 raise FileNotFoundError(f"embeddings file not found: {args.embeddings}")
1056 if not args.labels:
1057 raise ValueError("--labels is required for dimensionality reduction plots")
1058 if not args.method:
1059 raise ValueError("--method is required for dimensionality reduction plots")
1060 if not args.target_variables:
1061 raise ValueError("--target_variables is required for dimensionality reduction plots")
1062
1063 if args.plot_type in ['kaplan_meier']:
1064 if not args.survival_data:
1065 raise ValueError("--survival_data is required when plot_type is 'kaplan_meier'")
1066 if not os.path.isfile(args.survival_data):
1067 raise FileNotFoundError(f"Survival data file not found: {args.survival_data}")
1068 if not args.labels:
1069 raise ValueError("--labels is required for dimensionality reduction plots")
1070 if not args.method:
1071 raise ValueError("--method is required for dimensionality reduction plots")
1072 if not args.surv_time_var:
1073 raise ValueError("--surv_time_var is required for Kaplan-Meier plots")
1074 if not args.surv_event_var:
1075 raise ValueError("--surv_event_var is required for Kaplan-Meier plots")
1076 if not args.event_value:
1077 raise ValueError("--event_value is required for Kaplan-Meier plots")
1078
1079 if args.plot_type in ['cox']:
1080 if not args.model:
1081 raise ValueError("--model is required when plot_type is 'cox'")
1082 if not os.path.isfile(args.model):
1083 raise FileNotFoundError(f"Model file not found: {args.model}")
1084 if not args.clinical_train:
1085 raise ValueError("--clinical_train is required when plot_type is 'cox'")
1086 if not os.path.isfile(args.clinical_train):
1087 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}")
1088 if not args.clinical_test:
1089 raise ValueError("--clinical_test is required when plot_type is 'cox'")
1090 if not os.path.isfile(args.clinical_test):
1091 raise FileNotFoundError(f"Test dataset file not found: {args.clinical_test}")
1092 if not args.omics_train:
1093 raise ValueError("--omics_train is required when plot_type is 'cox'")
1094 if not os.path.isfile(args.omics_train):
1095 raise FileNotFoundError(f"Training omics dataset file not found: {args.omics_train}")
1096 if not args.omics_test:
1097 raise ValueError("--omics_test is required when plot_type is 'cox'")
1098 if not os.path.isfile(args.omics_test):
1099 raise FileNotFoundError(f"Test omics dataset file not found: {args.omics_test}")
1100 if not args.surv_time_var:
1101 raise ValueError("--surv_time_var is required for Cox plots")
1102 if not args.surv_event_var:
1103 raise ValueError("--surv_event_var is required for Cox plots")
1104 if not args.clinical_variables:
1105 raise ValueError("--clinical_variables is required for Cox plots")
1106 if not isinstance(args.top_features, int) or args.top_features <= 0:
1107 raise ValueError("--top_features must be a positive integer")
1108 if not args.event_value:
1109 raise ValueError("--event_value is required for Kaplan-Meier plots")
1110 if not args.crossval:
1111 args.crossval = False
1112 if not isinstance(args.n_splits, int) or args.n_splits <= 0:
1113 raise ValueError("--n_splits must be a positive integer")
1114 if not isinstance(args.random_state, int):
1115 raise ValueError("--random_state must be an integer")
1116
1117 if args.plot_type in ['scatter']:
1118 if not args.labels:
1119 raise ValueError("--labels is required for scatter plots")
1120 if not args.target_value:
1121 print("--target_value is not specified, using all unique variables from labels")
1122 if not os.path.isfile(args.labels):
1123 raise FileNotFoundError(f"Labels file not found: {args.labels}")
1124
1125 if args.plot_type in ['concordance_heatmap']:
1126 if not args.labels:
1127 raise ValueError("--labels is required for concordance heatmap")
1128 if not args.target_value:
1129 print("--target_value is not specified, using all unique variables from labels")
1130 if not os.path.isfile(args.labels):
1131 raise FileNotFoundError(f"Labels file not found: {args.labels}")
1132
1133 if args.plot_type in ['pr_curve']:
1134 if not args.labels:
1135 raise ValueError("--labels is required for precision-recall curves")
1136 if not args.target_value:
1137 print("--target_value is not specified, using all unique variables from labels")
1138 if not os.path.isfile(args.labels):
1139 raise FileNotFoundError(f"Labels file not found: {args.labels}")
1140
1141 if args.plot_type in ['roc_curve']:
1142 if not args.labels:
1143 raise ValueError("--labels is required for ROC curves")
1144 if not args.target_value:
1145 print("--target_value is not specified, using all unique variables from labels")
1146 if not os.path.isfile(args.labels):
1147 raise FileNotFoundError(f"Labels file not found: {args.labels}")
1148
1149 if args.plot_type in ['box_plot']:
1150 if not args.labels:
1151 raise ValueError("--labels is required for box plots")
1152 if not args.target_value:
1153 print("--target_value is not specified, using all unique variables from labels")
1154 if not os.path.isfile(args.labels):
1155 raise FileNotFoundError(f"Labels file not found: {args.labels}")
1156
1157 # Validate other arguments
1158 if args.method not in ['pca', 'umap']:
1159 raise ValueError("Method must be 'pca' or 'umap'")
1160
1161 # Create output directory
1162 output_dir = Path(args.output_dir)
1163 output_dir.mkdir(parents=True, exist_ok=True)
1164 print(f"Output directory: {output_dir.absolute()}")
1165
1166 # Generate output filename base
1167 if args.output_name:
1168 output_name_base = args.output_name
1169 else:
1170 if args.plot_type == 'dimred':
1171 embeddings_name = Path(args.embeddings).stem
1172 output_name_base = f"{embeddings_name}_{args.method}"
1173 elif args.plot_type == 'kaplan_meier':
1174 survival_name = Path(args.survival_data).stem
1175 output_name_base = f"{survival_name}_km"
1176 elif args.plot_type == 'cox':
1177 model_name = Path(args.model).stem
1178 output_name_base = f"{model_name}_cox"
1179 elif args.plot_type == 'scatter':
1180 labels_name = Path(args.labels).stem
1181 output_name_base = f"{labels_name}_scatter"
1182 elif args.plot_type == 'concordance_heatmap':
1183 labels_name = Path(args.labels).stem
1184 output_name_base = f"{labels_name}_concordance"
1185 elif args.plot_type == 'pr_curve':
1186 labels_name = Path(args.labels).stem
1187 output_name_base = f"{labels_name}_pr_curves"
1188 elif args.plot_type == 'roc_curve':
1189 labels_name = Path(args.labels).stem
1190 output_name_base = f"{labels_name}_roc_curves"
1191 elif args.plot_type == 'box_plot':
1192 labels_name = Path(args.labels).stem
1193 output_name_base = f"{labels_name}_box_plot"
1194
1195 # Generate plots based on type
1196 if args.plot_type in ['dimred']:
1197 # Load labels
1198 print(f"Loading labels from: {args.labels}")
1199 label_data = load_labels(args.labels)
1200 # Load embeddings data
1201 print(f"Loading embeddings from: {args.embeddings}")
1202 embeddings, sample_names = load_embeddings(args.embeddings)
1203 print(f"embeddings shape: {embeddings.shape}")
1204
1205 # Match samples to embeddings
1206 matched_labels = match_samples_to_embeddings(sample_names, label_data)
1207 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction")
1208
1209 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base)
1210
1211 elif args.plot_type in ['kaplan_meier']:
1212 # Load labels
1213 print(f"Loading labels from: {args.labels}")
1214 label_data = load_labels(args.labels)
1215 # Load survival data
1216 print(f"Loading survival data from: {args.survival_data}")
1217 survival_data = load_survival_data(args.survival_data)
1218 print(f"Survival data shape: {survival_data.shape}")
1219
1220 generate_km_plots(survival_data, label_data, args, output_dir, output_name_base)
1221
1222 elif args.plot_type in ['cox']:
1223 # Load model and datasets
1224 print(f"Loading model from: {args.model}")
1225 model = load_model(args.model)
1226 print(f"Loading training dataset from: {args.clinical_train}")
1227 clinical_train = load_omics(args.clinical_train)
1228 print(f"Loading test dataset from: {args.clinical_test}")
1229 clinical_test = load_omics(args.clinical_test)
1230 print(f"Loading training omics dataset from: {args.omics_train}")
1231 omics_train = load_omics(args.omics_train)
1232 print(f"Loading test omics dataset from: {args.omics_test}")
1233 omics_test = load_omics(args.omics_test)
1234
1235 generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base)
1236
1237 elif args.plot_type in ['scatter']:
1238 # Load labels
1239 print(f"Loading labels from: {args.labels}")
1240 label_data = load_labels(args.labels)
1241
1242 generate_plot_scatter(label_data, args, output_dir, output_name_base)
1243
1244 elif args.plot_type in ['concordance_heatmap']:
1245 # Load labels
1246 print(f"Loading labels from: {args.labels}")
1247 label_data = load_labels(args.labels)
1248
1249 generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base)
1250
1251 elif args.plot_type in ['pr_curve']:
1252 # Load labels
1253 print(f"Loading labels from: {args.labels}")
1254 label_data = load_labels(args.labels)
1255
1256 generate_pr_curves(label_data, args, output_dir, output_name_base)
1257
1258 elif args.plot_type in ['roc_curve']:
1259 # Load labels
1260 print(f"Loading labels from: {args.labels}")
1261 label_data = load_labels(args.labels)
1262
1263 generate_roc_curves(label_data, args, output_dir, output_name_base)
1264
1265 elif args.plot_type in ['box_plot']:
1266 # Load labels
1267 print(f"Loading labels from: {args.labels}")
1268 label_data = load_labels(args.labels)
1269
1270 generate_box_plots(label_data, args, output_dir, output_name_base)
1271
1272 print("All plots generated successfully!")
1273
1274 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e:
1275 print(f"Error: {e}")
1276 return 1
1277
1278 return 0
1279
1280
1281 if __name__ == "__main__":
1282 exit(main())