comparison cp_segmentation.py @ 3:c793edde4284 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/cellpose commit 3f2ba60f101c923896ca95ed62981fcbb0a5ced3
author bgruening
date Fri, 12 Dec 2025 12:36:21 +0000
parents e5370bb71633
children
comparison
equal deleted inserted replaced
2:e5370bb71633 3:c793edde4284
12 # Load the remaining packages *after* adjusting `MKL_NUM_THREADS` (this likely necessary for it to take effect) 12 # Load the remaining packages *after* adjusting `MKL_NUM_THREADS` (this likely necessary for it to take effect)
13 import matplotlib.pyplot as plt 13 import matplotlib.pyplot as plt
14 import numpy as np 14 import numpy as np
15 import skimage.io 15 import skimage.io
16 import torch 16 import torch
17 from cellpose import models, plot, transforms 17 from cellpose import models, plot
18 18
19 # Apply PyTorch guidelines for reproducibility 19 # Apply PyTorch guidelines for reproducibility
20 torch.backends.cudnn.benchmark = True 20 torch.backends.cudnn.benchmark = True
21 torch.backends.cudnn.deterministic = True 21 torch.backends.cudnn.deterministic = True
22 torch.manual_seed(0) 22 torch.manual_seed(0)
23 23
24 24
25 def main(inputs, img_path, img_format, output_dir): 25 def main(inputs, img_path, output_dir):
26 """ 26 """
27 Parameter 27 Parameter
28 --------- 28 ---------
29 inputs : str 29 inputs : str
30 File path to galaxy tool parameter 30 File path to galaxy tool parameter
31 img_path : str 31 img_path : str
32 File path for the input image 32 File path for the input image
33 img_format : str
34 One of the ['ome.tiff', 'tiff', 'png', 'jpg']
35 output_dir : str 33 output_dir : str
36 Folder to save the outputs. 34 Folder to save the outputs.
37 """ 35 """
38 warnings.simplefilter('ignore') 36 warnings.simplefilter('ignore')
39 np.random.seed(42) 37 np.random.seed(42)
40 with open(inputs, 'r') as param_handler: 38 with open(inputs, 'r') as param_handler:
41 params = json.load(param_handler) 39 params = json.load(param_handler)
42 40
43 gpu = params['use_gpu'] 41 gpu = params['use_gpu']
44 model_type = params['model_type'] 42 model_type = params['model_type']
45 chan = params['chan']
46 chan2 = params['chan2']
47 chan_first = params['chan_first']
48 if chan is None:
49 channels = None
50 else:
51 channels = [int(chan), int(chan2) if chan2 is not None else None]
52
53 options = params['options'] 43 options = params['options']
54
55 img = skimage.io.imread(img_path) 44 img = skimage.io.imread(img_path)
56 45
57 print(f"Image shape: {img.shape}") 46 print(f"Image shape: {img.shape}")
58 # transpose to Ly x Lx x nchann and reshape based on channels
59 if img_format.endswith('tiff'):
60 img = np.transpose(img, (1, 2, 0))
61 img = transforms.reshape(img, channels=channels, chan_first=chan_first)
62 47
63 print(f"Image shape: {img.shape}")
64 model = models.Cellpose(gpu=gpu, model_type=model_type) 48 model = models.Cellpose(gpu=gpu, model_type=model_type)
65 masks, flows, styles, diams = model.eval(img, channels=channels, **options) 49 masks, flows, styles, diams = model.eval(img, channels=[0, 0], **options)
66 50
67 # save masks to tiff 51 # save masks to tiff
68 with warnings.catch_warnings(): 52 with warnings.catch_warnings():
69 warnings.simplefilter("ignore") 53 warnings.simplefilter("ignore")
70 skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'), 54 skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'),
71 masks.astype(np.uint16)) 55 masks.astype(np.uint16))
72 56
73 # make segmentation show # 57 # make segmentation show #
74 if params['show_segmentation']: 58 if params['show_segmentation']:
75 img = skimage.io.imread(img_path) 59 img = skimage.io.imread(img_path)
76 # uniform image
77 if img_format.endswith('tiff'):
78 img = np.transpose(img, (1, 2, 0))
79 img = transforms.reshape(img, channels=channels, chan_first=chan_first)
80 60
81 maski = masks 61 maski = masks
82 flowi = flows[0] 62 flowi = flows[0]
83 fig = plt.figure(figsize=(12, 3)) 63 fig = plt.figure(figsize=(8, 2))
84 # can save images (set save_dir=None if not) 64 # can save images (set save_dir=None if not)
85 plot.show_segmentation(fig, img, maski, flowi, channels=channels) 65 plot.show_segmentation(fig, img, maski, flowi, channels=[0, 0])
86 fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300) 66 fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300)
87 plt.close(fig) 67 plt.close(fig)
88 68
89 69
90 if __name__ == '__main__': 70 if __name__ == '__main__':
91 aparser = argparse.ArgumentParser() 71 aparser = argparse.ArgumentParser()
92 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 72 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
93 aparser.add_argument("-p", "--img_path", dest="img_path") 73 aparser.add_argument("-p", "--img_path", dest="img_path")
94 aparser.add_argument("-f", "--img_format", dest="img_format")
95 aparser.add_argument("-O", "--output_dir", dest="output_dir") 74 aparser.add_argument("-O", "--output_dir", dest="output_dir")
96 args = aparser.parse_args() 75 args = aparser.parse_args()
97 76
98 main(args.inputs, args.img_path, args.img_format, args.output_dir) 77 main(args.inputs, args.img_path, args.output_dir)