Source code for blimp.processing.segment

from typing import Union, Optional, List
from pathlib import Path
import logging

from aicsimageio import AICSImage
import numpy as np
import mahotas as mh

from blimp.utils import get_channel_names

logger = logging.getLogger(__name__)


[docs] def segment_nuclei_cellpose( intensity_image: AICSImage, nuclei_channel: int = 0, pretrained_model: Union[str, Path, None] = None, diameter: Optional[int] = None, threshold: float = 0, flow_threshold: float = 0.4, normalize: Union[bool, dict] = True, gpu: bool = False, ) -> AICSImage: """Segment nuclei in 2D images across all timepoints using cellpose 4. Parameters ---------- intensity_image intensity image in 5D format "TCZYX" where Z=1 nuclei_channel channel number corresponding to nuclear stain pretrained_model path to custom pretrained model, if None uses default "cpsam" model diameter estimated diameter of nuclei in pixels, if None cellpose estimates threshold cellprob_threshold, float between [-6,+6] after which objects are discarded flow_threshold flow error threshold for filtering masks normalize normalization settings, can be bool or dict of parameters gpu whether to use GPU acceleration, by default False Returns ------- AICSImage label image with segmented nuclei for all timepoints Raises ------ ValueError If input image has Z dimension > 1 (3D images not supported) """ from cellpose import models # Check that input is 2D only if intensity_image.dims.Z > 1: raise ValueError( f"segment_nuclei_cellpose only supports 2D images (Z=1). " f"Input image has Z={intensity_image.dims.Z}. " f"For 3D segmentation, use cellpose with do_3D=True directly." ) # Initialize model once for all timepoints if pretrained_model is None: logger.debug("Initializing cellpose with default cpsam model") cellpose_model = models.CellposeModel(gpu=gpu) else: logger.debug(f"Initializing cellpose with pretrained model {str(pretrained_model)}") cellpose_model = models.CellposeModel(gpu=gpu, pretrained_model=str(pretrained_model)) # Segment all timepoints all_masks = [] n_timepoints = intensity_image.dims.T for t in range(n_timepoints): logger.debug(f"Segmenting nuclei at timepoint {t}/{n_timepoints-1}") # Extract single 2D image nuclei_image = intensity_image.get_image_data("YX", C=nuclei_channel, T=t, Z=0) # Add channel dimension (YX -> CYX) # Cellpose's convert_image will automatically convert to 3 channels nuclei_image_with_channel = nuclei_image[np.newaxis, :, :] # Run segmentation # Note: cellpose 4 cpsam model returns 3 values (masks, flows, styles) results = cellpose_model.eval( nuclei_image_with_channel, channel_axis=0, diameter=diameter, flow_threshold=flow_threshold, cellprob_threshold=threshold, normalize=normalize, do_3D=False, ) # Extract masks (first element of returned tuple) masks = results[0] all_masks.append(masks) # Stack all timepoints and convert to AICSImage format (add C and Z dimensions) masks_stack = np.stack(all_masks)[:, np.newaxis, np.newaxis, :, :] segmentation = AICSImage( masks_stack, channel_names=["Nuclei"], physical_pixel_sizes=intensity_image.physical_pixel_sizes, ) return segmentation
def expand_objects_watershed( seeds_image: np.ndarray, background_image: np.ndarray, intensity_image: np.ndarray ) -> np.ndarray: """Expand objects. Expands objects in `seeds_image` using a watershed transform on `intensity_image`. Parameters ---------- seeds_image: objects that should be expanded background_image: regions in the image that should be considered background and should not be part of an object after expansion intensity_image: grayscale image; pixel intensities determine how far individual objects are expanded Returns ------- numpy.ndarray expanded objects """ # We compute the watershed transform using the seeds of the primary # objects and the additional seeds for the background regions. The # background regions will compete with the foreground regions and # thereby work as a stop criterion for expansion of primary objects. labels = np.where(seeds_image != 0, seeds_image, background_image) regions = mh.cwatershed(np.invert(intensity_image), labels) # Remove background regions n_objects = len(np.unique(seeds_image[seeds_image > 0])) regions[regions > n_objects] = 0 # Ensure objects are separated lines = mh.labeled.borders(regions) regions[lines] = 0 # Close holes in objects. foreground_mask = regions > 0 holes = np.logical_xor(mh.close_holes(foreground_mask), foreground_mask) holes = mh.morph.dilate(holes) holes_labeled, n_holes = mh.label(holes) for i in range(1, n_holes + 1): fill_value = np.unique(regions[holes_labeled == i])[-1] fill_value = fill_value[fill_value > 0][0] regions[holes_labeled == i] = fill_value # Remove objects that are obviously too small, i.e. smaller than any of # the seeds (this could happen when we remove certain parts of objects # after the watershed region growing) primary_sizes = mh.labeled.labeled_size(seeds_image) if len(primary_sizes) > 1: min_size = np.min(primary_sizes[1:]) + 1 regions = mh.labeled.filter_labeled(regions, min_size=min_size)[0] # Remove regions that don't overlap with seed objects and assign # correct labels to the other regions, i.e. those of the corresponding seeds. new_label_image, n_new_labels = mh.labeled.relabel(regions) lut = np.zeros(np.max(new_label_image) + 1, new_label_image.dtype) for i in range(1, n_new_labels + 1): orig_labels = seeds_image[new_label_image == i] orig_labels = orig_labels[orig_labels > 0] orig_count = np.bincount(orig_labels) orig_unique = np.where(orig_count)[0] if orig_unique.size == 1: lut[i] = orig_unique[0] elif orig_unique.size > 1: logger.warning("objects overlap after expansion: %s", ", ".join(map(str, orig_unique))) lut[i] = np.where(orig_count == np.max(orig_count))[0][0] expanded_image = lut[new_label_image] # Ensure that seed objects are fully contained within expanded objects index = (seeds_image - expanded_image) > 0 expanded_image[index] = seeds_image[index] return expanded_image def segment_secondary( primary_label_image: np.ndarray, intensity_image: np.ndarray, contrast_threshold: float, min_threshold: Optional[float] = None, max_threshold: Optional[float] = None, ) -> np.ndarray: """Segment Secondary. Detects secondary objects in an image by expanding the primary objects encoded in `primary_label_image`. The outlines of secondary objects are determined based on the watershed transform of `intensity_image` using the primary objects in `primary_label_image` as seeds. Parameters ---------- primary_label_image: numpy.ndarray[numpy.int32] 2D labeled array encoding primary objects, which serve as seeds for watershed transform intensity_image: numpy.ndarray[numpy.uint8 or numpy.uint16] 2D grayscale array that serves as gradient for watershed transform; optimally this image is enhanced with a low-pass filter contrast_threshold: int contrast threshold for automatic separation of forground from background based on locally adaptive thresholding (when ``0`` threshold defaults to `min_threshold` manual thresholding) min_threshold: int, optional minimal foreground value; pixels below `min_threshold` are considered background max_threshold: int, optional maximal foreground value; pixels above `max_threshold` are considered foreground plot: bool, optional whether a plot should be generated Returns ------- numpy.ndarray secondary_label_image Note ---- Setting `min_threshold` and `max_threshold` to the same value reduces to manual thresholding. """ if np.any(primary_label_image == 0): has_background = True else: has_background = False if not has_background: secondary_label_image = primary_label_image else: # We use adaptive thresholding to determine background regions, # i.e. regions in the intensity_image that should not be covered by # secondary objects. n_objects = len(np.unique(primary_label_image)) logger.info(f"primary label image has {n_objects -1} objects") if np.max(primary_label_image) != n_objects - 1: raise ValueError(f"Objects are not consecutively labeled, please relabel before secondary segmentation.") # SB: Added a catch for images with no primary objects # note that background is an 'object' if n_objects > 1: background_mask = mh.thresholding.bernsen(intensity_image, 5, contrast_threshold) if min_threshold is not None: logger.info(f"set lower threshold level to {min_threshold}") background_mask[intensity_image < min_threshold] = True if max_threshold is not None: logger.info(f"set upper threshold level to {max_threshold}") background_mask[intensity_image > max_threshold] = False background_label_image = (mh.label(background_mask)[0] > 0).astype(np.int32) if n_objects >= 2147483646: raise ValueError(f"Number of objects ({n_objects}) exceeds 32-bit datatype.") background_label_image[background_mask] += n_objects logger.info("detect secondary objects via watershed transform") secondary_label_image = expand_objects_watershed( primary_label_image, background_label_image, intensity_image ) else: logger.info("skipping secondary segmentation") secondary_label_image = np.zeros(primary_label_image.shape, dtype=np.int32) n_objects = len(np.unique(secondary_label_image)[1:]) logger.info("identified {n_objects} objects") return secondary_label_image def resolve_multi_parent_objects( label_image: AICSImage, measure_object: Optional[Union[int, str, List[Union[int, str]]]] = None, parent_object: Union[int, str] = 0, timepoint: int = 0, in_place: bool = True ) -> AICSImage | None: """ Resolve child objects that span multiple parent objects by removing pixels to ensure each child object is fully contained within a single parent. When a child object overlaps with multiple parent objects, pixels are assigned to the parent with which the child object has the largest overlap. Parameters ---------- label_image The labeled image containing objects in separate channels. measure_object The child object(s) to be resolved. Can be channel index, channel name, or list of indices/names. If None (default), resolve conflicts for all channels except parent_object. parent_object The parent object channel, can be index or channel name, by default 0. timepoint Timepoint at which to resolve objects, by default 0. in_place If True, modify the input label_image in place. If False, return a new AICSImage with resolved objects, by default True. Returns ------- AICSImage | None If in_place=False, returns a new AICSImage with resolved child objects. If in_place=True, returns None and modifies the input label_image. """ # Convert parent_object to channel name and index parent_object_name = get_channel_names(label_image, parent_object)[0] parent_object_index = label_image.channel_names.index(parent_object_name) # Determine which channels to process if measure_object is None: # Process all channels except parent_object_index measure_indices = [i for i in range(label_image.dims.C) if i != parent_object_index] measure_names = [label_image.channel_names[i] for i in measure_indices] logger.info(f"Resolving multi-parent conflicts for all objects except parent object {parent_object_name} (index {parent_object_index}): {measure_names}") else: # Convert measure_object to list of names and indices measure_names = get_channel_names(label_image, measure_object) measure_indices = [label_image.channel_names.index(name) for name in measure_names] # Check if any measure object is the same as parent object if parent_object_index in measure_indices: logger.warning(f"Parent object '{parent_object_name}' is also in measure_objects. Skipping it.") measure_indices = [i for i in measure_indices if i != parent_object_index] measure_names = [label_image.channel_names[i] for i in measure_indices] if not measure_indices: logger.warning("No valid measure objects to process after removing parent object.") return None # If not in_place, create a copy of the data if not in_place: new_label_stack = label_image.data.copy() # Process each measure channel for current_measure_index in measure_indices: _resolve_single_measure_object( label_image, current_measure_index, parent_object_index, timepoint, in_place, new_label_stack if not in_place else None ) # Return new AICSImage if not in_place, otherwise return None if not in_place: resolved_label_image = AICSImage( new_label_stack, channel_names=label_image.channel_names, physical_pixel_sizes=label_image.physical_pixel_sizes ) return resolved_label_image return None def _resolve_single_measure_object( label_image: AICSImage, measure_object_index: int, parent_object_index: int, timepoint: int, in_place: bool, new_label_stack: Optional[np.ndarray] = None ) -> None: """ Helper function to resolve multi-parent conflicts for a single measure object channel. Parameters ---------- label_image The labeled image containing objects in separate channels. measure_object_index Index of the channel containing child objects to be resolved. parent_object_index Index of the channel containing parent objects. timepoint Timepoint at which to resolve objects. in_place If True, modify the input label_image in place. new_label_stack If not in_place, the copied data array to modify. """ # Get the appropriate arrays based on dimensionality if label_image.dims.Z == 1: logger.debug(f"Processing channel {measure_object_index} ({label_image.channel_names[measure_object_index]}) in 2D.") label_array = label_image.get_image_data("YX", C=measure_object_index, T=timepoint, Z=0).copy() parent_label_array = label_image.get_image_data("YX", C=parent_object_index, T=timepoint, Z=0) is_2d = True elif label_image.dims.Z > 1: logger.debug(f"Processing channel {measure_object_index} ({label_image.channel_names[measure_object_index]}) in 3D ({label_image.dims.Z} Z-planes).") label_array = label_image.get_image_data("ZYX", C=measure_object_index, T=timepoint).copy() parent_label_array = label_image.get_image_data("ZYX", C=parent_object_index, T=timepoint) is_2d = False # Find all unique child object labels child_labels = np.unique(label_array[label_array > 0]) conflicts_resolved = 0 # Skip child objects that are only a single pixel child_sizes = np.bincount(label_array.ravel())[child_labels] valid_child_labels = child_labels[child_sizes > 1] # Process each valid child object for child_label in valid_child_labels: # Get mask for current child object child_mask = label_array == child_label # Find all parent labels that overlap with this child overlapping_parents = np.unique(parent_label_array[child_mask]) overlapping_parents = overlapping_parents[overlapping_parents > 0] # Remove background if len(overlapping_parents) > 1: # Child spans multiple parents - need to resolve conflicts_resolved += 1 logger.debug(f"Resolving child object {child_label} spanning {len(overlapping_parents)} parents: {overlapping_parents}") # Calculate overlap counts overlap_counts = np.zeros(len(overlapping_parents), dtype=np.int64) for i, parent_label in enumerate(overlapping_parents): overlap_counts[i] = np.sum(child_mask & (parent_label_array == parent_label)) # Find parent with largest overlap best_parent_idx = np.argmax(overlap_counts) best_parent = overlapping_parents[best_parent_idx] overlap_count = overlap_counts[best_parent_idx] logger.debug(f"Assigning to parent {best_parent} (overlap: {overlap_count} pixels)") # Remove child pixels that don't belong to the best parent remove_mask = child_mask & (parent_label_array != best_parent) label_array[remove_mask] = 0 logger.info(f"Resolved {conflicts_resolved} multi-parent conflicts for {label_image.channel_names[measure_object_index]} objects") # Update the label image data if is_2d: # For 2D, update the specific slice if in_place: label_image.data[timepoint, measure_object_index, 0, :, :] = label_array else: new_label_stack[timepoint, measure_object_index, 0, :, :] = label_array else: # For 3D, update the entire volume if in_place: label_image.data[timepoint, measure_object_index, :, :, :] = label_array else: new_label_stack[timepoint, measure_object_index, :, :, :] = label_array def mask_child_objects_by_parent( label_image: AICSImage, measure_object: Optional[Union[int, str, List[Union[int, str]]]] = None, parent_object: Union[int, str] = 0, timepoint: int = 0, in_place: bool = True ) -> AICSImage | None: """ Mask child objects by parent objects, removing any pixels that extend beyond parent boundaries. This function masks (sets to zero) any parts of child objects that extend outside their parent objects, ensuring all child objects are fully contained within parent boundaries. This is useful for enforcing parent-child relationships in multi-channel segmentation data. Parameters ---------- label_image The labeled image containing objects in separate channels. measure_object The child object(s) to be masked. Can be channel index, channel name, or list of indices/names. If None (default), mask all channels except parent_object. parent_object The parent object channel used as a mask, can be index or channel name, by default 0. timepoint Timepoint at which to mask objects, by default 0. in_place If True, modify the input label_image in place. If False, return a new AICSImage with masked objects, by default True. Returns ------- AICSImage | None If in_place=False, returns a new AICSImage with masked child objects. If in_place=True, returns None and modifies the input label_image. Examples -------- >>> # Mask all objects to be within cell boundaries >>> masked_labels = mask_child_objects_by_parent( ... label_image, ... parent_object='Cell', ... in_place=False ... ) >>> # Mask specific organelles to be within nuclei >>> mask_child_objects_by_parent( ... label_image, ... measure_object=['Organelle1', 'Organelle2'], ... parent_object='Nucleus' ... ) """ # Convert parent_object to channel name and index parent_object_name = get_channel_names(label_image, parent_object)[0] parent_object_index = label_image.channel_names.index(parent_object_name) # Determine which channels to process if measure_object is None: # Process all channels except parent_object_index measure_indices = [i for i in range(label_image.dims.C) if i != parent_object_index] measure_names = [label_image.channel_names[i] for i in measure_indices] logger.info(f"Masking child objects by parent for all objects except parent object {parent_object_name} (index {parent_object_index}): {measure_names}") else: # Convert measure_object to list of names and indices measure_names = get_channel_names(label_image, measure_object) measure_indices = [label_image.channel_names.index(name) for name in measure_names] # Check if any measure object is the same as parent object if parent_object_index in measure_indices: logger.warning(f"Parent object '{parent_object_name}' is also in measure_objects. Skipping it.") measure_indices = [i for i in measure_indices if i != parent_object_index] measure_names = [label_image.channel_names[i] for i in measure_indices] if not measure_indices: logger.warning("No valid measure objects to process after removing parent object.") return None # If not in_place, create a copy of the data if not in_place: new_label_stack = label_image.data.copy() # Process each measure channel for current_measure_index in measure_indices: _mask_single_measure_object_by_parent( label_image, current_measure_index, parent_object_index, timepoint, in_place, new_label_stack if not in_place else None ) # Return new AICSImage if not in_place, otherwise return None if not in_place: masked_label_image = AICSImage( new_label_stack, channel_names=label_image.channel_names, physical_pixel_sizes=label_image.physical_pixel_sizes ) return masked_label_image return None def _mask_single_measure_object_by_parent( label_image: AICSImage, measure_object_index: int, parent_object_index: int, timepoint: int, in_place: bool, new_label_stack: Optional[np.ndarray] = None ) -> None: """ Helper function to mask a single child object channel by parent objects. Removes pixels from child objects that extend outside their parent boundaries, ensuring child objects are fully contained within parent objects. Parameters ---------- label_image The labeled image containing objects in separate channels. measure_object_index Index of the channel containing child objects to be masked. parent_object_index Index of the channel containing parent objects used as masks. timepoint Timepoint at which to mask objects. in_place If True, modify the input label_image in place. new_label_stack If not in_place, the copied data array to modify. """ # Get the appropriate arrays based on dimensionality if label_image.dims.Z == 1: logger.debug(f"Masking object channel {measure_object_index} ({label_image.channel_names[measure_object_index]}) by parent in 2D.") label_array = label_image.get_image_data("YX", C=measure_object_index, T=timepoint, Z=0).copy() parent_label_array = label_image.get_image_data("YX", C=parent_object_index, T=timepoint, Z=0) is_2d = True elif label_image.dims.Z > 1: logger.debug(f"Masking object channel {measure_object_index} ({label_image.channel_names[measure_object_index]}) by parent in 3D ({label_image.dims.Z} Z-planes).") label_array = label_image.get_image_data("ZYX", C=measure_object_index, T=timepoint).copy() parent_label_array = label_image.get_image_data("ZYX", C=parent_object_index, T=timepoint) is_2d = False # Count objects before masking initial_objects = len(np.unique(label_array[label_array > 0])) initial_pixels = np.sum(label_array > 0) # Set all child object pixels outside parent objects to zero (mask them) outside_parent_mask = (label_array > 0) & (parent_label_array == 0) label_array[outside_parent_mask] = 0 # Count objects and pixels after masking final_objects = len(np.unique(label_array[label_array > 0])) final_pixels = np.sum(label_array > 0) removed_pixels = initial_pixels - final_pixels logger.info(f"Masked {label_image.channel_names[measure_object_index]} objects by parent: " f"{initial_objects} -> {final_objects} objects, removed {removed_pixels} pixels outside parent") # Update the label image data if is_2d: # For 2D, update the specific slice if in_place: label_image.data[timepoint, measure_object_index, 0, :, :] = label_array else: new_label_stack[timepoint, measure_object_index, 0, :, :] = label_array else: # For 3D, update the entire volume if in_place: label_image.data[timepoint, measure_object_index, :, :, :] = label_array else: new_label_stack[timepoint, measure_object_index, :, :, :] = label_array