Mercurial > repos > bgruening > flexynesis
comparison flexynesis_plot.py @ 3:525c661a7fdc draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit b2463fb68d0ae54864d87718ee72f5e063aa4587
author | bgruening |
---|---|
date | Tue, 24 Jun 2025 05:55:40 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
2:902e26dc8e81 | 3:525c661a7fdc |
---|---|
1 #!/usr/bin/env python | |
2 """Generate plots using flexynesis | |
3 This script generates dimensionality reduction plots, Kaplan-Meier survival curves, | |
4 and Cox proportional hazards models from data processed by flexynesis.""" | |
5 | |
6 import argparse | |
7 import os | |
8 from pathlib import Path | |
9 | |
10 import matplotlib.pyplot as plt | |
11 import numpy as np | |
12 import pandas as pd | |
13 import seaborn as sns | |
14 import torch | |
15 from flexynesis import ( | |
16 build_cox_model, | |
17 get_important_features, | |
18 plot_dim_reduced, | |
19 plot_hazard_ratios, | |
20 plot_kaplan_meier_curves, | |
21 plot_pr_curves, | |
22 plot_roc_curves, | |
23 plot_scatter | |
24 ) | |
25 from scipy.stats import kruskal, mannwhitneyu | |
26 | |
27 | |
28 def load_embeddings(embeddings_path): | |
29 """Load embeddings from a file""" | |
30 try: | |
31 # Determine file extension | |
32 file_ext = Path(embeddings_path).suffix.lower() | |
33 | |
34 if file_ext == '.csv': | |
35 df = pd.read_csv(embeddings_path, index_col=0) | |
36 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
37 df = pd.read_csv(embeddings_path, sep='\t', index_col=0) | |
38 else: | |
39 raise ValueError(f"Unsupported file extension: {file_ext}") | |
40 | |
41 return df, df.index.tolist() | |
42 | |
43 except Exception as e: | |
44 raise ValueError(f"Error loading embeddings from {embeddings_path}: {e}") from e | |
45 | |
46 | |
47 def load_labels(labels_input): | |
48 """Load predicted labels from flexynesis""" | |
49 try: | |
50 # Determine file extension | |
51 file_ext = Path(labels_input).suffix.lower() | |
52 | |
53 if file_ext == '.csv': | |
54 df = pd.read_csv(labels_input) | |
55 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
56 df = pd.read_csv(labels_input, sep='\t') | |
57 | |
58 # Check if this is the specific format with sample_id, known_label, predicted_label | |
59 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
60 if all(col in df.columns for col in required_cols): | |
61 return df | |
62 else: | |
63 raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}") | |
64 | |
65 except Exception as e: | |
66 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e | |
67 | |
68 | |
69 def load_survival_data(survival_path): | |
70 """Load survival data from a file. First column should be sample_id""" | |
71 try: | |
72 # Determine file extension | |
73 file_ext = Path(survival_path).suffix.lower() | |
74 | |
75 if file_ext == '.csv': | |
76 df = pd.read_csv(survival_path, index_col=0) | |
77 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
78 df = pd.read_csv(survival_path, sep='\t', index_col=0) | |
79 else: | |
80 raise ValueError(f"Unsupported file extension: {file_ext}") | |
81 return df | |
82 | |
83 except Exception as e: | |
84 raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e | |
85 | |
86 | |
87 def load_omics(omics_path): | |
88 """Load omics data from a file. First column should be features""" | |
89 try: | |
90 # Determine file extension | |
91 file_ext = Path(omics_path).suffix.lower() | |
92 | |
93 if file_ext == '.csv': | |
94 df = pd.read_csv(omics_path, index_col=0) | |
95 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
96 df = pd.read_csv(omics_path, sep='\t', index_col=0) | |
97 else: | |
98 raise ValueError(f"Unsupported file extension: {file_ext}") | |
99 return df | |
100 | |
101 except Exception as e: | |
102 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e | |
103 | |
104 | |
105 def load_model(model_path): | |
106 """Load flexynesis model from pickle file""" | |
107 try: | |
108 with open(model_path, 'rb') as f: | |
109 model = torch.load(f, weights_only=False) | |
110 return model | |
111 except Exception as e: | |
112 raise ValueError(f"Error loading model from {model_path}: {e}") from e | |
113 | |
114 | |
115 def match_samples_to_embeddings(sample_names, label_data): | |
116 """Filter label data to match sample names in the embeddings""" | |
117 df_matched = label_data[label_data['sample_id'].isin(sample_names)] | |
118 return df_matched | |
119 | |
120 | |
121 def detect_color_type(labels_series): | |
122 """Auto-detect whether target variables should be treated as categorical or numerical""" | |
123 # Remove NaN | |
124 clean_labels = labels_series.dropna() | |
125 | |
126 if clean_labels.empty: | |
127 return 'categorical' # default output if no labels | |
128 | |
129 # Check if all values can be converted to numbers | |
130 try: | |
131 numeric_labels = pd.to_numeric(clean_labels, errors='coerce') | |
132 | |
133 # If conversion failed -> categorical | |
134 if numeric_labels.isna().any(): | |
135 return 'categorical' | |
136 | |
137 # Check number of unique values | |
138 unique_count = len(clean_labels.unique()) | |
139 total_count = len(clean_labels) | |
140 | |
141 # If few unique values relative to total -> categorical | |
142 # Threshold: if unique values < 10 OR unique/total < 0.1 | |
143 if unique_count < 10 or (unique_count / total_count) < 0.1: | |
144 return 'categorical' | |
145 else: | |
146 return 'numerical' | |
147 | |
148 except Exception: | |
149 return 'categorical' | |
150 | |
151 | |
152 def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): | |
153 """ | |
154 Plot a heatmap reflecting the concordance between two sets of labels using pandas crosstab. | |
155 | |
156 Parameters: | |
157 - labels1: The first set of labels. | |
158 - labels2: The second set of labels. | |
159 """ | |
160 # Compute the cross-tabulation | |
161 ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2')) | |
162 # Normalize the cross-tabulation matrix column-wise | |
163 ct_normalized = ct.div(ct.sum(axis=1), axis=0) | |
164 | |
165 # Plot the heatmap | |
166 plt.figure(figsize=figsize) | |
167 sns.heatmap(ct_normalized, annot=True, cmap='viridis', linewidths=.5) # col_cluster=False) | |
168 plt.title('Concordance between label groups') | |
169 | |
170 return plt.gcf() | |
171 | |
172 | |
173 def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize=4): | |
174 """ | |
175 Create a boxplot with to visualize the distribution of predicted probabilities across different categories. | |
176 the x axis represents the true labels, and the y axis represents the predicted probabilities for specific categories. | |
177 """ | |
178 df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y}) | |
179 | |
180 # Compute p-value | |
181 groups = df[title_x].unique() | |
182 if len(groups) == 2: | |
183 group1 = df[df[title_x] == groups[0]][title_y] | |
184 group2 = df[df[title_x] == groups[1]][title_y] | |
185 stat, p = mannwhitneyu(group1, group2, alternative='two-sided') | |
186 test_name = "Mann-Whitney U" | |
187 else: | |
188 group_data = [df[df[title_x] == group][title_y] for group in groups] | |
189 stat, p = kruskal(*group_data) | |
190 test_name = "Kruskal-Wallis" | |
191 | |
192 # Create a boxplot with jittered points | |
193 plt.figure(figsize=figsize) | |
194 sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill=False) | |
195 sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4) | |
196 | |
197 # Labels and p-value annotation | |
198 plt.xlabel(title_x) | |
199 plt.ylabel(title_y) | |
200 plt.text( | |
201 x=-0.4, | |
202 y=plt.ylim()[1], | |
203 s=f'{test_name} p = {p:.3e}', | |
204 verticalalignment='top', | |
205 horizontalalignment='left', | |
206 fontsize=12, | |
207 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray') | |
208 ) | |
209 | |
210 plt.tight_layout() | |
211 return plt.gcf() | |
212 | |
213 | |
214 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): | |
215 """Generate dimensionality reduction plots""" | |
216 | |
217 # Parse target variables | |
218 target_vars = [var.strip() for var in args.target_variables.split(',')] | |
219 | |
220 print(f"Generating {args.method.upper()} plots for {len(target_vars)} target variable(s): {', '.join(target_vars)}") | |
221 | |
222 # Check variables | |
223 available_vars = matched_labels['variable'].unique() | |
224 missing_vars = [var for var in target_vars if var not in available_vars] | |
225 | |
226 if missing_vars: | |
227 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") | |
228 print(f"Available variables: {', '.join(available_vars)}") | |
229 | |
230 # Filter to only process available variables | |
231 valid_vars = [var for var in target_vars if var in available_vars] | |
232 | |
233 if not valid_vars: | |
234 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") | |
235 | |
236 # Generate plots for each valid target variable | |
237 for var in valid_vars: | |
238 print(f"\nPlotting variable: {var}") | |
239 | |
240 # Filter matched labels for current variable | |
241 var_labels = matched_labels[matched_labels['variable'] == var].copy() | |
242 var_labels = var_labels.drop_duplicates(subset='sample_id') | |
243 | |
244 if var_labels.empty: | |
245 print(f"Warning: No data found for variable '{var}', skipping...") | |
246 continue | |
247 | |
248 # Auto-detect color type | |
249 known_color_type = detect_color_type(var_labels['known_label']) | |
250 predicted_color_type = detect_color_type(var_labels['predicted_label']) | |
251 | |
252 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") | |
253 | |
254 try: | |
255 # Plot 1: Known labels | |
256 print(f" Creating known labels plot for {var}...") | |
257 fig_known = plot_dim_reduced( | |
258 matrix=embeddings, | |
259 labels=var_labels['known_label'], | |
260 method=args.method, | |
261 color_type=known_color_type | |
262 ) | |
263 | |
264 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" | |
265 print(f" Saving known labels plot to: {output_path_known.name}") | |
266 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | |
267 | |
268 # Plot 2: Predicted labels | |
269 print(f" Creating predicted labels plot for {var}...") | |
270 fig_predicted = plot_dim_reduced( | |
271 matrix=embeddings, | |
272 labels=var_labels['predicted_label'], | |
273 method=args.method, | |
274 color_type=predicted_color_type | |
275 ) | |
276 | |
277 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" | |
278 print(f" Saving predicted labels plot to: {output_path_predicted.name}") | |
279 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') | |
280 | |
281 print(f" ✓ Successfully created plots for variable '{var}'") | |
282 | |
283 except Exception as e: | |
284 print(f" ✗ Error creating plots for variable '{var}': {e}") | |
285 continue | |
286 | |
287 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") | |
288 | |
289 | |
290 def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base): | |
291 """Generate Kaplan-Meier plots""" | |
292 print("Generating Kaplan-Meier curves of risk subtypes...") | |
293 | |
294 # Reset index and rename the index column to sample_id | |
295 survival_data = survival_data.reset_index() | |
296 if survival_data.columns[0] != 'sample_id': | |
297 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) | |
298 | |
299 # Convert survival event column to binary (0/1) based on event_value | |
300 # Check if the event column exists | |
301 if args.surv_event_var not in survival_data.columns: | |
302 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") | |
303 | |
304 # Convert to string for comparison to handle mixed types | |
305 survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str) | |
306 event_value_str = str(args.event_value) | |
307 | |
308 # Create binary event column (1 if matches event_value, 0 otherwise) | |
309 survival_data[f'{args.surv_event_var}_binary'] = ( | |
310 survival_data[args.surv_event_var] == event_value_str | |
311 ).astype(int) | |
312 | |
313 # Filter for survival category and class_label == '1:DECEASED' | |
314 label_data['class_label'] = label_data['class_label'].astype(str) | |
315 | |
316 label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)] | |
317 | |
318 # check survival data | |
319 for col in [args.surv_time_var, args.surv_event_var]: | |
320 if col not in survival_data.columns: | |
321 raise ValueError(f"Column '{col}' not found in survival data") | |
322 | |
323 # Merge survival data with labels | |
324 df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') | |
325 | |
326 if df_deceased.empty: | |
327 raise ValueError("No matching samples found after merging survival and label data.") | |
328 | |
329 # Get risk scores | |
330 risk_scores = df_deceased['probability'].values | |
331 | |
332 # Compute groups (e.g., median split) | |
333 quantiles = np.quantile(risk_scores, [0.5]) | |
334 groups = np.digitize(risk_scores, quantiles) | |
335 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] | |
336 | |
337 fig_known = plot_kaplan_meier_curves( | |
338 durations=df_deceased[args.surv_time_var], | |
339 events=df_deceased[f'{args.surv_event_var}_binary'], | |
340 categorical_variable=group_labels | |
341 ) | |
342 | |
343 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" | |
344 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") | |
345 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | |
346 | |
347 print("Kaplan-Meier plot saved successfully!") | |
348 | |
349 | |
350 def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): | |
351 """Generate Cox proportional hazards plots""" | |
352 print("Generating Cox proportional hazards analysis...") | |
353 | |
354 # Parse clinical variables | |
355 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] | |
356 | |
357 # Validate that survival variables are included | |
358 required_vars = [args.surv_time_var, args.surv_event_var] | |
359 for var in required_vars: | |
360 if var not in clinical_vars: | |
361 clinical_vars.append(var) | |
362 | |
363 print(f"Using clinical variables: {', '.join(clinical_vars)}") | |
364 | |
365 # filter datasets for clinical variables | |
366 if all(var in clinical_train.columns and var in clinical_test.columns for var in clinical_vars): | |
367 df_clin_train = clinical_train[clinical_vars] | |
368 df_clin_test = clinical_test[clinical_vars] | |
369 # Drop rows with NaN in clinical variables | |
370 df_clin_train = df_clin_train.dropna(subset=clinical_vars) | |
371 df_clin_test = df_clin_test.dropna(subset=clinical_vars) | |
372 else: | |
373 raise ValueError(f"Not all clinical variables found in datasets. Available in train dataset: {clinical_train.columns.tolist()}, Available in test dataset: {clinical_test.columns.tolist()}") | |
374 | |
375 # Combine | |
376 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) | |
377 | |
378 # Get top survival markers | |
379 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") | |
380 try: | |
381 imp = get_important_features(model, | |
382 var=args.surv_event_var, | |
383 top=args.top_features | |
384 )['name'].unique().tolist() | |
385 print(f"Top features: {', '.join(imp)}") | |
386 except Exception as e: | |
387 raise ValueError(f"Error getting important features: {e}") | |
388 | |
389 # Extract feature data from omics datasets | |
390 try: | |
391 omics_test = omics_test.loc[omics_test.index.isin(imp)] | |
392 omics_train = omics_train.loc[omics_train.index.isin(imp)] | |
393 # Drop rows with NaN in omics datasets | |
394 omics_test = omics_test.dropna(subset=omics_test.columns) | |
395 omics_train = omics_train.dropna(subset=omics_train.columns) | |
396 | |
397 df_imp = pd.concat([omics_train, omics_test], axis=1) | |
398 df_imp = df_imp.T # Transpose to have samples as rows | |
399 | |
400 print(f"Feature data shape: {df_imp.shape}") | |
401 except Exception as e: | |
402 raise ValueError(f"Error extracting feature subset: {e}") | |
403 | |
404 # Combine markers with clinical variables | |
405 df = pd.merge(df_imp, df_clin, left_index=True, right_index=True) | |
406 print(f"Combined data shape: {df.shape}") | |
407 | |
408 # Remove samples without survival endpoints | |
409 initial_samples = len(df) | |
410 df = df[df[args.surv_event_var].notna()] | |
411 final_samples = len(df) | |
412 print(f"Removed {initial_samples - final_samples} samples without survival data") | |
413 | |
414 if df.empty: | |
415 raise ValueError("No samples remain after filtering for survival data") | |
416 | |
417 # Convert survival event column to binary (0/1) based on event_value | |
418 # Convert to string for comparison to handle mixed types | |
419 df[args.surv_event_var] = df[args.surv_event_var].astype(str) | |
420 event_value_str = str(args.event_value) | |
421 | |
422 df[f'{args.surv_event_var}'] = ( | |
423 df[args.surv_event_var] == event_value_str | |
424 ).astype(int) | |
425 | |
426 # Build Cox model | |
427 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") | |
428 try: | |
429 coxm = build_cox_model(df, | |
430 duration_col=args.surv_time_var, | |
431 event_col=args.surv_event_var, | |
432 crossval=args.crossval, | |
433 n_splits=args.n_splits, | |
434 random_state=args.random_state) | |
435 print("Cox model built successfully") | |
436 except Exception as e: | |
437 raise ValueError(f"Error building Cox model: {e}") | |
438 | |
439 # Generate hazard ratios plot | |
440 try: | |
441 print("Generating hazard ratios plot...") | |
442 fig = plot_hazard_ratios(coxm) | |
443 | |
444 output_path = output_dir / f"{output_name_base}_hazard_ratios.{args.format}" | |
445 print(f"Saving hazard ratios plot to: {output_path.absolute()}") | |
446 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
447 | |
448 print("Cox proportional hazards analysis completed successfully!") | |
449 | |
450 except Exception as e: | |
451 raise ValueError(f"Error generating hazard ratios plot: {e}") | |
452 | |
453 | |
454 def generate_plot_scatter(labels, args, output_dir, output_name_base): | |
455 """Generate scatter plot of known vs predicted labels""" | |
456 print("Generating scatter plots of known vs predicted labels...") | |
457 | |
458 # Parse target values from comma-separated string | |
459 if args.target_value: | |
460 target_values = [val.strip() for val in args.target_value.split(',')] | |
461 else: | |
462 # If no target values specified, use all unique variables | |
463 target_values = labels['variable'].unique().tolist() | |
464 | |
465 print(f"Processing target values: {target_values}") | |
466 | |
467 successful_plots = 0 | |
468 skipped_plots = 0 | |
469 | |
470 for target_value in target_values: | |
471 print(f"\nProcessing target value: '{target_value}'") | |
472 | |
473 # Filter labels for the current target value | |
474 target_labels = labels[labels['variable'] == target_value] | |
475 | |
476 if target_labels.empty: | |
477 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
478 skipped_plots += 1 | |
479 continue | |
480 | |
481 # Check if labels are numeric and convert | |
482 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') | |
483 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') | |
484 | |
485 if true_values.isna().all() or predicted_values.isna().all(): | |
486 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") | |
487 skipped_plots += 1 | |
488 continue | |
489 | |
490 try: | |
491 print(f" Generating scatter plot for '{target_value}'...") | |
492 fig = plot_scatter(true_values, predicted_values) | |
493 | |
494 # Create output filename with target value | |
495 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
496 if len(target_values) > 1: | |
497 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
498 else: | |
499 output_filename = f"{output_name_base}.{args.format}" | |
500 | |
501 output_path = output_dir / output_filename | |
502 print(f" Saving scatter plot to: {output_path.absolute()}") | |
503 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
504 | |
505 successful_plots += 1 | |
506 print(f" Scatter plot for '{target_value}' generated successfully!") | |
507 | |
508 except Exception as e: | |
509 print(f" Error generating plot for '{target_value}': {str(e)}") | |
510 skipped_plots += 1 | |
511 | |
512 # Summary | |
513 print(" Summary:") | |
514 print(f" Successfully generated: {successful_plots} plots") | |
515 print(f" Skipped: {skipped_plots} plots") | |
516 | |
517 if successful_plots == 0: | |
518 raise ValueError("No scatter plots could be generated. Check your data and target values.") | |
519 | |
520 print("Scatter plot generation completed!") | |
521 | |
522 | |
523 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): | |
524 """Generate label concordance heatmap""" | |
525 print("Generating label concordance heatmaps...") | |
526 | |
527 # Parse target values from comma-separated string | |
528 if args.target_value: | |
529 target_values = [val.strip() for val in args.target_value.split(',')] | |
530 else: | |
531 # If no target values specified, use all unique variables | |
532 target_values = labels['variable'].unique().tolist() | |
533 | |
534 print(f"Processing target values: {target_values}") | |
535 | |
536 for target_value in target_values: | |
537 print(f"\nProcessing target value: '{target_value}'") | |
538 | |
539 # Filter labels for the current target value | |
540 target_labels = labels[labels['variable'] == target_value] | |
541 | |
542 if target_labels.empty: | |
543 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
544 continue | |
545 | |
546 true_values = target_labels['known_label'].tolist() | |
547 predicted_values = target_labels['predicted_label'].tolist() | |
548 | |
549 try: | |
550 print(f" Generating heatmap for '{target_value}'...") | |
551 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
552 plt.close(fig) | |
553 | |
554 # Create output filename with target value | |
555 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
556 if len(target_values) > 1: | |
557 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
558 else: | |
559 output_filename = f"{output_name_base}.{args.format}" | |
560 | |
561 output_path = output_dir / output_filename | |
562 print(f" Saving heatmap to: {output_path.absolute()}") | |
563 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
564 | |
565 except Exception as e: | |
566 print(f" Error generating heatmap for '{target_value}': {str(e)}") | |
567 continue | |
568 | |
569 print("Label concordance heatmap generated successfully!") | |
570 | |
571 | |
572 def generate_pr_curves(labels, args, output_dir, output_name_base): | |
573 """Generate precision-recall curves""" | |
574 print("Generating precision-recall curves...") | |
575 | |
576 # Parse target values from comma-separated string | |
577 if args.target_value: | |
578 target_values = [val.strip() for val in args.target_value.split(',')] | |
579 else: | |
580 # If no target values specified, use all unique variables | |
581 target_values = labels['variable'].unique().tolist() | |
582 | |
583 print(f"Processing target values: {target_values}") | |
584 | |
585 for target_value in target_values: | |
586 print(f"\nProcessing target value: '{target_value}'") | |
587 | |
588 # Filter labels for the current target value | |
589 target_labels = labels[labels['variable'] == target_value] | |
590 | |
591 # Check if this is a regression problem (no class probabilities) | |
592 prob_columns = target_labels['class_label'].unique() | |
593 non_na_probs = target_labels['probability'].notna().sum() | |
594 | |
595 print(f" Class labels found: {list(prob_columns)}") | |
596 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
597 | |
598 # If most probabilities are NaN, this is likely a regression problem | |
599 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
600 print(" Detected regression problem - precision-recall curves not applicable") | |
601 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
602 continue | |
603 | |
604 # Debug: Check data quality | |
605 total_rows = len(target_labels) | |
606 missing_labels = target_labels['known_label'].isna().sum() | |
607 missing_probs = target_labels['probability'].isna().sum() | |
608 unique_samples = target_labels['sample_id'].nunique() | |
609 unique_classes = target_labels['class_label'].nunique() | |
610 | |
611 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
612 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
613 | |
614 if missing_labels > 0: | |
615 print(f" Warning: Found {missing_labels} missing known_label values") | |
616 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
617 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
618 | |
619 # Remove rows with missing known_label | |
620 target_labels = target_labels.dropna(subset=['known_label']) | |
621 if target_labels.empty: | |
622 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
623 continue | |
624 | |
625 # 1. Pivot to wide format | |
626 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
627 | |
628 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
629 print(f" Class columns: {list(prob_df.columns)}") | |
630 | |
631 # Check for NaN values in probability data | |
632 nan_counts = prob_df.isna().sum() | |
633 if nan_counts.any(): | |
634 print(f" NaN counts per class: {dict(nan_counts)}") | |
635 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
636 | |
637 # Drop only rows where ALL probabilities are NaN | |
638 all_nan_rows = prob_df.isna().all(axis=1) | |
639 if all_nan_rows.any(): | |
640 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
641 prob_df = prob_df[~all_nan_rows] | |
642 | |
643 remaining_nans = prob_df.isna().sum().sum() | |
644 if remaining_nans > 0: | |
645 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
646 prob_df = prob_df.fillna(0) | |
647 | |
648 if prob_df.empty: | |
649 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
650 continue | |
651 | |
652 # 2. Get true labels | |
653 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
654 | |
655 # 3. Align indices - only keep samples that exist in both datasets | |
656 common_indices = prob_df.index.intersection(true_labels_df.index) | |
657 if len(common_indices) == 0: | |
658 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
659 continue | |
660 | |
661 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
662 | |
663 # Filter both datasets to common indices | |
664 prob_df_aligned = prob_df.loc[common_indices] | |
665 y_true = true_labels_df.loc[common_indices]['known_label'] | |
666 | |
667 # 4. Final check for NaN values | |
668 if y_true.isna().any(): | |
669 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
670 continue | |
671 | |
672 if prob_df_aligned.isna().any().any(): | |
673 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
674 continue | |
675 | |
676 # 5. Convert categorical labels to integer labels | |
677 # Create a mapping from class names to integers | |
678 class_names = list(prob_df_aligned.columns) | |
679 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
680 | |
681 print(f" Class mapping: {class_to_int}") | |
682 | |
683 # Convert true labels to integers | |
684 y_true_np = y_true.map(class_to_int).to_numpy() | |
685 y_probs_np = prob_df_aligned.to_numpy() | |
686 | |
687 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
688 print(f" Unique true labels (integers): {set(y_true_np)}") | |
689 print(f" Class labels (columns): {class_names}") | |
690 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
691 | |
692 # Check for any unmapped labels (will be NaN) | |
693 if pd.isna(y_true_np).any(): | |
694 print(" Error: Some true labels could not be mapped to class columns") | |
695 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
696 print(f" Unmapped labels: {unmapped_labels}") | |
697 print(f" Available classes: {class_names}") | |
698 continue | |
699 | |
700 try: | |
701 print(f" Generating precision-recall curve for '{target_value}'...") | |
702 fig = plot_pr_curves(y_true_np, y_probs_np) | |
703 | |
704 # Create output filename with target value | |
705 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
706 if len(target_values) > 1: | |
707 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
708 else: | |
709 output_filename = f"{output_name_base}.{args.format}" | |
710 | |
711 output_path = output_dir / output_filename | |
712 print(f" Saving precision-recall curve to: {output_path.absolute()}") | |
713 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
714 | |
715 except Exception as e: | |
716 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") | |
717 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") | |
718 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") | |
719 continue | |
720 | |
721 print("Precision-recall curves generated successfully!") | |
722 | |
723 | |
724 def generate_roc_curves(labels, args, output_dir, output_name_base): | |
725 """Generate ROC curves""" | |
726 print("Generating ROC curves...") | |
727 | |
728 # Parse target values from comma-separated string | |
729 if args.target_value: | |
730 target_values = [val.strip() for val in args.target_value.split(',')] | |
731 else: | |
732 # If no target values specified, use all unique variables | |
733 target_values = labels['variable'].unique().tolist() | |
734 | |
735 print(f"Processing target values: {target_values}") | |
736 | |
737 for target_value in target_values: | |
738 print(f"\nProcessing target value: '{target_value}'") | |
739 | |
740 # Filter labels for the current target value | |
741 target_labels = labels[labels['variable'] == target_value] | |
742 | |
743 # Check if this is a regression problem (no class probabilities) | |
744 prob_columns = target_labels['class_label'].unique() | |
745 non_na_probs = target_labels['probability'].notna().sum() | |
746 | |
747 print(f" Class labels found: {list(prob_columns)}") | |
748 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
749 | |
750 # If most probabilities are NaN, this is likely a regression problem | |
751 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
752 print(" Detected regression problem - ROC curves not applicable") | |
753 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
754 continue | |
755 | |
756 # Debug: Check data quality | |
757 total_rows = len(target_labels) | |
758 missing_labels = target_labels['known_label'].isna().sum() | |
759 missing_probs = target_labels['probability'].isna().sum() | |
760 unique_samples = target_labels['sample_id'].nunique() | |
761 unique_classes = target_labels['class_label'].nunique() | |
762 | |
763 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
764 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
765 | |
766 if missing_labels > 0: | |
767 print(f" Warning: Found {missing_labels} missing known_label values") | |
768 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
769 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
770 | |
771 # Remove rows with missing known_label | |
772 target_labels = target_labels.dropna(subset=['known_label']) | |
773 if target_labels.empty: | |
774 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
775 continue | |
776 | |
777 # 1. Pivot to wide format | |
778 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
779 | |
780 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
781 print(f" Class columns: {list(prob_df.columns)}") | |
782 | |
783 # Check for NaN values in probability data | |
784 nan_counts = prob_df.isna().sum() | |
785 if nan_counts.any(): | |
786 print(f" NaN counts per class: {dict(nan_counts)}") | |
787 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
788 | |
789 # Drop only rows where ALL probabilities are NaN | |
790 all_nan_rows = prob_df.isna().all(axis=1) | |
791 if all_nan_rows.any(): | |
792 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
793 prob_df = prob_df[~all_nan_rows] | |
794 | |
795 remaining_nans = prob_df.isna().sum().sum() | |
796 if remaining_nans > 0: | |
797 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
798 prob_df = prob_df.fillna(0) | |
799 | |
800 if prob_df.empty: | |
801 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
802 continue | |
803 | |
804 # 2. Get true labels | |
805 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
806 | |
807 # 3. Align indices - only keep samples that exist in both datasets | |
808 common_indices = prob_df.index.intersection(true_labels_df.index) | |
809 if len(common_indices) == 0: | |
810 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
811 continue | |
812 | |
813 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
814 | |
815 # Filter both datasets to common indices | |
816 prob_df_aligned = prob_df.loc[common_indices] | |
817 y_true = true_labels_df.loc[common_indices]['known_label'] | |
818 | |
819 # 4. Final check for NaN values | |
820 if y_true.isna().any(): | |
821 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
822 continue | |
823 | |
824 if prob_df_aligned.isna().any().any(): | |
825 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
826 continue | |
827 | |
828 # 5. Convert categorical labels to integer labels | |
829 # Create a mapping from class names to integers | |
830 class_names = list(prob_df_aligned.columns) | |
831 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
832 | |
833 print(f" Class mapping: {class_to_int}") | |
834 | |
835 # Convert true labels to integers | |
836 y_true_np = y_true.map(class_to_int).to_numpy() | |
837 y_probs_np = prob_df_aligned.to_numpy() | |
838 | |
839 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
840 print(f" Unique true labels (integers): {set(y_true_np)}") | |
841 print(f" Class labels (columns): {class_names}") | |
842 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
843 | |
844 # Check for any unmapped labels (will be NaN) | |
845 if pd.isna(y_true_np).any(): | |
846 print(" Error: Some true labels could not be mapped to class columns") | |
847 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
848 print(f" Unmapped labels: {unmapped_labels}") | |
849 print(f" Available classes: {class_names}") | |
850 continue | |
851 | |
852 try: | |
853 print(f" Generating ROC curve for '{target_value}'...") | |
854 fig = plot_roc_curves(y_true_np, y_probs_np) | |
855 | |
856 # Create output filename with target value | |
857 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
858 if len(target_values) > 1: | |
859 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
860 else: | |
861 output_filename = f"{output_name_base}.{args.format}" | |
862 | |
863 output_path = output_dir / output_filename | |
864 print(f" Saving ROC curve to: {output_path.absolute()}") | |
865 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
866 | |
867 except Exception as e: | |
868 print(f" Error generating ROC curve for '{target_value}': {str(e)}") | |
869 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") | |
870 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") | |
871 continue | |
872 | |
873 print("ROC curves generated successfully!") | |
874 | |
875 | |
876 def generate_box_plots(labels, args, output_dir, output_name_base): | |
877 """Generate box plots for model predictions""" | |
878 | |
879 print("Generating box plots...") | |
880 | |
881 # Parse target values from comma-separated string | |
882 if args.target_value: | |
883 target_values = [val.strip() for val in args.target_value.split(',')] | |
884 else: | |
885 # If no target values specified, use all unique variables | |
886 target_values = labels['variable'].unique().tolist() | |
887 | |
888 print(f"Processing target values: {target_values}") | |
889 | |
890 for target_value in target_values: | |
891 print(f"\nProcessing target value: '{target_value}'") | |
892 | |
893 # Filter labels for the current target value | |
894 target_labels = labels[labels['variable'] == target_value] | |
895 | |
896 if target_labels.empty: | |
897 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
898 continue | |
899 | |
900 # Check if this is a classification problem (has probabilities) | |
901 prob_columns = target_labels['class_label'].unique() | |
902 non_na_probs = target_labels['probability'].notna().sum() | |
903 | |
904 print(f" Class labels found: {list(prob_columns)}") | |
905 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
906 | |
907 # If most probabilities are NaN, this is likely a regression problem | |
908 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
909 print(" Detected regression problem - precision-recall curves not applicable") | |
910 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
911 continue | |
912 | |
913 # Debug: Check data quality | |
914 total_rows = len(target_labels) | |
915 missing_labels = target_labels['known_label'].isna().sum() | |
916 missing_probs = target_labels['probability'].isna().sum() | |
917 unique_samples = target_labels['sample_id'].nunique() | |
918 unique_classes = target_labels['class_label'].nunique() | |
919 | |
920 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
921 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
922 | |
923 if missing_labels > 0: | |
924 print(f" Warning: Found {missing_labels} missing known_label values") | |
925 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
926 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
927 | |
928 # Remove rows with missing known_label | |
929 target_labels = target_labels.dropna(subset=['known_label']) | |
930 if target_labels.empty: | |
931 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
932 continue | |
933 | |
934 # Remove rows with missing data | |
935 clean_data = target_labels.dropna(subset=['known_label', 'probability']) | |
936 | |
937 if clean_data.empty: | |
938 print(" No valid data after cleaning - skipping") | |
939 continue | |
940 | |
941 # Get unique classes | |
942 classes = clean_data['class_label'].unique() | |
943 | |
944 for class_label in classes: | |
945 print(f" Generating box plot for class: {class_label}") | |
946 | |
947 # Filter for current class | |
948 class_data = clean_data[clean_data['class_label'] == class_label] | |
949 | |
950 try: | |
951 # Create the box plot | |
952 fig = plot_boxplot( | |
953 categorical_x=class_data['known_label'], | |
954 numerical_y=class_data['probability'], | |
955 title_x='True Label', | |
956 title_y=f'Predicted Probability ({class_label})', | |
957 ) | |
958 | |
959 # Save the plot | |
960 safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_') | |
961 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
962 output_filename = f"{output_name_base}_{safe_target_name}_{safe_class_name}.{args.format}" | |
963 output_path = output_dir / output_filename | |
964 | |
965 print(f" Saving box plot to: {output_path.absolute()}") | |
966 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
967 plt.close(fig) | |
968 | |
969 except Exception as e: | |
970 print(f" Error generating box plot for class '{class_label}': {str(e)}") | |
971 continue | |
972 | |
973 | |
974 def main(): | |
975 """Main function to parse arguments and generate plots""" | |
976 parser = argparse.ArgumentParser(description="Generate plots using flexynesis") | |
977 | |
978 # Required arguments | |
979 parser.add_argument("--labels", type=str, required=False, | |
980 help="Path to labels file generated by flexynesis") | |
981 | |
982 # Plot type | |
983 parser.add_argument("--plot_type", type=str, required=True, | |
984 choices=['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'], | |
985 help="Type of plot to generate: 'dimred' for dimensionality reduction, 'kaplan_meier' for survival analysis, 'cox' for Cox proportional hazards analysis, 'scatter' for scatter plots, 'concordance_heatmap' for label concordance heatmaps, 'pr_curve' for precision-recall curves, 'roc_curve' for ROC curves, or 'box_plot' for box plots.") | |
986 | |
987 # Arguments for dimensionality reduction | |
988 parser.add_argument("--embeddings", type=str, | |
989 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") | |
990 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], | |
991 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") | |
992 parser.add_argument("--target_variables", type=str, required=False, | |
993 help="Comma-separated list of target variables to plot.") | |
994 | |
995 # Arguments for Kaplan-Meier | |
996 parser.add_argument("--survival_data", type=str, | |
997 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") | |
998 parser.add_argument("--surv_time_var", type=str, required=False, | |
999 help="Column name for survival time") | |
1000 parser.add_argument("--surv_event_var", type=str, required=False, | |
1001 help="Column name for survival event") | |
1002 parser.add_argument("--event_value", type=str, required=False, | |
1003 help="Value in event column that represents an event (e.g., 'DECEASED')") | |
1004 | |
1005 # Arguments for Cox analysis | |
1006 parser.add_argument("--model", type=str, | |
1007 help="Path to trained flexynesis model (pickle file). Required for cox plots.") | |
1008 parser.add_argument("--clinical_train", type=str, | |
1009 help="Path to training dataset (pickle file). Required for cox plots.") | |
1010 parser.add_argument("--clinical_test", type=str, | |
1011 help="Path to test dataset (pickle file). Required for cox plots.") | |
1012 parser.add_argument("--omics_train", type=str, default=None, | |
1013 help="Path to training omics dataset. Optional for cox plots.") | |
1014 parser.add_argument("--omics_test", type=str, default=None, | |
1015 help="Path to test omics dataset. Optional for cox plots.") | |
1016 parser.add_argument("--clinical_variables", type=str, | |
1017 help="Comma-separated list of clinical variables to include in Cox model (e.g., 'AGE,SEX,HISTOLOGICAL_DIAGNOSIS,STUDY')") | |
1018 parser.add_argument("--top_features", type=int, default=20, | |
1019 help="Number of top important features to include in Cox model. Default is 5") | |
1020 parser.add_argument("--crossval", action='store_true', | |
1021 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") | |
1022 parser.add_argument("--n_splits", type=int, default=5, | |
1023 help="Number of folds for cross-validation. Default is 5") | |
1024 parser.add_argument("--random_state", type=int, default=42, | |
1025 help="Random seed for reproducibility. Default is 42") | |
1026 | |
1027 # Arguments for scatter plot, heatmap, PR curves, ROC curves, and box plots | |
1028 parser.add_argument("--target_value", type=str, default=None, | |
1029 help="Target value for scatter plot.") | |
1030 | |
1031 # Common arguments | |
1032 parser.add_argument("--output_dir", type=str, default='output', | |
1033 help="Output directory. Default is 'output'") | |
1034 parser.add_argument("--output_name", type=str, default=None, | |
1035 help="Output filename base") | |
1036 parser.add_argument("--format", type=str, default='jpg', choices=['png', 'pdf', 'svg', 'jpg'], | |
1037 help="Output format for the plot. Default is 'jpg'") | |
1038 parser.add_argument("--dpi", type=int, default=300, | |
1039 help="DPI for the output image. Default is 300") | |
1040 | |
1041 args = parser.parse_args() | |
1042 | |
1043 try: | |
1044 # validate plot type | |
1045 if not args.plot_type: | |
1046 raise ValueError("Please specify a plot type using --plot_type") | |
1047 if args.plot_type not in ['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot']: | |
1048 raise ValueError(f"Invalid plot type: {args.plot_type}. Must be one of: 'dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'") | |
1049 | |
1050 # Validate plot type requirements | |
1051 if args.plot_type in ['dimred']: | |
1052 if not args.embeddings: | |
1053 raise ValueError("--embeddings is required when plot_type is 'dimred'") | |
1054 if not os.path.isfile(args.embeddings): | |
1055 raise FileNotFoundError(f"embeddings file not found: {args.embeddings}") | |
1056 if not args.labels: | |
1057 raise ValueError("--labels is required for dimensionality reduction plots") | |
1058 if not args.method: | |
1059 raise ValueError("--method is required for dimensionality reduction plots") | |
1060 if not args.target_variables: | |
1061 raise ValueError("--target_variables is required for dimensionality reduction plots") | |
1062 | |
1063 if args.plot_type in ['kaplan_meier']: | |
1064 if not args.survival_data: | |
1065 raise ValueError("--survival_data is required when plot_type is 'kaplan_meier'") | |
1066 if not os.path.isfile(args.survival_data): | |
1067 raise FileNotFoundError(f"Survival data file not found: {args.survival_data}") | |
1068 if not args.labels: | |
1069 raise ValueError("--labels is required for dimensionality reduction plots") | |
1070 if not args.method: | |
1071 raise ValueError("--method is required for dimensionality reduction plots") | |
1072 if not args.surv_time_var: | |
1073 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") | |
1074 if not args.surv_event_var: | |
1075 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") | |
1076 if not args.event_value: | |
1077 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
1078 | |
1079 if args.plot_type in ['cox']: | |
1080 if not args.model: | |
1081 raise ValueError("--model is required when plot_type is 'cox'") | |
1082 if not os.path.isfile(args.model): | |
1083 raise FileNotFoundError(f"Model file not found: {args.model}") | |
1084 if not args.clinical_train: | |
1085 raise ValueError("--clinical_train is required when plot_type is 'cox'") | |
1086 if not os.path.isfile(args.clinical_train): | |
1087 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") | |
1088 if not args.clinical_test: | |
1089 raise ValueError("--clinical_test is required when plot_type is 'cox'") | |
1090 if not os.path.isfile(args.clinical_test): | |
1091 raise FileNotFoundError(f"Test dataset file not found: {args.clinical_test}") | |
1092 if not args.omics_train: | |
1093 raise ValueError("--omics_train is required when plot_type is 'cox'") | |
1094 if not os.path.isfile(args.omics_train): | |
1095 raise FileNotFoundError(f"Training omics dataset file not found: {args.omics_train}") | |
1096 if not args.omics_test: | |
1097 raise ValueError("--omics_test is required when plot_type is 'cox'") | |
1098 if not os.path.isfile(args.omics_test): | |
1099 raise FileNotFoundError(f"Test omics dataset file not found: {args.omics_test}") | |
1100 if not args.surv_time_var: | |
1101 raise ValueError("--surv_time_var is required for Cox plots") | |
1102 if not args.surv_event_var: | |
1103 raise ValueError("--surv_event_var is required for Cox plots") | |
1104 if not args.clinical_variables: | |
1105 raise ValueError("--clinical_variables is required for Cox plots") | |
1106 if not isinstance(args.top_features, int) or args.top_features <= 0: | |
1107 raise ValueError("--top_features must be a positive integer") | |
1108 if not args.event_value: | |
1109 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
1110 if not args.crossval: | |
1111 args.crossval = False | |
1112 if not isinstance(args.n_splits, int) or args.n_splits <= 0: | |
1113 raise ValueError("--n_splits must be a positive integer") | |
1114 if not isinstance(args.random_state, int): | |
1115 raise ValueError("--random_state must be an integer") | |
1116 | |
1117 if args.plot_type in ['scatter']: | |
1118 if not args.labels: | |
1119 raise ValueError("--labels is required for scatter plots") | |
1120 if not args.target_value: | |
1121 print("--target_value is not specified, using all unique variables from labels") | |
1122 if not os.path.isfile(args.labels): | |
1123 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
1124 | |
1125 if args.plot_type in ['concordance_heatmap']: | |
1126 if not args.labels: | |
1127 raise ValueError("--labels is required for concordance heatmap") | |
1128 if not args.target_value: | |
1129 print("--target_value is not specified, using all unique variables from labels") | |
1130 if not os.path.isfile(args.labels): | |
1131 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
1132 | |
1133 if args.plot_type in ['pr_curve']: | |
1134 if not args.labels: | |
1135 raise ValueError("--labels is required for precision-recall curves") | |
1136 if not args.target_value: | |
1137 print("--target_value is not specified, using all unique variables from labels") | |
1138 if not os.path.isfile(args.labels): | |
1139 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
1140 | |
1141 if args.plot_type in ['roc_curve']: | |
1142 if not args.labels: | |
1143 raise ValueError("--labels is required for ROC curves") | |
1144 if not args.target_value: | |
1145 print("--target_value is not specified, using all unique variables from labels") | |
1146 if not os.path.isfile(args.labels): | |
1147 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
1148 | |
1149 if args.plot_type in ['box_plot']: | |
1150 if not args.labels: | |
1151 raise ValueError("--labels is required for box plots") | |
1152 if not args.target_value: | |
1153 print("--target_value is not specified, using all unique variables from labels") | |
1154 if not os.path.isfile(args.labels): | |
1155 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
1156 | |
1157 # Validate other arguments | |
1158 if args.method not in ['pca', 'umap']: | |
1159 raise ValueError("Method must be 'pca' or 'umap'") | |
1160 | |
1161 # Create output directory | |
1162 output_dir = Path(args.output_dir) | |
1163 output_dir.mkdir(parents=True, exist_ok=True) | |
1164 print(f"Output directory: {output_dir.absolute()}") | |
1165 | |
1166 # Generate output filename base | |
1167 if args.output_name: | |
1168 output_name_base = args.output_name | |
1169 else: | |
1170 if args.plot_type == 'dimred': | |
1171 embeddings_name = Path(args.embeddings).stem | |
1172 output_name_base = f"{embeddings_name}_{args.method}" | |
1173 elif args.plot_type == 'kaplan_meier': | |
1174 survival_name = Path(args.survival_data).stem | |
1175 output_name_base = f"{survival_name}_km" | |
1176 elif args.plot_type == 'cox': | |
1177 model_name = Path(args.model).stem | |
1178 output_name_base = f"{model_name}_cox" | |
1179 elif args.plot_type == 'scatter': | |
1180 labels_name = Path(args.labels).stem | |
1181 output_name_base = f"{labels_name}_scatter" | |
1182 elif args.plot_type == 'concordance_heatmap': | |
1183 labels_name = Path(args.labels).stem | |
1184 output_name_base = f"{labels_name}_concordance" | |
1185 elif args.plot_type == 'pr_curve': | |
1186 labels_name = Path(args.labels).stem | |
1187 output_name_base = f"{labels_name}_pr_curves" | |
1188 elif args.plot_type == 'roc_curve': | |
1189 labels_name = Path(args.labels).stem | |
1190 output_name_base = f"{labels_name}_roc_curves" | |
1191 elif args.plot_type == 'box_plot': | |
1192 labels_name = Path(args.labels).stem | |
1193 output_name_base = f"{labels_name}_box_plot" | |
1194 | |
1195 # Generate plots based on type | |
1196 if args.plot_type in ['dimred']: | |
1197 # Load labels | |
1198 print(f"Loading labels from: {args.labels}") | |
1199 label_data = load_labels(args.labels) | |
1200 # Load embeddings data | |
1201 print(f"Loading embeddings from: {args.embeddings}") | |
1202 embeddings, sample_names = load_embeddings(args.embeddings) | |
1203 print(f"embeddings shape: {embeddings.shape}") | |
1204 | |
1205 # Match samples to embeddings | |
1206 matched_labels = match_samples_to_embeddings(sample_names, label_data) | |
1207 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") | |
1208 | |
1209 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) | |
1210 | |
1211 elif args.plot_type in ['kaplan_meier']: | |
1212 # Load labels | |
1213 print(f"Loading labels from: {args.labels}") | |
1214 label_data = load_labels(args.labels) | |
1215 # Load survival data | |
1216 print(f"Loading survival data from: {args.survival_data}") | |
1217 survival_data = load_survival_data(args.survival_data) | |
1218 print(f"Survival data shape: {survival_data.shape}") | |
1219 | |
1220 generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) | |
1221 | |
1222 elif args.plot_type in ['cox']: | |
1223 # Load model and datasets | |
1224 print(f"Loading model from: {args.model}") | |
1225 model = load_model(args.model) | |
1226 print(f"Loading training dataset from: {args.clinical_train}") | |
1227 clinical_train = load_omics(args.clinical_train) | |
1228 print(f"Loading test dataset from: {args.clinical_test}") | |
1229 clinical_test = load_omics(args.clinical_test) | |
1230 print(f"Loading training omics dataset from: {args.omics_train}") | |
1231 omics_train = load_omics(args.omics_train) | |
1232 print(f"Loading test omics dataset from: {args.omics_test}") | |
1233 omics_test = load_omics(args.omics_test) | |
1234 | |
1235 generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) | |
1236 | |
1237 elif args.plot_type in ['scatter']: | |
1238 # Load labels | |
1239 print(f"Loading labels from: {args.labels}") | |
1240 label_data = load_labels(args.labels) | |
1241 | |
1242 generate_plot_scatter(label_data, args, output_dir, output_name_base) | |
1243 | |
1244 elif args.plot_type in ['concordance_heatmap']: | |
1245 # Load labels | |
1246 print(f"Loading labels from: {args.labels}") | |
1247 label_data = load_labels(args.labels) | |
1248 | |
1249 generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) | |
1250 | |
1251 elif args.plot_type in ['pr_curve']: | |
1252 # Load labels | |
1253 print(f"Loading labels from: {args.labels}") | |
1254 label_data = load_labels(args.labels) | |
1255 | |
1256 generate_pr_curves(label_data, args, output_dir, output_name_base) | |
1257 | |
1258 elif args.plot_type in ['roc_curve']: | |
1259 # Load labels | |
1260 print(f"Loading labels from: {args.labels}") | |
1261 label_data = load_labels(args.labels) | |
1262 | |
1263 generate_roc_curves(label_data, args, output_dir, output_name_base) | |
1264 | |
1265 elif args.plot_type in ['box_plot']: | |
1266 # Load labels | |
1267 print(f"Loading labels from: {args.labels}") | |
1268 label_data = load_labels(args.labels) | |
1269 | |
1270 generate_box_plots(label_data, args, output_dir, output_name_base) | |
1271 | |
1272 print("All plots generated successfully!") | |
1273 | |
1274 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: | |
1275 print(f"Error: {e}") | |
1276 return 1 | |
1277 | |
1278 return 0 | |
1279 | |
1280 | |
1281 if __name__ == "__main__": | |
1282 exit(main()) |