Mercurial > repos > bgruening > cellpose
comparison cp_segmentation.py @ 3:c793edde4284 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/cellpose commit 3f2ba60f101c923896ca95ed62981fcbb0a5ced3
| author | bgruening |
|---|---|
| date | Fri, 12 Dec 2025 12:36:21 +0000 |
| parents | e5370bb71633 |
| children |
comparison
equal
deleted
inserted
replaced
| 2:e5370bb71633 | 3:c793edde4284 |
|---|---|
| 12 # Load the remaining packages *after* adjusting `MKL_NUM_THREADS` (this likely necessary for it to take effect) | 12 # Load the remaining packages *after* adjusting `MKL_NUM_THREADS` (this likely necessary for it to take effect) |
| 13 import matplotlib.pyplot as plt | 13 import matplotlib.pyplot as plt |
| 14 import numpy as np | 14 import numpy as np |
| 15 import skimage.io | 15 import skimage.io |
| 16 import torch | 16 import torch |
| 17 from cellpose import models, plot, transforms | 17 from cellpose import models, plot |
| 18 | 18 |
| 19 # Apply PyTorch guidelines for reproducibility | 19 # Apply PyTorch guidelines for reproducibility |
| 20 torch.backends.cudnn.benchmark = True | 20 torch.backends.cudnn.benchmark = True |
| 21 torch.backends.cudnn.deterministic = True | 21 torch.backends.cudnn.deterministic = True |
| 22 torch.manual_seed(0) | 22 torch.manual_seed(0) |
| 23 | 23 |
| 24 | 24 |
| 25 def main(inputs, img_path, img_format, output_dir): | 25 def main(inputs, img_path, output_dir): |
| 26 """ | 26 """ |
| 27 Parameter | 27 Parameter |
| 28 --------- | 28 --------- |
| 29 inputs : str | 29 inputs : str |
| 30 File path to galaxy tool parameter | 30 File path to galaxy tool parameter |
| 31 img_path : str | 31 img_path : str |
| 32 File path for the input image | 32 File path for the input image |
| 33 img_format : str | |
| 34 One of the ['ome.tiff', 'tiff', 'png', 'jpg'] | |
| 35 output_dir : str | 33 output_dir : str |
| 36 Folder to save the outputs. | 34 Folder to save the outputs. |
| 37 """ | 35 """ |
| 38 warnings.simplefilter('ignore') | 36 warnings.simplefilter('ignore') |
| 39 np.random.seed(42) | 37 np.random.seed(42) |
| 40 with open(inputs, 'r') as param_handler: | 38 with open(inputs, 'r') as param_handler: |
| 41 params = json.load(param_handler) | 39 params = json.load(param_handler) |
| 42 | 40 |
| 43 gpu = params['use_gpu'] | 41 gpu = params['use_gpu'] |
| 44 model_type = params['model_type'] | 42 model_type = params['model_type'] |
| 45 chan = params['chan'] | |
| 46 chan2 = params['chan2'] | |
| 47 chan_first = params['chan_first'] | |
| 48 if chan is None: | |
| 49 channels = None | |
| 50 else: | |
| 51 channels = [int(chan), int(chan2) if chan2 is not None else None] | |
| 52 | |
| 53 options = params['options'] | 43 options = params['options'] |
| 54 | |
| 55 img = skimage.io.imread(img_path) | 44 img = skimage.io.imread(img_path) |
| 56 | 45 |
| 57 print(f"Image shape: {img.shape}") | 46 print(f"Image shape: {img.shape}") |
| 58 # transpose to Ly x Lx x nchann and reshape based on channels | |
| 59 if img_format.endswith('tiff'): | |
| 60 img = np.transpose(img, (1, 2, 0)) | |
| 61 img = transforms.reshape(img, channels=channels, chan_first=chan_first) | |
| 62 | 47 |
| 63 print(f"Image shape: {img.shape}") | |
| 64 model = models.Cellpose(gpu=gpu, model_type=model_type) | 48 model = models.Cellpose(gpu=gpu, model_type=model_type) |
| 65 masks, flows, styles, diams = model.eval(img, channels=channels, **options) | 49 masks, flows, styles, diams = model.eval(img, channels=[0, 0], **options) |
| 66 | 50 |
| 67 # save masks to tiff | 51 # save masks to tiff |
| 68 with warnings.catch_warnings(): | 52 with warnings.catch_warnings(): |
| 69 warnings.simplefilter("ignore") | 53 warnings.simplefilter("ignore") |
| 70 skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'), | 54 skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'), |
| 71 masks.astype(np.uint16)) | 55 masks.astype(np.uint16)) |
| 72 | 56 |
| 73 # make segmentation show # | 57 # make segmentation show # |
| 74 if params['show_segmentation']: | 58 if params['show_segmentation']: |
| 75 img = skimage.io.imread(img_path) | 59 img = skimage.io.imread(img_path) |
| 76 # uniform image | |
| 77 if img_format.endswith('tiff'): | |
| 78 img = np.transpose(img, (1, 2, 0)) | |
| 79 img = transforms.reshape(img, channels=channels, chan_first=chan_first) | |
| 80 | 60 |
| 81 maski = masks | 61 maski = masks |
| 82 flowi = flows[0] | 62 flowi = flows[0] |
| 83 fig = plt.figure(figsize=(12, 3)) | 63 fig = plt.figure(figsize=(8, 2)) |
| 84 # can save images (set save_dir=None if not) | 64 # can save images (set save_dir=None if not) |
| 85 plot.show_segmentation(fig, img, maski, flowi, channels=channels) | 65 plot.show_segmentation(fig, img, maski, flowi, channels=[0, 0]) |
| 86 fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300) | 66 fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300) |
| 87 plt.close(fig) | 67 plt.close(fig) |
| 88 | 68 |
| 89 | 69 |
| 90 if __name__ == '__main__': | 70 if __name__ == '__main__': |
| 91 aparser = argparse.ArgumentParser() | 71 aparser = argparse.ArgumentParser() |
| 92 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 72 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
| 93 aparser.add_argument("-p", "--img_path", dest="img_path") | 73 aparser.add_argument("-p", "--img_path", dest="img_path") |
| 94 aparser.add_argument("-f", "--img_format", dest="img_format") | |
| 95 aparser.add_argument("-O", "--output_dir", dest="output_dir") | 74 aparser.add_argument("-O", "--output_dir", dest="output_dir") |
| 96 args = aparser.parse_args() | 75 args = aparser.parse_args() |
| 97 | 76 |
| 98 main(args.inputs, args.img_path, args.img_format, args.output_dir) | 77 main(args.inputs, args.img_path, args.output_dir) |
