Mercurial > repos > bgruening > bioimage_inference
view main.py @ 3:bc28236f407b draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit e08711c242a340a1671dfca35f52d3724086e968
author | bgruening |
---|---|
date | Wed, 26 Feb 2025 10:27:28 +0000 |
parents | 9f9ae2ac7820 |
children | 2b61d8fcfa52 |
line wrap: on
line source
""" Predict images using AI models from BioImage.IO """ import argparse import imageio import numpy as np import torch import torch.nn.functional as F def dynamic_resize(image: torch.Tensor, target_shape: tuple): """ Resize an input tensor dynamically to the target shape. Parameters: - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims) - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN') Returns: - Resized tensor with target shape target_shape. """ # Extract input shape input_shape = image.shape num_dims = len(input_shape) # Includes channels and spatial dimensions # Ensure target shape matches the number of dimensions if len(target_shape) != num_dims: raise ValueError( f"Target shape {target_shape} must match input dimensions {num_dims}" ) # Extract target channels and spatial sizes target_channels = target_shape[0] # First element is the target channel count target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions # Add batch dim (N=1) for resizing image = image.unsqueeze(0) # Choose the best interpolation mode based on dimensionality if num_dims == 4: interp_mode = "trilinear" elif num_dims == 3: interp_mode = "bilinear" elif num_dims == 2: interp_mode = "bicubic" else: interp_mode = "nearest" # Resize spatial dimensions dynamically image = F.interpolate( image, size=target_spatial_size, mode=interp_mode, align_corners=False ) # Adjust channels if necessary current_channels = image.shape[1] if target_channels > current_channels: # Expand channels by repeating existing ones expand_factor = target_channels // current_channels remainder = target_channels % current_channels image = image.repeat(1, expand_factor, *[1] * (num_dims - 1)) if remainder > 0: extra_channels = image[ :, :remainder, ... ] # Take the first few channels to match target image = torch.cat([image, extra_channels], dim=1) elif target_channels < current_channels: # Reduce channels by averaging adjacent ones image = image[:, :target_channels, ...] # Simply slice to reduce channels return image.squeeze(0) # Remove batch dimension before returning if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( "-im", "--imaging_model", required=True, help="Input BioImage model" ) arg_parser.add_argument( "-ii", "--image_file", required=True, help="Input image file" ) arg_parser.add_argument( "-is", "--image_size", required=True, help="Input image file's size" ) arg_parser.add_argument( "-ia", "--image_axes", required=True, help="Input image file's axes" ) # get argument values args = vars(arg_parser.parse_args()) model_path = args["imaging_model"] input_image_path = args["image_file"] input_size = args["image_size"] # load all embedded images in TIF file test_data = imageio.v3.imread(input_image_path, index="...") test_data = test_data.astype(np.float32) test_data = np.squeeze(test_data) target_image_dim = input_size.split(",")[::-1] target_image_dim = [int(i) for i in target_image_dim if i != "1"] target_image_dim = tuple(target_image_dim) exp_test_data = torch.tensor(test_data) # check if image dimensions are reversed reversed_order = list(reversed(range(exp_test_data.dim()))) exp_test_data_T = exp_test_data.permute(*reversed_order) if exp_test_data_T.shape == target_image_dim: exp_test_data = exp_test_data_T if exp_test_data.shape != target_image_dim: for i in range(len(target_image_dim) - exp_test_data.dim()): exp_test_data = exp_test_data.unsqueeze(i) try: exp_test_data = dynamic_resize(exp_test_data, target_image_dim) except Exception as e: raise RuntimeError(f"Error during resizing: {e}") from e current_dimension = len(exp_test_data.shape) input_axes = args["image_axes"] target_dimension = len(input_axes) # expand input image based on the number of target dimensions for i in range(target_dimension - current_dimension): exp_test_data = torch.unsqueeze(exp_test_data, i) # load model model = torch.load(model_path) model.eval() # make prediction pred_data = model(exp_test_data) pred_data_output = pred_data.detach().numpy() # save original image matrix np.save("output_predicted_image_matrix.npy", pred_data_output) # post process predicted file to correctly save as TIF file pred_data = torch.squeeze(pred_data) pred_numpy = pred_data.detach().numpy() # write predicted TIF image to file imageio.v3.imwrite("output_predicted_image.tif", pred_numpy, extension=".tif")