Mercurial > repos > bgruening > cellpose
view cp_segmentation.py @ 2:e5370bb71633 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/cellpose commit 5846f23282167df28e6cc00e1bda31523b66b5b4
author | bgruening |
---|---|
date | Sat, 15 Mar 2025 17:23:55 +0000 |
parents | 1857781df226 |
children |
line wrap: on
line source
import argparse import json import os import warnings # Make sure, that `MKL_NUM_THREADS` is set to 1, to ensure reproducibility: # https://forum.image.sc/t/reproducibility-how-we-spent-years-building-containers-and-then-mkl-decided-to-screw-our-results/109599 if str(os.environ['MKL_NUM_THREADS']) != '1': warnings.warn('MKL_NUM_THREADS must be set to 1 to ensure reproducibility, and will be adjusted accordingly for now.') os.environ['MKL_NUM_THREADS'] = '1' # Load the remaining packages *after* adjusting `MKL_NUM_THREADS` (this likely necessary for it to take effect) import matplotlib.pyplot as plt import numpy as np import skimage.io import torch from cellpose import models, plot, transforms # Apply PyTorch guidelines for reproducibility torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True torch.manual_seed(0) def main(inputs, img_path, img_format, output_dir): """ Parameter --------- inputs : str File path to galaxy tool parameter img_path : str File path for the input image img_format : str One of the ['ome.tiff', 'tiff', 'png', 'jpg'] output_dir : str Folder to save the outputs. """ warnings.simplefilter('ignore') np.random.seed(42) with open(inputs, 'r') as param_handler: params = json.load(param_handler) gpu = params['use_gpu'] model_type = params['model_type'] chan = params['chan'] chan2 = params['chan2'] chan_first = params['chan_first'] if chan is None: channels = None else: channels = [int(chan), int(chan2) if chan2 is not None else None] options = params['options'] img = skimage.io.imread(img_path) print(f"Image shape: {img.shape}") # transpose to Ly x Lx x nchann and reshape based on channels if img_format.endswith('tiff'): img = np.transpose(img, (1, 2, 0)) img = transforms.reshape(img, channels=channels, chan_first=chan_first) print(f"Image shape: {img.shape}") model = models.Cellpose(gpu=gpu, model_type=model_type) masks, flows, styles, diams = model.eval(img, channels=channels, **options) # save masks to tiff with warnings.catch_warnings(): warnings.simplefilter("ignore") skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'), masks.astype(np.uint16)) # make segmentation show # if params['show_segmentation']: img = skimage.io.imread(img_path) # uniform image if img_format.endswith('tiff'): img = np.transpose(img, (1, 2, 0)) img = transforms.reshape(img, channels=channels, chan_first=chan_first) maski = masks flowi = flows[0] fig = plt.figure(figsize=(12, 3)) # can save images (set save_dir=None if not) plot.show_segmentation(fig, img, maski, flowi, channels=channels) fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300) plt.close(fig) if __name__ == '__main__': aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-p", "--img_path", dest="img_path") aparser.add_argument("-f", "--img_format", dest="img_format") aparser.add_argument("-O", "--output_dir", dest="output_dir") args = aparser.parse_args() main(args.inputs, args.img_path, args.img_format, args.output_dir)