Mercurial > repos > bgruening > bioimage_inference
comparison main.py @ 3:bc28236f407b draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit e08711c242a340a1671dfca35f52d3724086e968
| author | bgruening |
|---|---|
| date | Wed, 26 Feb 2025 10:27:28 +0000 |
| parents | 9f9ae2ac7820 |
| children | 2b61d8fcfa52 |
comparison
equal
deleted
inserted
replaced
| 2:0c0de5546fe1 | 3:bc28236f407b |
|---|---|
| 5 import argparse | 5 import argparse |
| 6 | 6 |
| 7 import imageio | 7 import imageio |
| 8 import numpy as np | 8 import numpy as np |
| 9 import torch | 9 import torch |
| 10 import torch.nn.functional as F | |
| 10 | 11 |
| 11 | 12 |
| 12 def find_dim_order(user_in_shape, input_image): | 13 def dynamic_resize(image: torch.Tensor, target_shape: tuple): |
| 13 """ | 14 """ |
| 14 Find the correct order of input image's | 15 Resize an input tensor dynamically to the target shape. |
| 15 shape. For a few models, the order of input size | 16 |
| 16 mentioned in the RDF.yaml file is reversed compared | 17 Parameters: |
| 17 to the input image's original size. If it is reversed, | 18 - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims) |
| 18 transpose the image to find correct order of image's | 19 - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN') |
| 19 dimensions. | 20 |
| 21 Returns: | |
| 22 - Resized tensor with target shape target_shape. | |
| 20 """ | 23 """ |
| 21 image_shape = list(input_image.shape) | 24 # Extract input shape |
| 22 # reverse the input shape provided from RDF.yaml file | 25 input_shape = image.shape |
| 23 correct_order = user_in_shape.split(",")[::-1] | 26 num_dims = len(input_shape) # Includes channels and spatial dimensions |
| 24 # remove 1s from the original dimensions | 27 |
| 25 correct_order = [int(i) for i in correct_order if i != "1"] | 28 # Ensure target shape matches the number of dimensions |
| 26 if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape): | 29 if len(target_shape) != num_dims: |
| 27 input_image = torch.tensor(input_image.transpose()) | 30 raise ValueError( |
| 28 return input_image, correct_order | 31 f"Target shape {target_shape} must match input dimensions {num_dims}" |
| 32 ) | |
| 33 | |
| 34 # Extract target channels and spatial sizes | |
| 35 target_channels = target_shape[0] # First element is the target channel count | |
| 36 target_spatial_size = target_shape[1:] # Remaining elements are spatial dimensions | |
| 37 | |
| 38 # Add batch dim (N=1) for resizing | |
| 39 image = image.unsqueeze(0) | |
| 40 | |
| 41 # Choose the best interpolation mode based on dimensionality | |
| 42 if num_dims == 4: | |
| 43 interp_mode = "trilinear" | |
| 44 elif num_dims == 3: | |
| 45 interp_mode = "bilinear" | |
| 46 elif num_dims == 2: | |
| 47 interp_mode = "bicubic" | |
| 48 else: | |
| 49 interp_mode = "nearest" | |
| 50 | |
| 51 # Resize spatial dimensions dynamically | |
| 52 image = F.interpolate( | |
| 53 image, size=target_spatial_size, mode=interp_mode, align_corners=False | |
| 54 ) | |
| 55 | |
| 56 # Adjust channels if necessary | |
| 57 current_channels = image.shape[1] | |
| 58 | |
| 59 if target_channels > current_channels: | |
| 60 # Expand channels by repeating existing ones | |
| 61 expand_factor = target_channels // current_channels | |
| 62 remainder = target_channels % current_channels | |
| 63 image = image.repeat(1, expand_factor, *[1] * (num_dims - 1)) | |
| 64 | |
| 65 if remainder > 0: | |
| 66 extra_channels = image[ | |
| 67 :, :remainder, ... | |
| 68 ] # Take the first few channels to match target | |
| 69 image = torch.cat([image, extra_channels], dim=1) | |
| 70 | |
| 71 elif target_channels < current_channels: | |
| 72 # Reduce channels by averaging adjacent ones | |
| 73 image = image[:, :target_channels, ...] # Simply slice to reduce channels | |
| 74 return image.squeeze(0) # Remove batch dimension before returning | |
| 29 | 75 |
| 30 | 76 |
| 31 if __name__ == "__main__": | 77 if __name__ == "__main__": |
| 32 arg_parser = argparse.ArgumentParser() | 78 arg_parser = argparse.ArgumentParser() |
| 33 arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model") | 79 arg_parser.add_argument( |
| 34 arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file") | 80 "-im", "--imaging_model", required=True, help="Input BioImage model" |
| 35 arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size") | 81 ) |
| 82 arg_parser.add_argument( | |
| 83 "-ii", "--image_file", required=True, help="Input image file" | |
| 84 ) | |
| 85 arg_parser.add_argument( | |
| 86 "-is", "--image_size", required=True, help="Input image file's size" | |
| 87 ) | |
| 88 arg_parser.add_argument( | |
| 89 "-ia", "--image_axes", required=True, help="Input image file's axes" | |
| 90 ) | |
| 36 | 91 |
| 37 # get argument values | 92 # get argument values |
| 38 args = vars(arg_parser.parse_args()) | 93 args = vars(arg_parser.parse_args()) |
| 39 model_path = args["imaging_model"] | 94 model_path = args["imaging_model"] |
| 40 input_image_path = args["image_file"] | 95 input_image_path = args["image_file"] |
| 96 input_size = args["image_size"] | |
| 41 | 97 |
| 42 # load all embedded images in TIF file | 98 # load all embedded images in TIF file |
| 43 test_data = imageio.v3.imread(input_image_path, index="...") | 99 test_data = imageio.v3.imread(input_image_path, index="...") |
| 100 test_data = test_data.astype(np.float32) | |
| 44 test_data = np.squeeze(test_data) | 101 test_data = np.squeeze(test_data) |
| 45 test_data = test_data.astype(np.float32) | |
| 46 | 102 |
| 47 # assess the correct dimensions of TIF input image | 103 target_image_dim = input_size.split(",")[::-1] |
| 48 input_image_shape = args["image_size"] | 104 target_image_dim = [int(i) for i in target_image_dim if i != "1"] |
| 49 im_test_data, shape_vals = find_dim_order(input_image_shape, test_data) | 105 target_image_dim = tuple(target_image_dim) |
| 106 | |
| 107 exp_test_data = torch.tensor(test_data) | |
| 108 # check if image dimensions are reversed | |
| 109 reversed_order = list(reversed(range(exp_test_data.dim()))) | |
| 110 exp_test_data_T = exp_test_data.permute(*reversed_order) | |
| 111 if exp_test_data_T.shape == target_image_dim: | |
| 112 exp_test_data = exp_test_data_T | |
| 113 if exp_test_data.shape != target_image_dim: | |
| 114 for i in range(len(target_image_dim) - exp_test_data.dim()): | |
| 115 exp_test_data = exp_test_data.unsqueeze(i) | |
| 116 try: | |
| 117 exp_test_data = dynamic_resize(exp_test_data, target_image_dim) | |
| 118 except Exception as e: | |
| 119 raise RuntimeError(f"Error during resizing: {e}") from e | |
| 120 | |
| 121 current_dimension = len(exp_test_data.shape) | |
| 122 input_axes = args["image_axes"] | |
| 123 target_dimension = len(input_axes) | |
| 124 # expand input image based on the number of target dimensions | |
| 125 for i in range(target_dimension - current_dimension): | |
| 126 exp_test_data = torch.unsqueeze(exp_test_data, i) | |
| 50 | 127 |
| 51 # load model | 128 # load model |
| 52 model = torch.load(model_path) | 129 model = torch.load(model_path) |
| 53 model.eval() | 130 model.eval() |
| 54 | |
| 55 # find the number of dimensions required by the model | |
| 56 target_dimension = 0 | |
| 57 for param in model.named_parameters(): | |
| 58 target_dimension = len(param[1].shape) | |
| 59 break | |
| 60 current_dimension = len(list(im_test_data.shape)) | |
| 61 | |
| 62 # update the dimensions of input image if the required image by | |
| 63 # the model is smaller | |
| 64 slices = tuple(slice(0, s_val) for s_val in shape_vals) | |
| 65 | |
| 66 # apply the slices to the reshaped_input | |
| 67 im_test_data = im_test_data[slices] | |
| 68 exp_test_data = torch.tensor(im_test_data) | |
| 69 | |
| 70 # expand input image's dimensions | |
| 71 for i in range(target_dimension - current_dimension): | |
| 72 exp_test_data = torch.unsqueeze(exp_test_data, i) | |
| 73 | 131 |
| 74 # make prediction | 132 # make prediction |
| 75 pred_data = model(exp_test_data) | 133 pred_data = model(exp_test_data) |
| 76 pred_data_output = pred_data.detach().numpy() | 134 pred_data_output = pred_data.detach().numpy() |
| 77 | 135 |
