diff cp_segmentation.py @ 3:c793edde4284 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/cellpose commit 3f2ba60f101c923896ca95ed62981fcbb0a5ced3
author bgruening
date Fri, 12 Dec 2025 12:36:21 +0000
parents e5370bb71633
children
line wrap: on
line diff
--- a/cp_segmentation.py	Sat Mar 15 17:23:55 2025 +0000
+++ b/cp_segmentation.py	Fri Dec 12 12:36:21 2025 +0000
@@ -14,7 +14,7 @@
 import numpy as np
 import skimage.io
 import torch
-from cellpose import models, plot, transforms
+from cellpose import models, plot
 
 # Apply PyTorch guidelines for reproducibility
 torch.backends.cudnn.benchmark = True
@@ -22,7 +22,7 @@
 torch.manual_seed(0)
 
 
-def main(inputs, img_path, img_format, output_dir):
+def main(inputs, img_path, output_dir):
     """
     Parameter
     ---------
@@ -30,8 +30,6 @@
         File path to galaxy tool parameter
     img_path : str
         File path for the input image
-    img_format : str
-        One of the ['ome.tiff', 'tiff', 'png', 'jpg']
     output_dir : str
         Folder to save the outputs.
     """
@@ -42,27 +40,13 @@
 
     gpu = params['use_gpu']
     model_type = params['model_type']
-    chan = params['chan']
-    chan2 = params['chan2']
-    chan_first = params['chan_first']
-    if chan is None:
-        channels = None
-    else:
-        channels = [int(chan), int(chan2) if chan2 is not None else None]
-
     options = params['options']
-
     img = skimage.io.imread(img_path)
 
     print(f"Image shape: {img.shape}")
-    # transpose to Ly x Lx x nchann and reshape based on channels
-    if img_format.endswith('tiff'):
-        img = np.transpose(img, (1, 2, 0))
-        img = transforms.reshape(img, channels=channels, chan_first=chan_first)
 
-    print(f"Image shape: {img.shape}")
     model = models.Cellpose(gpu=gpu, model_type=model_type)
-    masks, flows, styles, diams = model.eval(img, channels=channels, **options)
+    masks, flows, styles, diams = model.eval(img, channels=[0, 0], **options)
 
     # save masks to tiff
     with warnings.catch_warnings():
@@ -73,16 +57,12 @@
     # make segmentation show #
     if params['show_segmentation']:
         img = skimage.io.imread(img_path)
-        # uniform image
-        if img_format.endswith('tiff'):
-            img = np.transpose(img, (1, 2, 0))
-            img = transforms.reshape(img, channels=channels, chan_first=chan_first)
 
         maski = masks
         flowi = flows[0]
-        fig = plt.figure(figsize=(12, 3))
+        fig = plt.figure(figsize=(8, 2))
         # can save images (set save_dir=None if not)
-        plot.show_segmentation(fig, img, maski, flowi, channels=channels)
+        plot.show_segmentation(fig, img, maski, flowi, channels=[0, 0])
         fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300)
         plt.close(fig)
 
@@ -91,8 +71,7 @@
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-p", "--img_path", dest="img_path")
-    aparser.add_argument("-f", "--img_format", dest="img_format")
     aparser.add_argument("-O", "--output_dir", dest="output_dir")
     args = aparser.parse_args()
 
-    main(args.inputs, args.img_path, args.img_format, args.output_dir)
+    main(args.inputs, args.img_path, args.output_dir)