view filter.py @ 2:b2d9c92bc431 draft

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/2d_simple_filter/ commit a6fd77be465068f709a71d377900da99becf94d8
author imgteam
date Fri, 12 Dec 2025 21:18:04 +0000
parents
children 5ab62693dca5
line wrap: on
line source

import argparse
import json
from typing import (
    Any,
    Callable,
)

import giatools
import numpy as np
import scipy.ndimage as ndi
from skimage.morphology import disk


def image_astype(img: giatools.Image, dtype: np.dtype) -> giatools.Image:
    return giatools.Image(
        data=img.data.astype(dtype),
        axes=img.axes,
        original_axes=img.original_axes,
        metadata=img.metadata,
    )


filters = {
    'gaussian': lambda img, sigma, order=0, axis=None: (
        apply_2d_filter(
            ndi.gaussian_filter,
            img if order == 0 else image_astype(img, float),
            sigma=sigma,
            order=order,
            axes=axis,
        )
    ),
    'uniform': lambda img, size: (
        apply_2d_filter(ndi.uniform_filter, img, size=size)
    ),
    'median': lambda img, radius: (
        apply_2d_filter(ndi.median_filter, img, footprint=disk(radius))
    ),
    'prewitt': lambda img, axis: (
        apply_2d_filter(ndi.prewitt, img, axis=axis)
    ),
    'sobel': lambda img, axis: (
        apply_2d_filter(ndi.sobel, img, axis=axis)
    ),
}


def apply_2d_filter(
    filter_impl: Callable[[np.ndarray, Any, ...], np.ndarray],
    img: giatools.Image,
    **kwargs: Any,
) -> giatools.Image:
    """
    Apply the 2-D filter to the 2-D/3-D, potentially multi-frame and multi-channel image.
    """
    result_data = None
    for qtzc in np.ndindex(
        img.data.shape[ 0],  # Q axis
        img.data.shape[ 1],  # T axis
        img.data.shape[ 2],  # Z axis
        img.data.shape[-1],  # C axis
    ):
        sl = np.s_[*qtzc[:3], ..., qtzc[3]]  # noqa: E999
        arr = img.data[sl]
        assert arr.ndim == 2  # sanity check, should always be True

        # Perform 2-D filtering
        res = filter_impl(arr, **kwargs)
        if result_data is None:
            result_data = np.empty(img.data.shape, res.dtype)
        result_data[sl] = res

    # Return results
    return giatools.Image(result_data, img.axes)


def apply_filter(
    input_filepath: str,
    output_filepath: str,
    filter_type: str,
    **kwargs: Any,
):
    # Read the input image
    img = giatools.Image.read(input_filepath)

    # Perform filtering
    filter_impl = filters[filter_type]
    res = filter_impl(img, **kwargs).normalize_axes_like(img.original_axes)

    # Adopt metadata and write the result
    res.metadata = img.metadata
    res.write(output_filepath, backend='tifffile')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('input', type=str, help='Input image filepath')
    parser.add_argument('output', type=str, help='Output image filepath (TIFF)')
    parser.add_argument('params', type=str)
    args = parser.parse_args()

    # Read the config file
    with open(args.params) as cfgf:
        cfg = json.load(cfgf)

    apply_filter(
        args.input,
        args.output,
        **cfg,
    )