Source code for marEx.track

"""
MarEx-Track: Marine Extreme Event Identification, Tracking, and Splitting/Merging Module

MarEx identifies and tracks extreme events in oceanographic data across time,
supporting both structured (regular grid) and unstructured datasets. It can identify
discrete objects at single time points and track them as evolving events through time,
seamlessly handling splitting and merging.

This package provides algorithms to:

* Identify binary objects in spatial data at each time step
* Track these objects across time to form coherent events
* Handle merging and splitting of objects over time
* Calculate and maintain object/event properties through time
* Filter by size criteria to focus on significant events

Key terminology:

* Object: A connected region in binary data at a single time point
* Event: One or more objects tracked through time and identified as the same entity
"""

import gc
import logging
import os
import shutil
import time
import warnings
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union

import dask.array as dsa
import numpy as np
import xarray as xr
from dask import persist
from dask.base import is_dask_collection
from dask.distributed import wait
from dask_image.ndmeasure import label
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from numba import jit, njit, prange
from numpy.typing import NDArray
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import coo_matrix, csr_matrix, eye
from scipy.sparse.csgraph import connected_components
from skimage.measure import regionprops_table

from ._dependencies import warn_missing_dependency
from .exceptions import ConfigurationError, TrackingError, create_coordinate_error, create_data_validation_error
from .logging_config import configure_logging, get_logger, log_dask_info, log_memory_usage, log_timing

# Get module logger
logger = get_logger(__name__)

try:
    import jax.numpy as jnp
except ImportError:
    jnp = np  # type: ignore[misc]  # Alias for jnp when JAX not available
    warn_missing_dependency("jax", "Some functionality")


# ============================
# Main Tracker Class
# ============================


[docs] class tracker: """ Tracker identifies and tracks arbitrary binary objects in spatial data through time. The tracker supports both structured (regular grid) and unstructured data, and seamlessly handles splitting & merging of objects. It identifies connected regions in binary data at each time step, and tracks these as evolving events through time. Main workflow: 1. Preprocessing: Fill spatiotemporal holes, filter small objects 2. Object identification: Label connected components at each time 3. Tracking: Determine object correspondences across time 4. Optional splitting & merging: Handle complex event evolution Parameters ---------- data_bin : xarray.DataArray Binary field of extreme points to group, label, and track (True = object, False = background) Must represent and underlying `dask` array. mask : xarray.DataArray Binary mask indicating valid regions (True = valid, False = invalid) R_fill : int The radius of the kernel used in morphological opening & closing, relating to the largest hole/gap that can be filled. In units of grid cells. area_filter_quartile : float, optional The fraction of the smallest objects to discard, i.e. the quantile defining the smallest area object retained. Quantile must be in (0-1) (e.g., 0.25 removes smallest 25%). Mutually exclusive with area_filter_absolute. Default is 0.5 if neither parameter is provided. area_filter_absolute : int, optional The minimum area (in grid cells) for an object to be retained. Mutually exclusive with area_filter_quartile. Use this for fixed minimum area thresholds (e.g., 10 cells minimum). temp_dir : str, optional Path to temporary directory for storing intermediate results T_fill : int, default=2 The permissible temporal gap (in days) between objects for tracking continuity to be maintained (must be even) allow_merging : bool, default=True Allow objects to split and merge across time. Apply splitting & merging criteria, track merge events, and maintain original identities of merged objects across time. N.B.: `False` reverts to classical `ndmeasure.label` with simplar time connectivity, i.e. Scannell et al. nn_partitioning : bool, default=False Implement a better partitioning of merged child objects based on closest parent cell. `False` reverts to using parent centroids to determine partitioning between new child objects, i.e. Di Sun & Bohai Zhang 2023. N.B.: Centroid-based partitioning has major problems with small merging objects suddenly obtaining unrealistically-large (and often disjoint) fractions of the larger object. overlap_threshold : float, default=0.5 The fraction of the smaller object's area that must overlap with the larger object's area to be considered the same event and continue tracking with the same ID. unstructured_grid : bool, default=False Whether data is on an unstructured grid dimensions : dict, default={"time": "time", "x": "lon", "y": "lat"} Mapping of dimensions to names in the data coordinates : dict, optional Coordinate names for unstructured grids. Should contain 'x' and 'y' keys for x and y coordinates. May also contain 'time' if the time coordinate name is different from the dimension name. neighbours : xarray.DataArray, optional For unstructured grid, indicates connectivity between cells cell_areas : xarray.DataArray, optional For unstructured grid, area of each cell (required). For structured grid, area of each cell (optional). If not provided, defaults to 1.0 for each cell (resulting in cell counts as areas). Note: Overridden by grid_resolution if provided for structured grids. grid_resolution : float, optional Grid resolution in degrees for structured grids only (ignored for unstructured grids). When provided, automatically calculates cell areas using spherical geometry. Overrides any provided cell_areas parameter. max_iteration : int, default=40 Maximum number of iterations for merging/splitting algorithm checkpoint : str, default='None' Checkpoint strategy ('save', 'load', or None) debug : int, default=0 Debug level (0-2) verbose : bool, optional Enable verbose logging with detailed progress information. If None, uses current global logging configuration. quiet : bool, optional Enable quiet logging with minimal output (warnings and errors only). If None, uses current global logging configuration. Note: quiet takes precedence over verbose if both are True. regional_mode : bool, default=False Enable regional mode for non-global coordinate ranges. When True, coordinate_units must be specified. coordinate_units : str, optional Coordinate units when regional_mode=True. Must be either 'degrees' or 'radians'. Examples -------- Basic tracking of marine heatwave events from preprocessed data: >>> import xarray as xr >>> import marEx >>> >>> # Load preprocessed extreme events data >>> processed = xr.open_dataset('extreme_events.nc', chunks={}) >>> extreme_events = processed.extreme_events # Boolean array >>> mask = processed.mask # Ocean/land mask >>> >>> # Initialise tracker with basic parameters >>> tracker = marEx.tracker( ... extreme_events, ... mask, ... R_fill=8, # Fill holes up to 8 grid cells ... area_filter_quartile=0.5 # Remove smallest 50% of objects ... allow_merging=False # Basic tracking without splitting/merging ... ) >>> >>> # Run tracking algorithm >>> events = tracker.run() >>> print(f"Identified {events.ID.max().compute()} distinct events") Identified 1247 distinct events Using automatic grid area calculation from resolution: >>> # For regular lat/lon grids, automatically calculate physical areas >>> grid_tracker = marEx.tracker( ... extreme_events, ... mask, ... R_fill=8, ... area_filter_quartile=0.5, ... grid_resolution=0.25 # Grid resolution in degrees ... ) >>> # Cell areas are calculated automatically using spherical geometry >>> grid_events = grid_tracker.run() Advanced tracking with merging and splitting enabled: >>> # More sophisticated tracking with temporal gap filling >>> advanced_tracker = marEx.tracker( ... extreme_events, ... mask, ... R_fill=12, # Larger spatial gap filling ... T_fill=4, # Fill up to 4-day temporal gaps ... area_filter_quartile=0.25, # More aggressive size filtering ... allow_merging=True, # Enable split/merge detection ... overlap_threshold=0.3 # Lower threshold for object linking ... ) >>> >>> events_advanced, merges_log = advanced_tracker.run(return_merges=True) >>> print(events_advanced.data_vars) Data variables: event (time, lat, lon) int32 dask.array<chunksize=(25, 180, 360)> event_centroid (time, lat, lon) int32 dask.array<chunksize=(25, 180, 360)> ID_field (time, lat, lon) int32 dask.array<chunksize=(25, 180, 360)> global_ID (time, ID) int32 dask.array<chunksize=(25, 1247)> area (time, ID) float32 dask.array<chunksize=(25, 1247)> centroid (component, time, ID) float64 dask.array<chunksize=(2, 25, 1247)> presence (time, ID) bool dask.array<chunksize=(25, 1247)> time_start (ID) datetime64[ns] dask.array<chunksize=(1247,)> time_end (ID) datetime64[ns] dask.array<chunksize=(1247,)> merge_ledger (time, ID, sibling_ID) int32 dask.array<chunksize=(25, 1247, 10)> Processing unstructured ocean model data (ICON): >>> # Load ICON ocean model data with connectivity >>> icon_data = xr.open_dataset('icon_extremes.nc', chunks={}) >>> icon_extremes = icon_data.extreme_events # (time, ncells) >>> icon_mask = icon_data.mask >>> neighbours = icon_data.neighbours # Cell connectivity >>> cell_areas = icon_data.cell_areas # Physical areas >>> >>> # Track events on unstructured grid >>> unstructured_tracker = marEx.tracker( ... icon_extremes, ... icon_mask, ... R_fill=5, # 5-neighbor radius for gap filling ... area_filter_quartile=0.6, # Remove 60% of smallest events ... unstructured_grid=True, # Enable unstructured mode ... dimensions={"x": "ncells"}, # Must specify the name of the spatial dimension ... coordinates={"x": "lon", "y": "lat"}, # Spatial coordinate names ... neighbours=neighbours, # Required for unstructured ... cell_areas=cell_areas # Required for area calculations ... ) >>> unstructured_events = unstructured_tracker.run() Memory management and checkpointing for large datasets: >>> # Use checkpointing for very large datasets >>> large_tracker = marEx.tracker( ... extreme_events, ... mask, ... R_fill=8, ... area_filter_quartile=0.5, ... temp_dir='/scratch/user/tracking_temp', # Temporary storage ... checkpoint='save' # Save intermediate results ... ) >>> # Processing can be resumed if interrupted >>> large_events = large_tracker.run() Comparing different filtering strategies: >>> # Conservative filtering - keep more events >>> conservative = marEx.tracker( ... extreme_events, mask, R_fill=5, area_filter_quartile=0.1 ... ) >>> conservative_events = conservative.run() >>> >>> # Aggressive filtering - focus on largest events >>> aggressive = marEx.tracker( ... extreme_events, mask, R_fill=15, area_filter_quartile=0.8 ... ) >>> aggressive_events = aggressive.run() >>> >>> print(f"Conservative: {conservative_events.ID.max().compute()} events") >>> print(f"Aggressive: {aggressive_events.ID.max().compute()} events") Using absolute area filtering instead of percentile-based: >>> # Filter objects smaller than 25 grid cells >>> absolute_tracker = marEx.tracker( ... extreme_events, mask, R_fill=8, area_filter_absolute=25 ... ) >>> absolute_events = absolute_tracker.run() >>> >>> # Default behavior (area_filter_quartile=0.5) when no parameters provided >>> default_tracker = marEx.tracker(extreme_events, mask, R_fill=8) >>> default_events = default_tracker.run() # Uses quartile=0.5 filtering Using physical cell areas for structured grids: >>> # Load data with irregular grid cell areas >>> grid_areas = xr.open_dataset('grid_areas.nc').cell_area # (lat, lon) in m² >>> >>> # Track events using physical areas instead of cell counts >>> physical_tracker = marEx.tracker( ... extreme_events, ... mask, ... R_fill=8, ... area_filter_quartile=0.5, ... cell_areas=grid_areas # Physical areas in m² ... ) >>> events = physical_tracker.run() >>> # Now events.area contains physical areas in m² instead of cell counts Integration with full marEx workflow: >>> # Complete workflow from raw data to tracked events >>> raw_sst = xr.open_dataset('sst_data.nc', chunks={}).sst.chunk({'time': 30}) >>> >>> # Step 1: Preprocess to identify extremes >>> processed = marEx.preprocess_data(raw_sst, threshold_percentile=95) >>> >>> # Step 2: Track extreme events >>> tracker = marEx.tracker( ... processed.extreme_events, ... processed.mask, ... R_fill=8, ... area_filter_quartile=0.5 ... ) >>> tracked_events = tracker.run() """
[docs] def __init__( self, data_bin: xr.DataArray, mask: xr.DataArray, R_fill: Union[int, float], area_filter_quartile: Optional[float] = None, area_filter_absolute: Optional[int] = None, temp_dir: Optional[str] = None, T_fill: int = 2, allow_merging: bool = True, nn_partitioning: bool = False, overlap_threshold: float = 0.5, unstructured_grid: bool = False, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, neighbours: Optional[xr.DataArray] = None, cell_areas: Optional[xr.DataArray] = None, grid_resolution: Optional[float] = None, max_iteration: int = 40, checkpoint: Optional[Literal["save", "load", "None"]] = None, debug: int = 0, verbose: Optional[bool] = None, quiet: Optional[bool] = None, regional_mode: bool = False, coordinate_units: Optional[Literal["degrees", "radians"]] = None, ) -> None: """Initialise the tracker with parameters and data.""" # Configure logging if verbose/quiet parameters are provided if verbose is not None or quiet is not None: configure_logging(verbose=verbose, quiet=quiet) # Store logging preferences self.verbose = verbose self.quiet = quiet # Log tracker initialisation logger.info("Initialising MarEx tracker") logger.info(f"Grid type: {'unstructured' if unstructured_grid else 'structured'}") logger.info( f"Parameters: R_fill={R_fill}, T_fill={T_fill}, " f"area_filter_quartile={area_filter_quartile}, area_filter_absolute={area_filter_absolute}" ) logger.debug( f"Tracking options: allow_merging={allow_merging}, nn_partitioning={nn_partitioning}, " f"overlap_threshold={overlap_threshold}" ) # Log input data info log_dask_info(logger, data_bin, "Binary input data") log_memory_usage(logger, "Tracker initialisation") self.data_bin = data_bin # Store coordinate parameters self.regional_mode = regional_mode self.coordinate_units = coordinate_units # Unify coordinate system: degrees dimensions = dimensions or {} self.timedim = dimensions.get("time", "time") self.xdim = dimensions.get("x", "lon") self.ydim: Optional[str] = dimensions.get("y", "lat") if unstructured_grid: self.timecoord = coordinates["time"] if coordinates and "time" in coordinates else self.timedim self.xcoord = coordinates["x"] if coordinates and "x" in coordinates else "lon" self.ycoord = coordinates["y"] if coordinates and "y" in coordinates else "lat" else: coordinates = coordinates or {} self.timecoord = coordinates.get("time", self.timedim) self.xcoord = coordinates.get("x", self.xdim) self.ycoord = coordinates.get("y", self.ydim) self.lat_init = data_bin[self.ycoord].persist() # Save in original units self.lon_init = data_bin[self.xcoord].persist() self._unify_coordinates() self.mask = mask self.R_fill = int(R_fill) self.T_fill = T_fill # Resolve area filtering parameters self._resolve_area_filtering_parameters(area_filter_quartile, area_filter_absolute) self.allow_merging = allow_merging self.nn_partitioning = nn_partitioning self.overlap_threshold = overlap_threshold self.lat = data_bin[self.ycoord].persist() self.lon = data_bin[self.xcoord].persist() if data_bin.chunks is not None: self.timechunks = data_bin.chunks[data_bin.dims.index(self.timedim)][0] else: raise create_data_validation_error( "Data must be chunked", details="The input data_bin must have chunk information", suggestions=["Use data_bin.chunk({'time': 10}) to chunk the data"], ) self.unstructured_grid = unstructured_grid self.checkpoint = checkpoint self.debug = debug logger.debug(f"Dimensions: time={self.timedim}, x={self.xdim}, y={self.ydim}") logger.debug(f"Coordinates: time={self.timecoord}, x={self.xcoord}, y={self.ycoord}") # Extract data_bin metadata to inherit if hasattr(self.data_bin, "attrs") and self.data_bin.attrs: self.data_attrs = self.data_bin.attrs.copy() else: self.data_attrs = {} # Input validation and preparation self._validate_inputs(neighbours, cell_areas, grid_resolution, temp_dir) # Handle cell_areas for both structured and unstructured grids if self.unstructured_grid: # Validation already done in _validate_inputs pass else: # Handle structured grids if grid_resolution is not None: # Calculate cell areas from grid resolution using spherical geometry logger.info(f"Calculating cell areas from grid resolution: {grid_resolution} degrees") # Earth radius in km R_earth = 6378.0 # Get coordinate arrays (should be in degrees) lat_coords = data_bin[self.ycoord] # Convert to radians lat_r = np.radians(lat_coords) dlat = np.radians(grid_resolution) dlon = np.radians(grid_resolution) # Calculate grid areas using spherical geometry # Area = R² * |sin(lat + dlat/2) - sin(lat - dlat/2)| * dlon grid_area = (R_earth**2 * np.abs(np.sin(lat_r + dlat / 2) - np.sin(lat_r - dlat / 2)) * dlon).astype(np.float32) # Check if cell_areas was originally provided (and warn about override) if cell_areas is not None: logger.warning("grid_resolution parameter overrides provided cell_areas for structured grid") cell_areas = grid_area elif cell_areas is None: # Create unit cell areas (resulting in cell counts) if self.ydim is None: raise ValueError("ydim should not be None for structured grids") cell_areas = xr.ones_like(data_bin.isel({self.timedim: 0}), dtype=np.float32) logger.info("No cell_areas provided for structured grid - using unit areas (cell counts)") else: # Validation already done in _validate_inputs logger.info("Using provided cell_areas for structured grid") # Store cell_areas for both grid types self.cell_area = cell_areas.astype(np.float32).persist() if self.unstructured_grid: # Remove coordinate variables for unstructured self.cell_area = self.cell_area.drop_vars({self.ycoord, self.xcoord}.intersection(set(cell_areas.coords))) self.mean_cell_area = float(cell_areas.mean().compute().item()) else: # For structured grids, calculate mean cell area self.mean_cell_area = float(cell_areas.mean().compute().item()) # Special setup for unstructured grids if unstructured_grid: # Validation already done in _validate_inputs self._setup_unstructured_grid(temp_dir, neighbours, cell_areas, max_iteration) self._configure_warnings()
def _validate_inputs( self, neighbours: Optional[xr.DataArray] = None, cell_areas: Optional[xr.DataArray] = None, grid_resolution: Optional[float] = None, temp_dir: Optional[str] = None, ) -> None: """Validate input parameters and data.""" if self.regional_mode and self.unstructured_grid: raise NotImplementedError("regional_mode is not yet implemented for unstructured grids") # For unstructured grids, adjust dimensions if self.unstructured_grid: self.ydim = None if (self.timedim, self.xdim) != self.data_bin.dims: try: self.data_bin = self.data_bin.transpose(self.timedim, self.xdim) except Exception: raise create_data_validation_error( "Invalid dimensions for unstructured data", details=f"Expected 2D array with dimensions ({self.timedim}, {self.xdim}), got {list(self.data_bin.dims)}", suggestions=[ "Ensure data has time and cell dimensions only", "Check dimension mapping in function call", ], data_info={ "actual_dims": list(self.data_bin.dims), "expected_dims": [self.timedim, self.xdim], }, ) else: # For structured grids, ensure 3D data if (self.timedim, self.ydim, self.xdim) != self.data_bin.dims: try: self.data_bin = self.data_bin.transpose(self.timedim, self.ydim, self.xdim) except Exception: raise create_data_validation_error( "Invalid dimensions for gridded data", details=( f"Expected 3D array with dimensions ({self.timedim}, {self.ydim}, {self.xdim}), " f"got {list(self.data_bin.dims)}" ), suggestions=[ "Ensure data has time, latitude, and longitude dimensions", "Check dimension mapping and coordinate names", ], data_info={ "actual_dims": list(self.data_bin.dims), "expected_dims": [self.timedim, self.ydim, self.xdim], }, ) # Check if self.timecoord, self.xcoord, and self.ycoord are in data_bin coords: if ( self.timecoord not in self.data_bin.coords or self.xcoord not in self.data_bin.coords or self.ycoord not in self.data_bin.coords ): raise create_data_validation_error( "Missing required coordinates in unstructured data", details=( f"Expected coordinates ({self.timecoord}, {self.xcoord}, {self.ycoord}), " f"but found {list(self.data_bin.coords)}" ), suggestions=[ "Ensure data_bin contains time, x, and y coordinates", "Check coordinate names in the dataset", "Specify coordinates in the tracker initialisation with `coordinates` parameter.", ], data_info={ "actual_coords": list(self.data_bin.coords), "expected_coords": [self.timecoord, self.xcoord, self.ycoord], }, ) # Check if timecoord is an index of timedim if self.timecoord != self.timedim and ( self.timedim not in self.data_bin.indexes or self.data_bin.indexes[self.timedim].name != self.timecoord ): logger.warning( f"timecoord '{self.timecoord}' is not an index of timedim '{self.timedim}'. " f"Setting '{self.timecoord}' as index for dimension '{self.timedim}'" ) self.data_bin = self.data_bin.set_index({self.timedim: self.timecoord}) # Check data type and structure if self.data_bin.data.dtype != bool: raise create_data_validation_error( "Input DataArray must be binary (boolean type)", details=f"Found dtype {self.data_bin.data.dtype}, expected bool", suggestions=[ "Convert data using da > threshold for binary events", "Use xr.where(condition, True, False) for boolean conversion", ], data_info={ "actual_dtype": str(self.data_bin.data.dtype), "expected_dtype": "bool", }, ) # Validate required parameters for unstructured grids if self.unstructured_grid: if temp_dir is None: raise create_data_validation_error( "temp_dir is required for unstructured grids", details="Unstructured grid processing requires a temporary directory", suggestions=["Provide a temp_dir parameter when using unstructured_grid=True"], ) if neighbours is None: raise create_data_validation_error( "neighbours array is required for unstructured grids", details="Unstructured grid processing requires cell connectivity information", suggestions=["Provide a neighbours parameter when using unstructured_grid=True"], ) if cell_areas is None: raise create_data_validation_error( "cell_areas array is required for unstructured grids", details="Unstructured grid processing requires cell area information", suggestions=["Provide a cell_areas parameter when using unstructured_grid=True"], ) else: # For structured grids, cell_areas is optional if cell_areas is not None: # Validate dimensions if provided expected_spatial_dims = {self.ydim, self.xdim} if set(cell_areas.dims) != expected_spatial_dims: raise create_data_validation_error( "Invalid cell_areas dimensions for structured grid", details=f"Expected spatial dimensions {expected_spatial_dims}, got {set(cell_areas.dims)}", suggestions=["Ensure cell_areas matches the spatial dimensions of your data"], ) # Validate grid_resolution parameter if grid_resolution is not None: if self.unstructured_grid: raise create_data_validation_error( "grid_resolution parameter is not supported for unstructured grids", details="Grid resolution calculation requires structured (lat/lon) coordinates", suggestions=["Use cell_areas parameter directly for unstructured grids"], ) if not isinstance(grid_resolution, (int, float)) or grid_resolution <= 0: raise create_data_validation_error( "grid_resolution must be a positive number", details=f"Received grid_resolution={grid_resolution}", suggestions=["Provide a positive float value representing grid resolution in degrees"], ) if not is_dask_collection(self.data_bin.data): raise create_data_validation_error( "Input DataArray must be Dask-backed", details="Tracking requires chunked data for efficient processing", suggestions=[ "Convert to Dask: data_bin = data_bin.chunk({'time': 10})", "Load with chunking: xr.open_dataset('file.nc', chunks={})", ], data_info={"data_type": type(self.data_bin.data).__name__}, ) if self.mask.data.dtype != bool: raise create_data_validation_error( "Mask must be binary (boolean type)", details=f"Found mask dtype {self.mask.data.dtype}, expected bool", suggestions=["Convert mask using mask > 0 or mask.astype(bool)"], data_info={"mask_dtype": str(self.mask.data.dtype)}, ) if not self.mask.any().compute().item(): raise create_data_validation_error( "Mask contains only False values", details="Mask should indicate valid regions with True values", suggestions=[ "Check mask orientation - it should mark valid (ocean) regions as True", "Invert mask if needed: mask = ~mask", "Create ocean mask from land mask", ], ) # Check chunking for spatial dimensions self._validate_spatial_chunking() # Validate resolved area filtering parameters if not self._use_absolute_filtering: # Quartile-based filtering validation if (self.area_filter_quartile < 0) or (self.area_filter_quartile > 1): raise ConfigurationError( "Invalid area_filter_quartile value", details=f"Value {self.area_filter_quartile} is outside valid range [0, 1]", suggestions=[ "Use values between 0.0 and 1.0", "Use 0.25 to filter smallest 25% of events", "Use 0.5 to keep only larger events", ], context={ "provided_value": self.area_filter_quartile, "valid_range": [0, 1], }, ) else: # Absolute filtering validation if self.area_filter_absolute <= 0: raise ConfigurationError( "Invalid area_filter_absolute value", details=f"area_filter_absolute={self.area_filter_absolute} must be positive", suggestions=[ "Set area_filter_absolute to a positive integer (e.g., 5, 10, 50)", ], context={ "area_filter_absolute": self.area_filter_absolute, }, ) if self.T_fill % 2 != 0: raise ConfigurationError( "T_fill must be even for temporal symmetry", details=f"Provided T_fill={self.T_fill} is odd", suggestions=["Use even values: 2, 4, 6, 8, etc."], context={"provided_value": self.T_fill, "requirement": "even number"}, ) def _resolve_area_filtering_parameters( self, area_filter_quartile: Optional[float], area_filter_absolute: Optional[int] ) -> None: """Resolve area filtering parameters and set internal state.""" # Count non-None parameters provided_params = sum(x is not None for x in [area_filter_quartile, area_filter_absolute]) if provided_params == 0: # Default case: use quartile-based filtering self.area_filter_quartile = 0.5 self.area_filter_absolute = 0 self._use_absolute_filtering = False elif provided_params == 1: # Single parameter provided - use it if area_filter_quartile is not None: self.area_filter_quartile = area_filter_quartile self.area_filter_absolute = 0 self._use_absolute_filtering = False else: # area_filter_absolute is not None self.area_filter_quartile = 0.0 # Set for compatibility self.area_filter_absolute = area_filter_absolute self._use_absolute_filtering = True else: # Both provided - error raise ConfigurationError( "Cannot specify both area filtering parameters", details="area_filter_quartile and area_filter_absolute are mutually exclusive", suggestions=[ "Use area_filter_quartile for percentile-based filtering (e.g., 0.25 for smallest 25%)", "Use area_filter_absolute for fixed minimum area (e.g., 10 for minimum 10 cells)", "Omit both parameters to use default quartile filtering (0.5)", ], context={ "area_filter_quartile": area_filter_quartile, "area_filter_absolute": area_filter_absolute, }, ) def _validate_spatial_chunking(self) -> None: """Validate that spatial dimensions are in single chunks for apply_ufunc operations.""" rechunk_needed = False rechunk_dims = {} # Check xdim chunking in data_bin if self.xdim in self.data_bin.chunksizes: xdim_chunks = self.data_bin.chunksizes[self.xdim] if len(xdim_chunks) > 1: warnings.warn( f"Spatial dimension '{self.xdim}' has multiple chunks ({len(xdim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk." f"Consider directly loading dataset with proper chunking to optimise performance.", UserWarning, stacklevel=3, ) rechunk_needed = True rechunk_dims[self.xdim] = -1 # Check ydim chunking for structured grids if self.ydim is not None and self.ydim in self.data_bin.chunksizes: ydim_chunks = self.data_bin.chunksizes[self.ydim] if len(ydim_chunks) > 1: warnings.warn( f"Spatial dimension '{self.ydim}' has multiple chunks ({len(ydim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk." f"Consider directly loading dataset with proper chunking to optimise performance.", UserWarning, stacklevel=3, ) rechunk_needed = True rechunk_dims[self.ydim] = -1 # Rechunk data_bin if needed if rechunk_needed: logger.info(f"Rechunking spatial dimensions: {rechunk_dims}") self.data_bin = self.data_bin.chunk(rechunk_dims) # Check mask spatial dimensions for single chunks mask_rechunk_needed = False mask_rechunk_dims = {} # Check xdim chunking in mask if self.mask.chunks is not None and self.xdim in self.mask.chunksizes: xdim_chunks = self.mask.chunksizes[self.xdim] if len(xdim_chunks) > 1: warnings.warn( f"Mask spatial dimension '{self.xdim}' has multiple chunks ({len(xdim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=3, ) mask_rechunk_needed = True mask_rechunk_dims[self.xdim] = -1 # Check ydim chunking in mask for structured grids if self.ydim is not None and self.mask.chunks is not None and self.ydim in self.mask.chunksizes: ydim_chunks = self.mask.chunksizes[self.ydim] if len(ydim_chunks) > 1: warnings.warn( f"Mask spatial dimension '{self.ydim}' has multiple chunks ({len(ydim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=3, ) mask_rechunk_needed = True mask_rechunk_dims[self.ydim] = -1 # Rechunk mask if needed if mask_rechunk_needed: logger.info(f"Rechunking mask spatial dimensions: {mask_rechunk_dims}") self.mask = self.mask.chunk(mask_rechunk_dims) # Check coordinate spatial dimensions for single chunks coord_rechunk_needed = False coord_rechunk_dims = {} # Check xdim chunking in lon coordinate if self.lon.chunks is not None and self.xdim in self.lon.chunksizes: # pragma: no cover xdim_chunks = self.lon.chunksizes[self.xdim] if len(xdim_chunks) > 1: warnings.warn( f"Longitude coordinate spatial dimension '{self.xdim}' has multiple chunks ({len(xdim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=3, ) coord_rechunk_needed = True coord_rechunk_dims[self.xdim] = -1 # Check ydim chunking in lat coordinate for structured grids if self.ydim is not None and self.lat.chunks is not None and self.ydim in self.lat.chunksizes: # pragma: no cover ydim_chunks = self.lat.chunksizes[self.ydim] if len(ydim_chunks) > 1: warnings.warn( f"Latitude coordinate spatial dimension '{self.ydim}' has multiple chunks ({len(ydim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=3, ) coord_rechunk_needed = True coord_rechunk_dims[self.ydim] = -1 # Rechunk coordinates if needed if coord_rechunk_needed: # pragma: no cover logger.info(f"Rechunking coordinate spatial dimensions: {coord_rechunk_dims}") self.lat = self.lat.chunk(coord_rechunk_dims).persist() self.lon = self.lon.chunk(coord_rechunk_dims).persist() def _validate_unstructured_chunking(self, neighbours: xr.DataArray, cell_areas: xr.DataArray) -> None: """Validate that neighbours and cell_areas are in single chunks for unstructured grids.""" # Check neighbours spatial dimensions for single chunks neighbours_rechunk_needed = False neighbours_rechunk_dims = {} # Check xdim chunking in neighbours if self.xdim in neighbours.chunksizes: xdim_chunks = neighbours.chunksizes[self.xdim] if len(xdim_chunks) > 1: warnings.warn( f"Neighbours spatial dimension '{self.xdim}' has multiple chunks ({len(xdim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=4, ) neighbours_rechunk_needed = True neighbours_rechunk_dims[self.xdim] = -1 # Check nv dimension chunking in neighbours if "nv" in neighbours.chunksizes: nv_chunks = neighbours.chunksizes["nv"] if len(nv_chunks) > 1: warnings.warn( f"Neighbours dimension 'nv' has multiple chunks ({len(nv_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=4, ) neighbours_rechunk_needed = True neighbours_rechunk_dims["nv"] = -1 # Check cell_areas spatial dimensions for single chunks cell_areas_rechunk_needed = False cell_areas_rechunk_dims = {} # Check xdim chunking in cell_areas if self.xdim in cell_areas.chunksizes: xdim_chunks = cell_areas.chunksizes[self.xdim] if len(xdim_chunks) > 1: warnings.warn( f"Cell areas spatial dimension '{self.xdim}' has multiple chunks ({len(xdim_chunks)} chunks). " f"This will cause issues with apply_ufunc operations. Rechunking to single chunk.", UserWarning, stacklevel=4, ) cell_areas_rechunk_needed = True cell_areas_rechunk_dims[self.xdim] = -1 # Apply rechunking if needed if neighbours_rechunk_needed: logger.info(f"Rechunking neighbours spatial dimensions: {neighbours_rechunk_dims}") # Note: We don't store the rechunked neighbours directly since it's a parameter # The caller should handle this if needed if cell_areas_rechunk_needed: logger.info(f"Rechunking cell_areas spatial dimensions: {cell_areas_rechunk_dims}") # Note: We don't store the rechunked cell_areas directly since it's a parameter # The caller should handle this if needed def _unify_coordinates(self) -> None: if self.regional_mode: if self.coordinate_units is None: raise create_coordinate_error( "coordinate_units must be specified when regional_mode=True", suggestions=[ "Set coordinate_units='degrees' for degree-based coordinates", "Set coordinate_units='radians' for radian-based coordinates", ], ) if self.coordinate_units not in ["degrees", "radians"]: raise create_coordinate_error( f"Invalid coordinate_units '{self.coordinate_units}'", details="coordinate_units must be either 'degrees' or 'radians'", suggestions=["Use coordinate_units='degrees' or coordinate_units='radians'"], ) else: # Check if coordinate_units is explicitly specified if self.coordinate_units is not None: if self.coordinate_units not in ["degrees", "radians"]: raise create_coordinate_error( f"Invalid coordinate_units '{self.coordinate_units}'", details="coordinate_units must be either 'degrees' or 'radians'", suggestions=["Use coordinate_units='degrees' or coordinate_units='radians'"], ) # Use explicitly specified coordinate units else: # Auto-detect coordinate units for global data lon = self.data_bin[self.xcoord] lon_range = float(lon.max()) - float(lon.min()) # Check for degrees (range close to 360) if abs(lon_range - 360.0) <= 1.0: self.coordinate_units = "degrees" # Check for radians (range close to 2π) elif abs(lon_range - 2 * np.pi) <= 0.02: self.coordinate_units = "radians" # If neither, throw error else: raise create_coordinate_error( f"Cannot auto-detect coordinate units from range {lon_range:.3f}", details=(f"Expected ranges: ~360 degrees or ~{2*np.pi:.3f} radians. " f"Found range: {lon_range:.3f}"), suggestions=[ "Use regional_mode=True with coordinate_units specified for regional data", "Specify coordinate_units='degrees' or coordinate_units='radians' explicitly", "Check that your coordinate values are correct", "Verify x-dimension coordinate ranges", ], context={"detected_range": lon_range, "xdim": self.xcoord}, ) # Convert lat & lon to degrees if self.coordinate_units == "radians": self.data_bin[self.xcoord] = self.data_bin[self.xcoord] * 180.0 / np.pi self.data_bin[self.ycoord] = self.data_bin[self.ycoord] * 180.0 / np.pi def _remap_coordinates(self, events_ds: xr.Dataset) -> xr.Dataset: """Remap coordinates to original lat/lon values after processing. Map centroids from lat=[-180,180] back into original lat/lon units & range. """ # Re-assign original coordinates from original marEx input events_ds = events_ds.assign_coords({self.ycoord: self.lat_init.compute(), self.xcoord: self.lon_init.compute()}) if "centroid" in events_ds.data_vars: # Remap centroids to original coordinate system # (lat, lon) currently in degrees [-90,90], [-180,180] centroids = events_ds["centroid"].persist() # Split into components centroids_lat = centroids.isel(component=0) # [-90, 90] degrees centroids_lon = centroids.isel(component=1) # [-180, 180] degrees # Get original coordinate bounds lon_min = float(self.lon_init.min().compute().item()) lon_max = float(self.lon_init.max().compute().item()) # Convert units and adjust ranges if self.coordinate_units == "radians": # Convert from degrees to radians centroids_lat = centroids_lat * np.pi / 180.0 # Now in [-π/2, π/2] centroids_lon = centroids_lon * np.pi / 180.0 # Now in [-π, π] # Check if original longitude was in [0, 2π] range if lon_min >= 0 and lon_max > np.pi: # Shift from [-π, π] to [0, 2π] centroids_lon = xr.where(centroids_lon < 0, centroids_lon + 2 * np.pi, centroids_lon) else: # Coordinates remain in degrees # Check if original longitude was in [0, 360] range if lon_min >= 0 and lon_max > 180: # Shift from [-180, 180] to [0, 360] centroids_lon = xr.where(centroids_lon < 0, centroids_lon + 360, centroids_lon) # Reassemble centroids with remapped coordinates centroids_remapped = xr.concat([centroids_lat, centroids_lon], dim="component") # Update the dataset events_ds["centroid"] = centroids_remapped return events_ds def _setup_unstructured_grid( self, temp_dir: str, neighbours: xr.DataArray, cell_areas: xr.DataArray, max_iteration: int, ) -> None: """Set up special handling for unstructured grids.""" if not temp_dir: raise ConfigurationError( "Missing temporary directory for unstructured processing", details="Unstructured grids require temporary storage for memory efficiency", suggestions=[ "Provide temp_dir parameter: tracker(..., temp_dir='/tmp/marex')", "Ensure directory has sufficient space and write permissions", ], ) self.scratch_dir = temp_dir # Clear any existing temporary storage if os.path.exists(f"{self.scratch_dir}/marEx_temp_field.zarr/"): shutil.rmtree(f"{self.scratch_dir}/marEx_temp_field.zarr/") # Remove coordinate variables to avoid memory issues self.data_bin = self.data_bin.drop_vars({self.ycoord, self.xcoord}) self.mask = self.mask.drop_vars({self.ycoord, self.xcoord}) self.lat = self.lat.drop_vars(self.lat.coords) self.lon = self.lon.drop_vars(self.lon.coords) neighbours = neighbours.drop_vars({self.ycoord, self.xcoord, "nv"}.intersection(set(neighbours.coords))) self.max_iteration = max_iteration # Validate spatial chunking for unstructured grid data self._validate_unstructured_chunking(neighbours, cell_areas) # Initialise dilation array for unstructured grid self.neighbours_int = neighbours.astype(np.int32) - 1 # Convert to 0-based indexing # Validate neighbour array structure if self.neighbours_int.shape[0] != 3: raise create_data_validation_error( "Invalid neighbour array for triangular grid", details=f"Expected shape (3, ncells), got {self.neighbours_int.shape}", suggestions=[ "Ensure triangular grid connectivity", "Check neighbour array from grid file", "Verify unstructured grid format", ], data_info={ "actual_shape": self.neighbours_int.shape, "expected_shape": "(3, ncells)", }, ) if self.neighbours_int.dims != ("nv", self.xdim): raise create_data_validation_error( "Invalid neighbour array dimensions", details=f"Expected dimensions ('nv', '{self.xdim}'), got {self.neighbours_int.dims}", suggestions=[ "Check dimension names in grid file", "Verify coordinate mapping", ], data_info={ "actual_dims": self.neighbours_int.dims, "expected_dims": ("nv", self.xdim), }, ) # Construct sparse dilation matrix self._build_sparse_dilation_matrix() def _build_sparse_dilation_matrix(self) -> None: """Build sparse matrix for efficient dilation operations on unstructured grid.""" # Create row and column indices for sparse matrix row_indices = jnp.repeat(jnp.arange(self.neighbours_int.shape[1]), 3) col_indices = self.neighbours_int.data.compute().T.flatten() # Filter out negative values (invalid connections) valid_mask = col_indices >= 0 row_indices = row_indices[valid_mask] col_indices = col_indices[valid_mask] # Create the sparse matrix for dilation ncells = self.neighbours_int.shape[1] dilate_coo = coo_matrix( (jnp.ones_like(row_indices, dtype=bool), (row_indices, col_indices)), shape=(ncells, ncells), ) self.dilate_sparse = csr_matrix(dilate_coo) # Add identity matrix to include self-connections identity = eye(self.neighbours_int.shape[1], dtype=bool, format="csr") self.dilate_sparse = self.dilate_sparse + identity logger.info("Finished constructing the sparse dilation matrix") def _configure_warnings(self) -> None: """Configure warning and logging suppression based on debug level.""" logger.debug(f"Configuring warnings and logging for debug level: {self.debug}") if self.debug < 2: # Configure logging warning filters logging.getLogger("distributed.scheduler").setLevel(logging.ERROR) def filter_dask_warnings(record): # pragma: no cover msg = str(record.msg) if self.debug == 0: # Suppress both run_spec and large graph warnings if any( pattern in msg for pattern in [ "Detected different `run_spec`", "Sending large graph", "This may cause some slowdown", ] ): return False return True else: # Suppress only run_spec warnings if "Detected different `run_spec`" in msg: return False return True logging.getLogger("distributed.scheduler").addFilter(filter_dask_warnings) # Configure Python warnings if self.debug == 0: warnings.filterwarnings("ignore", category=UserWarning, module="distributed.client") warnings.filterwarnings( "ignore", message=".*Sending large graph.*\n.*This may cause some slowdown.*", category=UserWarning, ) # ============================ # Main Public Methods # ============================
[docs] def run( self, return_merges: bool = False, checkpoint: Optional[str] = None ) -> Union[xr.Dataset, Tuple[xr.Dataset, xr.Dataset]]: """ Run the complete object identification and tracking pipeline. This method executes the full workflow: 1. Preprocessing: morphological operations and size filtering 2. Identification and tracking of objects through time 3. Computing and attaching statistics to the results Parameters ---------- return_merges : bool, default=False If True, return merge events dataset alongside the main events checkpoint : str, optional Override the instance checkpoint setting Returns ------- events_ds : xarray.Dataset Dataset containing tracked events and their properties merges_ds : xarray.Dataset, optional Dataset with merge event information (only if return_merges=True) """ logger.info("Starting complete tracking pipeline") log_memory_usage(logger, "Pipeline start") # Progress tracking total_steps = 3 current_step = 0 # Preprocess the binary data current_step += 1 logger.info(f"Step {current_step}/{total_steps}: Data preprocessing") with log_timing(logger, "Data preprocessing", log_memory=True, show_progress=True): data_bin_preprocessed, object_stats = self.run_preprocess(checkpoint=checkpoint) # Run identification and tracking current_step += 1 logger.info(f"Step {current_step}/{total_steps}: Object identification and tracking") with log_timing( logger, "Object identification and tracking", log_memory=True, show_progress=True, ): events_ds, merges_ds, N_events_final = self.run_tracking(data_bin_preprocessed) # Compute statistics and finalise output current_step += 1 logger.info(f"Step {current_step}/{total_steps}: Computing event statistics and attributes") with log_timing( logger, "Computing event statistics and attributes", log_memory=True, show_progress=True, ): events_ds = self.run_stats_attributes(events_ds, merges_ds, object_stats, N_events_final) logger.info(f"Tracking pipeline completed successfully - {N_events_final} events identified") logger.debug(f"Final dataset dimensions: {events_ds.dims}") log_memory_usage(logger, "Pipeline completion") if self.allow_merging and return_merges: logger.debug("Returning both events and merge datasets") return events_ds, merges_ds else: logger.debug("Returning events dataset only") return events_ds
[docs] def run_preprocess(self, checkpoint: Optional[str] = None) -> Tuple[xr.DataArray, Tuple[float, int, int, float, float, float]]: """ Preprocess binary data to prepare for tracking. This performs morphological operations to fill holes/gaps in both space and time, then filters small objects according to the area_filter_quartile or area_filter_absolute. Parameters ---------- checkpoint : str, optional Checkpoint strategy override Returns ------- data_bin_filtered : xarray.DataArray Preprocessed binary data object_stats : tuple Statistics about the preprocessing """ if not checkpoint: checkpoint = self.checkpoint def load_data_from_checkpoint() -> xr.DataArray: """Load preprocessed data from checkpoint files.""" data_bin_preprocessed: xr.DataArray = xr.open_zarr( f"{self.scratch_dir}/marEx_checkpoint_proc_bin.zarr", chunks={self.timedim: self.timechunks}, )["data_bin_preproc"] return data_bin_preprocessed def load_stats_from_checkpoint() -> Tuple[float, int, int, float, float, float]: object_stats_npz = np.load(f"{self.scratch_dir}/marEx_checkpoint_stats.npz") object_stats = [ object_stats_npz[key] for key in [ "total_area_IDed", "N_objects_prefiltered", "N_objects_filtered", "area_threshold", "accepted_area_fraction", "preprocessed_area_fraction", ] ] return tuple(object_stats) # type: ignore[return-value] if checkpoint == "load": logger.info("Loading preprocessed data from checkpoint") return load_data_from_checkpoint(), load_stats_from_checkpoint() # Compute area of initial binary data logger.debug("Computing area of initial binary data") raw_area = self.compute_area(self.data_bin) logger.debug(f"Initial raw area: {raw_area}") # Fill small holes & gaps between objects logger.info(f"Filling spatial holes with radius R_fill={self.R_fill}") with log_timing(logger, "Spatial hole filling"): data_bin_filled = self.fill_holes(self.data_bin) del self.data_bin # Free memory log_memory_usage(logger, "After spatial hole filling", logging.DEBUG) # Fill small time-gaps between objects logger.info(f"Filling temporal gaps with T_fill={self.T_fill}") with log_timing(logger, "Temporal gap filling"): data_bin_filled = self.fill_time_gaps(data_bin_filled).persist() log_memory_usage(logger, "After temporal gap filling", logging.DEBUG) # Remove small objects logger.info("Filtering small objects") with log_timing(logger, "Small object filtering"): ( data_bin_filtered, area_threshold, object_areas, N_objects_prefiltered, N_objects_filtered, ) = self.filter_small_objects(data_bin_filled) del data_bin_filled # Free memory logger.info(f"Filtered {N_objects_prefiltered} -> {N_objects_filtered} objects (threshold: {area_threshold})") log_memory_usage(logger, "After object filtering", logging.DEBUG) # Persist preprocessed data &/or Save checkpoint if checkpoint and "save" in checkpoint: logger.info("Saving preprocessed data to checkpoint") with log_timing(logger, "Checkpoint saving"): time.sleep(5) data_bin_filtered.name = "data_bin_preproc" data_bin_filtered.to_zarr( f"{self.scratch_dir}/marEx_checkpoint_proc_bin.zarr", mode="w" ) # N.B.: This needs to be done without .persist() due to dask to_zarr tuple bug... data_bin_filtered = load_data_from_checkpoint() else: logger.debug("Persisting preprocessed data in memory") data_bin_filtered = data_bin_filtered.persist() wait(data_bin_filtered) # Compute area of processed data processed_area = self.compute_area(data_bin_filtered) # Compute statistics object_areas = object_areas.compute() total_area_IDed = float(object_areas.sum().item()) accepted_area = float(object_areas.where(object_areas > area_threshold, drop=True).sum().item()) accepted_area_fraction = accepted_area / total_area_IDed total_hobday_area = float(raw_area.sum().compute().item()) total_processed_area = float(processed_area.sum().compute().item()) preprocessed_area_fraction = total_hobday_area / total_processed_area object_stats = ( total_area_IDed, N_objects_prefiltered, N_objects_filtered, area_threshold, accepted_area_fraction, preprocessed_area_fraction, ) # Save checkpoint if checkpoint and "save" in checkpoint: np.savez( f"{self.scratch_dir}/marEx_checkpoint_stats.npz", total_area_IDed=total_area_IDed, N_objects_prefiltered=N_objects_prefiltered, N_objects_filtered=N_objects_filtered, area_threshold=area_threshold, accepted_area_fraction=accepted_area_fraction, preprocessed_area_fraction=preprocessed_area_fraction, ) # Reload to refresh the dask graph data_bin_filtered = load_data_from_checkpoint() object_stats = load_stats_from_checkpoint() return data_bin_filtered, object_stats
[docs] def run_tracking(self, data_bin_preprocessed: xr.DataArray) -> Tuple[xr.Dataset, xr.Dataset, int]: """ Track objects through time to identify events. Parameters ---------- data_bin_preprocessed : xarray.DataArray Preprocessed binary data Returns ------- events_ds : xarray.Dataset Dataset containing tracked events merges_ds : xarray.Dataset Dataset with merge information N_events_final : int Final number of unique events """ if self.allow_merging or self.unstructured_grid: # Track with merging & splitting events_ds, merges_ds, N_events_final = self.track_objects(data_bin_preprocessed) else: # Track without merging or splitting events_da, _, N_events_final = self.identify_objects(data_bin_preprocessed, time_connectivity=True) events_ds = xr.Dataset({"ID_field": events_da}) merges_ds = xr.Dataset() # Set all filler IDs < 0 to 0 events_ds["ID_field"] = events_ds.ID_field.where(events_ds.ID_field > 0, drop=False, other=0) # Restore original coordinate name if needed if self.timecoord != self.timedim and self.timedim in events_ds.coords and self.timecoord not in events_ds.coords: # Get the time coordinate data time_coord_data = events_ds.coords[self.timedim] # Create a new coordinate with the original name events_ds = events_ds.assign_coords({self.timecoord: time_coord_data}) # Remove the dimension coordinate to avoid duplication if self.timedim in events_ds.coords and self.timecoord in events_ds.coords: events_ds = events_ds.drop_vars(self.timedim) logger.info("Finished tracking all extreme events!") return events_ds, merges_ds, N_events_final
[docs] def run_stats_attributes( self, events_ds: xr.Dataset, merges_ds: xr.Dataset, object_stats: Tuple[float, int, int, float, float, float], N_events_final: int, ) -> xr.Dataset: """ Add statistics and attributes to the events dataset. Parameters ---------- events_ds : xarray.Dataset Dataset containing tracked events merges_ds : xarray.Dataset Dataset with merge information object_stats : tuple Preprocessed object statistics N_events_final : int Final number of events Returns ------- events_ds : xarray.Dataset Dataset with added statistics and attributes """ # Unpack object stats ( total_area_IDed, N_objects_prefiltered, N_objects_filtered, area_threshold, accepted_area_fraction, preprocessed_area_fraction, ) = object_stats # Add general attributes to dataset events_ds.attrs["allow_merging"] = int(self.allow_merging) events_ds.attrs["N_objects_prefiltered"] = int(N_objects_prefiltered) events_ds.attrs["N_objects_filtered"] = int(N_objects_filtered) events_ds.attrs["N_events_final"] = int(N_events_final) events_ds.attrs["R_fill"] = self.R_fill events_ds.attrs["T_fill"] = self.T_fill events_ds.attrs["area_filter_quartile"] = self.area_filter_quartile events_ds.attrs["area_threshold (cells)"] = area_threshold events_ds.attrs["accepted_area_fraction"] = accepted_area_fraction events_ds.attrs["preprocessed_area_fraction"] = preprocessed_area_fraction # Print summary statistics print("Tracking Statistics:") print(f" Binary Hobday to Processed Area Fraction: {preprocessed_area_fraction}") print(f" Total Object Area IDed (cells): {total_area_IDed}") print(f" Number of Initial Pre-Filtered Objects: {N_objects_prefiltered}") print(f" Number of Final Filtered Objects: {N_objects_filtered}") print(f" Area Cutoff Threshold (cells): {int(area_threshold)}") print(f" Accepted Area Fraction: {accepted_area_fraction}") print(f" Total Events Tracked: {N_events_final}") # Add merge-specific attributes if applicable if self.allow_merging: events_ds.attrs["overlap_threshold"] = self.overlap_threshold events_ds.attrs["nn_partitioning"] = int(self.nn_partitioning) # Add merge summary attributes events_ds.attrs["total_merges"] = len(merges_ds.merge_ID) events_ds.attrs["multi_parent_merges"] = int((merges_ds.n_parents > 2).sum().item()) print(f" Total Merging Events Recorded: {events_ds.attrs['total_merges']}") # Inherit metadata from input data_bin events_ds.attrs.update(self.data_attrs) # Restore coordinates & remap centroids # Add lat & lon back as coordinates events_ds = self._remap_coordinates(events_ds) # Rechunk to size 1 for better post-processing events_ds = events_ds.chunk({self.timedim: 1}) return events_ds
# ============================ # Data Processing Methods # ============================
[docs] def compute_area(self, data_bin: xr.DataArray) -> xr.DataArray: """ Compute the total area of binary data at each time. Parameters ---------- data_bin : xarray.DataArray Binary data Returns ------- area : xarray.DataArray Total area at each time (units: pixels for structured grid, matching cell_area for unstructured) """ if self.unstructured_grid: area = (data_bin * self.cell_area).sum(dim=[self.xdim]) else: area = data_bin.sum(dim=[self.ydim, self.xdim]) return area
[docs] def fill_holes(self, data_bin: xr.DataArray, R_fill: Optional[int] = None) -> xr.DataArray: """ Fill holes and gaps using morphological operations. This performs closing (dilation followed by erosion) to fill small gaps, then opening (erosion followed by dilation) to remove small isolated objects. Parameters ---------- data_bin : xarray.DataArray Binary data to process R_fill : int, optional Fill radius override Returns ------- data_bin_filled : xarray.DataArray Binary data with holes/gaps filled """ if R_fill is None: R_fill = self.R_fill if self.unstructured_grid: # Process unstructured grid using sparse matrix operations # _Put the data into an xarray.DataArray to pass into the apply_ufunc_ -- Needed for correct memory management ! sp_data = xr.DataArray(self.dilate_sparse.data, dims="sp_data") indices = xr.DataArray(self.dilate_sparse.indices, dims="indices") indptr = xr.DataArray(self.dilate_sparse.indptr, dims="indptr") def binary_open_close( bitmap_binary: NDArray[np.bool_], sp_data: NDArray[np.bool_], indices: NDArray[np.int32], indptr: NDArray[np.int32], mask: NDArray[np.bool_], ) -> NDArray[np.bool_]: """ Binary opening and closing for unstructured grid. Uses sparse matrix power operations for efficiency. """ # Closing: Dilation then Erosion (fills small gaps) # Dilation bitmap_binary = sparse_bool_power(bitmap_binary, sp_data, indices, indptr, R_fill) # Set land values to True (to avoid artificially eroding the shore) bitmap_binary[:, ~mask] = True # Erosion (negated dilation of negated image) bitmap_binary = ~sparse_bool_power(~bitmap_binary, sp_data, indices, indptr, R_fill) # Opening: Erosion then Dilation (removes small objects) # Set land values to True (to avoid artificially eroding the shore) bitmap_binary[:, ~mask] = True # Erosion bitmap_binary = ~sparse_bool_power(~bitmap_binary, sp_data, indices, indptr, R_fill) # Dilation bitmap_binary = sparse_bool_power(bitmap_binary, sp_data, indices, indptr, R_fill) return bitmap_binary # Apply the operations data_bin = xr.apply_ufunc( binary_open_close, data_bin, sp_data, indices, indptr, self.mask, input_core_dims=[ [self.xdim], ["sp_data"], ["indices"], ["indptr"], [self.xdim], ], output_core_dims=[[self.xdim]], output_dtypes=[np.bool_], vectorize=False, dask_gufunc_kwargs={ "output_sizes": {self.xdim: data_bin.sizes[self.xdim]}, }, dask="parallelized", ) else: # Structured grid using dask-powered morphological operations use_dask_morph = True # Generate structuring element (disk-shaped) y, x = np.ogrid[-R_fill : R_fill + 1, -R_fill : R_fill + 1] r = x**2 + y**2 # Pad/trim width for the closing+opening morphology. Closing has reach 2*R_fill and the # subsequent opening another 2*R_fill, so the total reach is 4*R_fill. Padding by only # 2*R_fill (the old value) leaves an under-padded artifact in the 2R-4R band around the # periodic-longitude seam. (Matches the clean_up-branch EDT fill_holes, which pads 4*R_fill.) diameter = 4 * R_fill se_kernel = r < (R_fill**2) + 1 mode = "wrap" if not self.regional_mode else "edge" if use_dask_morph: # Skip all operations if R_fill is 0 if R_fill == 0: pass # No morphological operations needed else: # Pad data to avoid edge effects data_bin = data_bin.pad({self.ydim: diameter, self.xdim: diameter}, mode=mode) data_coords = data_bin.coords data_dims = data_bin.dims # Apply morphological operations data_bin = binary_closing_dask( data_bin.data, structure=se_kernel[np.newaxis, :, :] ) # N.B.: There may be a rearing bug in constructing the dask task graph when we # extract and then re-imbed the dask array into an xarray DataArray data_bin = binary_opening_dask(data_bin, structure=se_kernel[np.newaxis, :, :]) # Convert back to xarray.DataArray and trim padding data_bin = xr.DataArray(data_bin, coords=data_coords, dims=data_dims) data_bin = data_bin.isel( { self.ydim: slice(diameter, -diameter), self.xdim: slice(diameter, -diameter), } ) else: # pragma: no cover def binary_open_close( bitmap_binary: NDArray[np.bool_], ) -> NDArray[np.bool_]: """Apply binary opening and closing in one function.""" bitmap_binary_padded = np.pad( bitmap_binary, ((diameter, diameter), (diameter, diameter)), mode=mode, ) s1 = binary_closing(bitmap_binary_padded, se_kernel, iterations=1) s2 = binary_opening(s1, se_kernel, iterations=1) unpadded = s2[diameter:-diameter, diameter:-diameter] return unpadded data_bin = xr.apply_ufunc( binary_open_close, data_bin, input_core_dims=[[self.ydim, self.xdim]], output_core_dims=[[self.ydim, self.xdim]], output_dtypes=[data_bin.dtype], vectorize=True, dask="parallelized", ) # Mask out edge features from morphological operations data_bin = data_bin.where(self.mask, drop=False, other=False) return data_bin
[docs] def fill_time_gaps(self, data_bin: xr.DataArray) -> xr.DataArray: """ Fill temporal gaps between objects. Performs binary closing (dilation then erosion) along the time dimension to fill small time gaps between objects. Parameters ---------- data_bin : xarray.DataArray Binary data to process Returns ------- data_bin_filled : xarray.DataArray Binary data with temporal gaps filled """ if self.T_fill == 0: return data_bin # Create temporal structuring element kernel_size = self.T_fill + 1 # This will then fill a maximum hole size of self.T_fill time_kernel = np.ones(kernel_size, dtype=bool) if self.ydim is None: # Unstructured grid has only 1 additional dimension time_kernel = time_kernel[:, np.newaxis] else: time_kernel = time_kernel[:, np.newaxis, np.newaxis] # Pad in time to avoid edge effects data_bin = data_bin.pad({self.timedim: kernel_size}, mode="constant", constant_values=False) # Apply temporal closing data_bin_dask = data_bin.data closed_dask_array = binary_closing_dask(data_bin_dask, structure=time_kernel) # Convert back to xarray.DataArray data_bin_filled = xr.DataArray( closed_dask_array, coords=data_bin.coords, dims=data_bin.dims, attrs=data_bin.attrs, ) # Remove padding data_bin_filled = data_bin_filled.isel({self.timedim: slice(kernel_size, -kernel_size)}).persist() # Fill newly-created spatial holes data_bin_filled = self.fill_holes(data_bin_filled, R_fill=self.R_fill // 2) return data_bin_filled
[docs] def refresh_dask_graph(self, data_bin: xr.DataArray) -> xr.DataArray: """ Clear and reset the Dask graph via save/load cycle. This is needed to work around a memory leak bug in Dask where "Unmanaged Memory" builds up within loops. Parameters ---------- data_bin : xarray.DataArray Data to refresh Returns ------- data_new : xarray.DataArray Data with fresh Dask graph """ logger.debug("Refreshing Dask task graph...") data_bin.name = "temp" data_bin.to_zarr(f"{self.scratch_dir}/marEx_temp_field.zarr", mode="w") del data_bin gc.collect() data_new = xr.open_zarr(f"{self.scratch_dir}/marEx_temp_field.zarr", chunks={}).temp return data_new
[docs] def filter_small_objects(self, data_bin: xr.DataArray) -> Tuple[xr.DataArray, float, xr.DataArray, int, int]: """ Remove objects smaller than a threshold area. Parameters ---------- data_bin : xarray.DataArray Binary data to filter Returns ------- data_bin_filtered : xarray.DataArray Binary data with small objects removed area_threshold : float Area threshold used for filtering object_areas : xarray.DataArray Areas of all objects pre-filtering N_objects_prefiltered : int Number of objects before filtering N_objects_filtered : int Number of objects after filtering """ # Cluster & Label Binary Data: Time-independent in 2D (i.e. no time connectivity!) object_id_field, _, N_objects_unfiltered = self.identify_objects(data_bin, time_connectivity=False) if self.unstructured_grid: # Get the maximum ID to dimension arrays # Note: identify_objects() starts at ID=0 for every time slice max_ID = int(object_id_field.max().compute().item()) def count_cluster_sizes( object_id_field: NDArray[np.int32], ) -> Tuple[NDArray[np.int32], NDArray[np.int32]]: """Count the number of cells in each cluster.""" unique, counts = np.unique(object_id_field[object_id_field > 0], return_counts=True) padded_sizes = np.zeros(max_ID, dtype=np.int32) padded_unique = np.zeros(max_ID, dtype=np.int32) padded_sizes[: len(counts)] = counts padded_unique[: len(counts)] = unique return padded_sizes, padded_unique # Calculate cluster sizes cluster_sizes, unique_cluster_IDs = xr.apply_ufunc( count_cluster_sizes, object_id_field, input_core_dims=[[self.xdim]], output_core_dims=[["ID"], ["ID"]], dask_gufunc_kwargs={"output_sizes": {"ID": max_ID}}, output_dtypes=(np.int32, np.int32), vectorize=True, dask="parallelized", ) results = persist(cluster_sizes, unique_cluster_IDs) cluster_sizes, unique_cluster_IDs = results # Pre-filter tiny objects for performance (greatly reduces the size for the percentile calculation) if self._use_absolute_filtering: cluster_sizes_filtered_dask = cluster_sizes.where(cluster_sizes > 5).data else: cluster_sizes_filtered_dask = cluster_sizes.where(cluster_sizes > 50).data cluster_areas_mask = dsa.isfinite(cluster_sizes_filtered_dask) object_areas = cluster_sizes_filtered_dask[cluster_areas_mask].compute() # Filter based on area threshold N_objects_unfiltered = len(object_areas) if N_objects_unfiltered == 0: # pragma: no cover raise TrackingError( "No objects found for area-based filtering", details={ "objects_count": N_objects_unfiltered, "area_filter_quartile": self.area_filter_quartile, "grid_type": "unstructured", }, suggestions=[ "Check if input data contains any extreme events", "Verify that preprocessing parameters are appropriate", "Consider lowering the extreme threshold percentile", ], ) if self._use_absolute_filtering: area_threshold = self.area_filter_absolute else: area_threshold = np.percentile(object_areas, self.area_filter_quartile * 100) N_objects_filtered = np.sum(object_areas > area_threshold) def filter_area_binary(cluster_IDs_0: NDArray[np.int32], keep_IDs_0: NDArray[np.int32]) -> NDArray[np.bool_]: """Keep only clusters above threshold area.""" keep_IDs_0 = keep_IDs_0[keep_IDs_0 > 0] keep_where = np.isin(cluster_IDs_0, keep_IDs_0) return keep_where # Create filtered binary data keep_IDs = xr.where(cluster_sizes > area_threshold, unique_cluster_IDs, 0) data_bin_filtered = xr.apply_ufunc( filter_area_binary, object_id_field, keep_IDs, input_core_dims=[[self.xdim], ["ID"]], output_core_dims=[[self.xdim]], output_dtypes=[data_bin.dtype], vectorize=True, dask="parallelized", ) object_areas = cluster_sizes # Store pre-filtered areas else: # Structured grid approach # Calculate object properties including area object_props = self.calculate_object_properties(object_id_field) object_areas, object_ids = object_props.area, object_props.ID # Calculate area threshold if len(object_areas) == 0: # pragma: no cover raise TrackingError( "No objects found for area-based filtering", details={ "objects_count": len(object_areas), "area_filter_quartile": self.area_filter_quartile, "grid_type": "structured", }, suggestions=[ "Check if input data contains any extreme events", "Verify that preprocessing parameters are appropriate", "Consider lowering the extreme threshold percentile", ], ) if self._use_absolute_filtering: area_threshold = self.area_filter_absolute else: area_threshold = np.percentile(object_areas, self.area_filter_quartile * 100.0) # Keep only objects above threshold object_ids_keep = xr.where(object_areas >= area_threshold, object_ids, -1) object_ids_keep[0] = -1 # Don't keep ID=0 # Create filtered binary data data_bin_filtered = object_id_field.isin(object_ids_keep) # Count objects after filtering N_objects_filtered = int(object_ids_keep.where(object_ids_keep > 0).count().item()) return ( data_bin_filtered, area_threshold, object_areas, N_objects_unfiltered, N_objects_filtered, )
# ============================ # Object Identification Methods # ============================
[docs] def identify_objects(self, data_bin: xr.DataArray, time_connectivity: bool) -> Tuple[xr.DataArray, None, int]: """ Identify connected regions in binary data. Parameters ---------- data_bin : xarray.DataArray Binary data to identify objects in time_connectivity : bool Whether to connect objects across time Returns ------- object_id_field : xarray.DataArray Field of integer IDs for each object None : NoneType Placeholder for compatibility with track_objects N_objects : int Number of objects identified """ if self.unstructured_grid: # The resulting ID field for unstructured grid will start at 0 for each time-slice, # which differs from structured grid where IDs are unique across time. if time_connectivity: # pragma: no cover raise ConfigurationError( "Time connectivity not supported for unstructured grids", details="Automatic time connectivity computation requires regular grids", suggestions=[ "Set time_connectivity=False for unstructured data", "Manually specify connectivity if needed", ], ) # Use Union-Find (Disjoint Set Union) clustering for unstructured grid def cluster_true_values(arr: NDArray[np.bool_], neighbours_int: NDArray[np.int32]) -> NDArray[np.int32]: """Cluster connected True values in binary data on unstructured grid.""" t, n = arr.shape labels = np.full((t, n), -1, dtype=np.int32) for i in range(t): # Get indices of True values true_indices = np.where(arr[i])[0].astype(np.int32) mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(true_indices)} # Find connected components valid_mask = (neighbours_int != -1) & arr[i][neighbours_int] row_ind, col_ind = np.where(valid_mask) row_ind = row_ind.astype(np.int32) col_ind = col_ind.astype(np.int32) # Map to compact indices for graph algorithm mapped_row_ind = [] mapped_col_ind = [] for r, c in zip(neighbours_int[row_ind, col_ind], col_ind): if r in mapping and c in mapping: mapped_row_ind.append(mapping[r]) mapped_col_ind.append(mapping[c]) # Create graph and find connected components graph = csr_matrix( ( np.ones(len(mapped_row_ind), dtype=np.int32), (mapped_row_ind, mapped_col_ind), ), shape=(len(true_indices), len(true_indices)), ) _, labels_true = connected_components(csgraph=graph, directed=False, return_labels=True) labels[i, true_indices] = labels_true return labels + 1 # Add 1 so 0 represents no object # Apply mask and cluster data_bin = data_bin.where(self.mask, other=False) object_id_field = xr.apply_ufunc( cluster_true_values, data_bin, self.neighbours_int, input_core_dims=[[self.xdim], ["nv", self.xdim]], output_core_dims=[[self.xdim]], output_dtypes=[np.int32], dask_gufunc_kwargs={ "output_sizes": {self.xdim: data_bin.sizes[self.xdim]}, }, vectorize=False, dask="parallelized", ) # Ensure ID = 0 on invalid regions object_id_field = object_id_field.where(self.mask, other=0) object_id_field = object_id_field.persist() object_id_field = object_id_field.rename("ID_field") N_objects = 1 # Placeholder (IDs aren't unique across time) else: # Structured Grid # Create connectivity kernel for labeling neighbours = np.zeros((3, 3, 3)) if time_connectivity: # ID objects in 3D (i.e. space & time) -- N.B. IDs are unique across time neighbours[:, :, :] = 1 # +-1 in time, _and also diagonal in time_ -- i.e. edges can touch else: # ID objects only in 2D (i.e. space) -- N.B. IDs are _not_ unique across time (i.e. each time starts at 0 again) neighbours[1, :, :] = 1 # All 8 neighbours, but ignore time # Cluster & label binary data # Apply dask-powered ndimage & persist in memory if self.regional_mode: object_id_field, N_objects = label( data_bin, structure=neighbours, ) else: object_id_field, N_objects = label( data_bin, structure=neighbours, wrap_axes=(2,), # Wrap in x-direction ! ) results = persist(object_id_field, N_objects) object_id_field, N_objects = results N_objects = N_objects.compute() # Convert to DataArray with same coordinates as input object_id_field = ( xr.DataArray( object_id_field, coords=data_bin.coords, dims=data_bin.dims, attrs=data_bin.attrs, ) .rename("ID_field") .astype(np.int32) ) return object_id_field, None, N_objects
[docs] def calculate_centroid( self, binary_mask: NDArray[np.bool_], original_centroid: Optional[Tuple[float, float]] = None, ) -> Tuple[float, float]: """ Calculate object centroid, handling edge cases for periodic boundaries. Parameters ---------- binary_mask : numpy.ndarray 2D binary array where True indicates the object (dimensions are (y,x)) original_centroid : tuple, optional (y_centroid, x_centroid) from regionprops_table Returns ------- tuple (y_centroid, x_centroid) """ if self.regional_mode: # pragma: no cover # We don't need to adjust centroids for periodic boundaries return original_centroid # Check if object is near either edge of x dimension near_left_BC = np.any(binary_mask[:, :100]) near_right_BC = np.any(binary_mask[:, -100:]) if original_centroid is None: # pragma: no cover # Calculate y centroid from scratch y_indices = np.nonzero(binary_mask)[0] y_centroid = np.mean(y_indices) else: y_centroid = original_centroid[0] # If object is near both edges, recalculate x-centroid to handle wrapping # N.B.: We calculate _near_ rather than touching, to catch the edge case where the # object may be split and straddling the boundary ! if near_left_BC and near_right_BC: # Adjust x coordinates that are near right edge x_indices = np.nonzero(binary_mask)[1] x_indices_adj = x_indices.copy() right_side = x_indices > binary_mask.shape[1] // 2 x_indices_adj[right_side] -= binary_mask.shape[1] x_centroid = np.mean(x_indices_adj) if x_centroid < 0: # Ensure centroid is positive x_centroid += binary_mask.shape[1] elif original_centroid is None: # pragma: no cover # Calculate x-centroid from scratch x_indices = np.nonzero(binary_mask)[1] x_centroid = np.mean(x_indices) else: x_centroid = original_centroid[1] return (y_centroid, x_centroid)
[docs] def calculate_object_properties(self, object_id_field: xr.DataArray, properties: Optional[List[str]] = None) -> xr.Dataset: """ Calculate properties of objects from ID field. Parameters ---------- object_id_field : xarray.DataArray Field containing object IDs properties : list, optional List of properties to calculate (defaults to ['label', 'area']) Returns ------- object_props : xarray.Dataset Dataset containing calculated properties with 'ID' dimension """ # Set default properties if properties is None: properties = ["label", "area"] # Ensure 'label' is included if "label" not in properties: properties = ["label"] + properties # 'label' is actually 'ID' within regionprops check_centroids = "centroid" in properties if self.unstructured_grid: # Compute properties on unstructured grid # Convert lat/lon to radians lat_rad = np.radians(self.lat) lon_rad = np.radians(self.lon) # Broadcast coordinate arrays to match object_id_field shape for vectorisation lat_rad_broadcast, _ = xr.broadcast(lat_rad, object_id_field) lon_rad_broadcast, _ = xr.broadcast(lon_rad, object_id_field) cell_area_broadcast, _ = xr.broadcast(self.cell_area, object_id_field) # Calculate buffer size for IDs in chunks max_ID = int(object_id_field.max().compute().item()) + 1 # Handle case where object_id_field may not have time dimension (e.g., single time slice) if self.timedim in object_id_field.dims: time_steps = object_id_field.sizes[self.timedim] else: # For single time slice, use 1 as time steps time_steps = 1 ID_buffer_size = max(int(max_ID / time_steps) * 4 + 2, max_ID) def object_properties_chunk( ids: NDArray[np.int32], lat: NDArray[np.float32], lon: NDArray[np.float32], area: NDArray[np.float32], buffer_IDs: bool = True, ) -> Tuple[NDArray[np.float32], NDArray[np.int32]]: """ Calculate object properties for a chunk of data. Uses vectorised operations for efficiency. """ # Find valid IDs valid_mask = ids > 0 ids_chunk = np.unique(ids[valid_mask]) n_ids = len(ids_chunk) if n_ids == 0: # No objects in this chunk if buffer_IDs: result = np.zeros((3, ID_buffer_size), dtype=np.float32) padded_ids = np.zeros(ID_buffer_size, dtype=np.int32) return result, padded_ids else: # pragma: no cover result = np.zeros((3, 0), dtype=np.float32) padded_ids = np.array([], dtype=np.int32) return result, padded_ids # Map IDs to consecutive indices mapped_indices = np.searchsorted(ids_chunk, ids[valid_mask]).astype(np.int32) # Pre-allocate arrays areas = np.zeros(n_ids, dtype=np.float32) weighted_x = np.zeros(n_ids, dtype=np.float32) weighted_y = np.zeros(n_ids, dtype=np.float32) weighted_z = np.zeros(n_ids, dtype=np.float32) # Convert to Cartesian for centroid calculation cos_lat = np.cos(lat[valid_mask]) x = cos_lat * np.cos(lon[valid_mask]) y = cos_lat * np.sin(lon[valid_mask]) z = np.sin(lat[valid_mask]) # Compute areas valid_areas = area[valid_mask] np.add.at(areas, mapped_indices, valid_areas) # Compute weighted coordinates np.add.at(weighted_x, mapped_indices, valid_areas * x) np.add.at(weighted_y, mapped_indices, valid_areas * y) np.add.at(weighted_z, mapped_indices, valid_areas * z) # Clean intermediate arrays del x, y, z, cos_lat, valid_areas # Normalise vectors norm = np.sqrt(weighted_x**2 + weighted_y**2 + weighted_z**2) norm = np.where(norm > 0, norm, 1) # Avoid division by zero weighted_x /= norm weighted_y /= norm weighted_z /= norm # Convert back to lat/lon centroid_lat = np.degrees(np.arcsin(np.clip(weighted_z, -1, 1))) centroid_lon = np.degrees(np.arctan2(weighted_y, weighted_x)) # Fix longitude range to [-180, 180] centroid_lon = np.where( centroid_lon > 180.0, centroid_lon - 360.0, np.where(centroid_lon < -180.0, centroid_lon + 360.0, centroid_lon), ) assert areas.shape == (n_ids,) assert centroid_lat.shape == (n_ids,) assert centroid_lon.shape == (n_ids,) if buffer_IDs: # Create padded output arrays result = np.zeros((3, ID_buffer_size), dtype=np.float32) padded_ids = np.zeros(ID_buffer_size, dtype=np.int32) # Fill arrays up to n_ids result[0, :n_ids] = areas result[1, :n_ids] = centroid_lat result[2, :n_ids] = centroid_lon padded_ids[:n_ids] = ids_chunk else: # pragma: no cover result = np.vstack((areas, centroid_lat, centroid_lon)) padded_ids = ids_chunk return result, padded_ids # Process single time or multiple times # If time dimension doesn't exist, treat as single time slice if self.timedim not in object_id_field.dims or object_id_field.sizes[self.timedim] == 1: # pragma: no cover props_np, ids = object_properties_chunk( object_id_field.values, lat_rad_broadcast.values, lon_rad_broadcast.values, cell_area_broadcast.values, buffer_IDs=False, ) props = xr.DataArray(props_np, dims=["prop", "out_id"]) else: # Process in parallel props_buffer, ids_buffer = xr.apply_ufunc( object_properties_chunk, object_id_field, lat_rad_broadcast, lon_rad_broadcast, cell_area_broadcast, input_core_dims=[ [self.xdim], [self.xdim], [self.xdim], [self.xdim], ], output_core_dims=[["prop", "out_id"], ["out_id"]], output_dtypes=[np.float32, np.int32], dask_gufunc_kwargs={"output_sizes": {"prop": 3, "out_id": ID_buffer_size}}, vectorize=True, dask="parallelized", ) results = persist(props_buffer, ids_buffer) props_buffer, ids_buffer = results ids_buffer = ids_buffer.compute().values.reshape(-1) # Get valid IDs (non-zero) valid_ids_mask = ids_buffer > 0 # Check if we have any valid IDs before stacking if np.any(valid_ids_mask): ids = ids_buffer[valid_ids_mask] props = props_buffer.stack(combined=(self.timedim, "out_id")).isel(combined=valid_ids_mask) else: # pragma: no cover # No valid IDs found ids = np.array([], dtype=np.int32) props = xr.DataArray(np.zeros((3, 0), dtype=np.float32), dims=["prop", "out_id"]) # Create object properties dataset if len(ids) > 0: object_props = ( xr.Dataset( { "area": ("out_id", props.isel(prop=0).data), "centroid-0": ("out_id", props.isel(prop=1).data), "centroid-1": ("out_id", props.isel(prop=2).data), }, coords={"ID": ("out_id", ids)}, ) .set_index(out_id="ID") .rename({"out_id": "ID"}) ) else: # pragma: no cover # Create empty dataset with correct structure object_props = xr.Dataset( { "area": ("ID", []), "centroid-0": ("ID", []), "centroid-1": ("ID", []), }, coords={"ID": []}, ) else: # Structured grid approach # N.B.: These operations are simply done on a pixel grid # i.e. with no cartesian conversion # (therefore, polar regions are doubly biased) # Define function to calculate properties for each chunk def object_properties_chunk( ids: NDArray[np.int32], ) -> Dict[str, List[Union[int, float]]]: """Calculate object properties for a chunk of data.""" # Use regionprops_table for standard properties props_slice = regionprops_table(ids, properties=properties) # Handle centroid calculation for objects that wrap around edges if check_centroids and not self.regional_mode and len(props_slice["label"]) > 0: # Get original centroids centroids = list(zip(props_slice["centroid-0"], props_slice["centroid-1"])) centroids_wrapped = [] # Process each object for ID_idx, ID in enumerate(props_slice["label"]): binary_mask = ids == ID centroids_wrapped.append(self.calculate_centroid(binary_mask, centroids[ID_idx])) # Update centroid values props_slice["centroid-0"] = [c[0] for c in centroids_wrapped] props_slice["centroid-1"] = [c[1] for c in centroids_wrapped] return props_slice # Process single time or multiple times # If time dimension doesn't exist, treat as single time slice if self.timedim not in object_id_field.dims or object_id_field.sizes[self.timedim] == 1: object_props = object_properties_chunk(object_id_field.values) object_props = xr.Dataset({key: (["ID"], value) for key, value in object_props.items()}) else: # Run in parallel object_props = xr.apply_ufunc( object_properties_chunk, object_id_field, input_core_dims=[[self.ydim, self.xdim]], output_core_dims=[[]], output_dtypes=[object], vectorize=True, dask="parallelized", ) # Concatenate and convert to dataset object_props = xr.concat( [xr.Dataset({key: (["ID"], value) for key, value in item.items()}) for item in object_props.values], dim="ID", ) # Set ID as coordinate object_props = object_props.set_index(ID="label") # Combine centroid components into a single variable if "centroid" in properties and "centroid-0" in object_props and "centroid-1" in object_props: object_props["centroid"] = xr.concat( [object_props["centroid-0"], object_props["centroid-1"]], dim="component", ) object_props = object_props.drop_vars(["centroid-0", "centroid-1"]) return object_props
# ============================ # Overlap and Tracking Methods # ============================
[docs] def check_overlap_slice(self, ids_t0: NDArray[np.int32], ids_next: NDArray[np.int32]) -> NDArray[Union[np.float32, np.int32]]: """ Find overlapping objects between two consecutive time slices. Parameters ---------- ids_t0 : numpy.ndarray Object IDs at current time ids_next : numpy.ndarray Object IDs at next time Returns ------- numpy.ndarray Array of shape (n_overlaps, 3) with [id_t0, id_next, overlap_area] """ # Create masks for valid IDs mask_t0 = ids_t0 > 0 mask_next = ids_next > 0 # Only process cells where both times have valid IDs combined_mask = mask_t0 & mask_next if not np.any(combined_mask): return np.empty((0, 3), dtype=np.float32 if self.unstructured_grid else np.int32) # Extract the overlapping points ids_t0_valid = ids_t0[combined_mask] ids_next_valid = ids_next[combined_mask] # Create a unique identifier for each pair # This is faster than using np.unique with axis=1 max_id = max(ids_t0.max(), ids_next.max() + 1).astype(np.int64) pair_ids = ids_t0_valid.astype(np.int64) * max_id + ids_next_valid.astype(np.int64) if self.unstructured_grid: # Get unique pairs and their inverse indices unique_pairs, inverse_indices = np.unique(pair_ids, return_inverse=True) inverse_indices = inverse_indices.astype(np.int32) # Ensure int32 for serialisation # Sum areas for overlapping cells areas_valid = self.cell_area.values[combined_mask] areas = np.zeros(len(unique_pairs), dtype=np.float32) np.add.at(areas, inverse_indices, areas_valid) else: # Get unique pairs and their counts (pixel counts) unique_pairs, areas = np.unique(pair_ids, return_counts=True) areas = areas.astype(np.int32) # Convert back to original ID pairs id_t0 = (unique_pairs // max_id).astype(np.int32) id_next = (unique_pairs % max_id).astype(np.int32) # Stack results result = np.column_stack((id_t0, id_next, areas)) return result
[docs] def find_overlapping_objects(self, object_id_field: xr.DataArray) -> NDArray[Union[np.float32, np.int32]]: """ Find all overlapping objects across time. Parameters ---------- object_id_field : xarray.DataArray Field containing object IDs Returns ------- overlap_objects_list_unique_filtered : (N x 3) numpy.ndarray Array of object ID pairs that overlap across time, with overlap area The object in the first column precedes the second column in time. The third column contains: * For structured grid: number of overlapping pixels (int32) * For unstructured grid: total overlapping area in m^2 (float32) """ # Check just for overlap with next time slice. # Keep a running list of all object IDs that overlap object_id_field_next = object_id_field.shift({self.timedim: -1}, fill_value=0) # Calculate overlaps in parallel input_dims = [self.xdim] if self.unstructured_grid else [self.ydim, self.xdim] overlap_object_pairs_list = xr.apply_ufunc( self.check_overlap_slice, object_id_field, object_id_field_next, input_core_dims=[input_dims, input_dims], output_core_dims=[[]], vectorize=True, dask="parallelized", output_dtypes=[object], ).persist() # Concatenate all pairs from different chunks all_pairs_with_areas = np.concatenate(overlap_object_pairs_list.values) # Get unique pairs and their indices unique_pairs, inverse_indices = np.unique(all_pairs_with_areas[:, :2], axis=0, return_inverse=True) inverse_indices = inverse_indices.astype(np.int32) # Ensure int32 for serialisation # Sum the overlap areas using the inverse indices output_dtype = np.float32 if self.unstructured_grid else np.int32 total_summed_areas = np.zeros(len(unique_pairs), dtype=output_dtype) np.add.at(total_summed_areas, inverse_indices, all_pairs_with_areas[:, 2]) # Stack the pairs with their summed areas overlap_objects_list_unique = np.column_stack((unique_pairs, total_summed_areas)) return overlap_objects_list_unique
[docs] def enforce_overlap_threshold( self, overlap_objects_list: NDArray[Union[np.float32, np.int32]], object_props: xr.Dataset, ) -> NDArray[Union[np.float32, np.int32]]: """ Filter object pairs based on overlap threshold. Parameters ---------- overlap_objects_list : (N x 3) numpy.ndarray Array of object ID pairs with overlap area object_props : xarray.Dataset Object properties including area Returns ------- overlap_objects_list_filtered : (M x 3) numpy.ndarray Filtered array of object ID pairs that meet the overlap threshold """ if len(overlap_objects_list) == 0: return np.empty((0, 3), dtype=np.float32 if self.unstructured_grid else np.int32) # Filter out overlaps where either ID doesn't exist in object_props existing_ids = set(object_props.ID.values) valid_mask = np.array([(overlap[0] in existing_ids) and (overlap[1] in existing_ids) for overlap in overlap_objects_list]) if not np.any(valid_mask): return np.empty((0, 3), dtype=np.float32 if self.unstructured_grid else np.int32) valid_overlaps = overlap_objects_list[valid_mask] # Calculate overlap fractions areas_0 = object_props["area"].sel(ID=valid_overlaps[:, 0]).values areas_1 = object_props["area"].sel(ID=valid_overlaps[:, 1]).values min_areas = np.minimum(areas_0, areas_1) overlap_fractions = valid_overlaps[:, 2].astype(float) / min_areas if np.any(overlap_fractions > 1.0): logger.warning(f"Found {np.sum(overlap_fractions > 1.0)} overlap fractions > 1.0") logger.warning(f"Max overlap fraction: {overlap_fractions.max()}") # Filter by threshold threshold_mask = overlap_fractions >= self.overlap_threshold overlap_objects_list_filtered = valid_overlaps[threshold_mask] return overlap_objects_list_filtered
[docs] def consolidate_object_ids( self, data_t_minus_2: xr.DataArray, data_t_minus_1: xr.DataArray, object_props: xr.Dataset, timestep: int ) -> Tuple[xr.DataArray, xr.Dataset]: """ Consolidate object IDs between t-2 and t-1 to ensure consistent tracking. This identifies objects at t-1 that are actually continuations of objects from t-2 (but got different IDs due to partitioning) and renames them to maintain consistent IDs across timesteps. Parameters ---------- data_t_minus_2 : xr.DataArray Object field at timestep t-2 data_t_minus_1 : xr.DataArray Object field at timestep t-1 (will be modified) object_props : xr.Dataset Object properties dataset (will be modified) timestep : int Current timestep number for logging purposes Returns ------- data_t_minus_1_consolidated : xr.DataArray Updated t-1 field with consolidated IDs object_props_updated : xr.Dataset Updated object properties with merged/deleted objects Notes ----- - Uses self.overlap_threshold for determining consolidation eligibility - Updates object properties by recalculating for consolidated objects - Removes redundant child objects from object_props """ # Find overlaps between t-2 and t-1 backward_overlaps = self.check_overlap_slice(data_t_minus_2.values, data_t_minus_1.values) if len(backward_overlaps) == 0: return data_t_minus_1, object_props backward_overlaps = self.enforce_overlap_threshold(backward_overlaps, object_props) if len(backward_overlaps) == 0: # pragma: no cover return data_t_minus_1, object_props # Find parent IDs that connect to multiple children (partition boundary jumps) parent_ids, parent_counts = np.unique(backward_overlaps[:, 0], return_counts=True) splitting_parents = parent_ids[parent_counts > 1] if len(splitting_parents) == 0: return data_t_minus_1, object_props # Track ID mappings for logging id_mappings = {} # child_id -> parent_id for parent_id in splitting_parents: # Skip if parent doesn't exist in properties if parent_id not in object_props.ID.values: continue # Get all children for this parent child_mask = backward_overlaps[:, 0] == parent_id children_for_parent = backward_overlaps[child_mask, 1].astype(int) # Consolidate all children to use first child_id if len(children_for_parent) > 1: first_child_id = int(children_for_parent[0]) # Skip if first child doesn't exist in properties if first_child_id not in object_props.ID.values: continue # Rename all other children to first_child_id for child_id in children_for_parent[1:]: child_id = int(child_id) # Skip if child doesn't exist in properties if child_id not in object_props.ID.values: continue # Rename child_id to first_child_id in data_t_minus_1 data_t_minus_1 = data_t_minus_1.where(data_t_minus_1 != child_id, first_child_id) # Remove redundant child_id from object_props if child_id in object_props.ID: object_props = object_props.drop_sel(ID=child_id) # Track the mapping id_mappings[child_id] = first_child_id # Recalculate properties for the consolidated object consolidated_mask = data_t_minus_1 == first_child_id if consolidated_mask.any(): # Create temporary field with only this object for property calculation temp_field = xr.where(consolidated_mask, first_child_id, 0) consolidated_props = self.calculate_object_properties(temp_field, properties=["area", "centroid"]) if first_child_id in consolidated_props.ID: # Update first child properties with consolidated values for var_name in ["area", "centroid"]: if var_name in consolidated_props: object_props[var_name].loc[{"ID": first_child_id}] = consolidated_props[var_name].sel( ID=first_child_id ) return data_t_minus_1, object_props
[docs] def compute_id_time_dict( self, da: xr.DataArray, child_objects: Union[List[int], NDArray[np.int32]], max_objects: int, all_objects: bool = True, ) -> Dict[int, int]: """ Generate lookup table mapping object IDs to their time index. Parameters ---------- da : xarray.DataArray Field of object IDs child_objects : list or array Object IDs to include in the dictionary max_objects : int Maximum number of objects all_objects : bool, default=True Whether to process all objects or just child_objects Returns ------- time_index_map : dict Dictionary mapping object IDs to time indices """ # Estimate max objects per time est_objects_per_time_max = int(max_objects / da[self.timedim].shape[0] * 100) def unique_pad(x: NDArray[np.int32]) -> NDArray[np.int32]: """Extract unique values and pad to fixed size.""" uniq = np.unique(x) result = np.zeros(est_objects_per_time_max, dtype=x.dtype) # Pad output to maximum size result[: len(uniq)] = uniq return result # Get unique IDs for each time slice input_dims = [self.xdim] if self.unstructured_grid else [self.ydim, self.xdim] unique_ids_by_time = xr.apply_ufunc( unique_pad, da, input_core_dims=[input_dims], output_core_dims=[["unique_values"]], dask="parallelized", vectorize=True, dask_gufunc_kwargs={"output_sizes": {"unique_values": est_objects_per_time_max}}, ) # Set up IDs to search for if not all_objects: # Just search for the specified child objects search_ids = xr.DataArray(child_objects, dims=["child_id"], coords={"child_id": child_objects}) else: # Search for all possible IDs search_ids = xr.DataArray( np.arange(max_objects, dtype=np.int32), dims=["child_id"], coords={"child_id": np.arange(max_objects, dtype=np.int32)}, ).chunk( {"child_id": 10000} ) # Chunk for better parallelism # Find the first time index where each ID appears time_indices = ( (unique_ids_by_time == search_ids).any(dim=["unique_values"]).argmax(dim=self.timedim).compute().astype(np.int32) ) # Convert to dictionary for fast lookup time_index_map = {int(id_val): int(idx.values) for id_val, idx in zip(time_indices.child_id, time_indices)} return time_index_map
# ============================ # Event Tracking Methods # ============================
[docs] def track_objects(self, data_bin: xr.DataArray) -> Tuple[xr.Dataset, xr.Dataset, int]: """ Track objects through time to form events. This is the main tracking method that handles splitting and merging of objects. Parameters ---------- data_bin : xarray.DataArray Preprocessed binary data: Field of globally unique integer IDs of each element in connected regions. ID = 0 indicates no object. Returns ------- split_merged_events_ds : xarray.Dataset Dataset containing tracked events merge_events : xarray.Dataset Dataset with merge information N_events : int Final number of events """ # Identify objects at each time step object_id_field, _, _ = self.identify_objects(data_bin, time_connectivity=False) object_id_field = object_id_field.persist() del data_bin logger.info("Finished object identification") # For unstructured grid, make objects unique across time if self.unstructured_grid: cumsum_ids = (object_id_field.max(dim=self.xdim)).cumsum(self.timedim).shift({self.timedim: 1}, fill_value=0) object_id_field = xr.where(object_id_field > 0, object_id_field + cumsum_ids, 0) object_id_field = self.refresh_dask_graph(object_id_field) logger.info(f"Finished assigning c. {cumsum_ids.max().compute().values} globally unique object IDs") # Calculate object properties object_props = self.calculate_object_properties(object_id_field, properties=["area", "centroid"]) object_props = object_props.persist() wait(object_props) logger.info("Finished calculating object properties") # Apply splitting & merging logic # This is the most intricate step due to non-trivial loop-wise dependencies # In v2.0_unstruct, this loop has been painstakingly parallelised split_and_merge = self.split_and_merge_objects_parallel if self.unstructured_grid else self.split_and_merge_objects object_id_field, object_props, overlap_objects_list, merge_events = split_and_merge(object_id_field, object_props) logger.info("Finished splitting and merging objects") # Persist results (This helps avoid block-wise task fusion run_spec issues with dask) results = persist(object_id_field, object_props, overlap_objects_list, merge_events) object_id_field, object_props, overlap_objects_list, merge_events = results # Cluster & rename objects to get globally unique event IDs split_merged_events_ds = self.cluster_rename_objects_and_props( object_id_field, object_props, overlap_objects_list, merge_events ) # Rechunk final output chunk_dict = { self.timedim: self.timechunks, "ID": -1, "component": -1, "sibling_ID": -1, self.xdim: -1, } if not self.unstructured_grid: chunk_dict[self.ydim] = -1 split_merged_events_ds = split_merged_events_ds.chunk(chunk_dict) # .persist() logger.info("Finished clustering and renaming objects into coherent consistent events") # Count final number of events N_events = split_merged_events_ds.ID_field.max().compute().data return split_merged_events_ds, merge_events, N_events
[docs] def cluster_rename_objects_and_props( self, object_id_field_unique: xr.DataArray, object_props: xr.Dataset, overlap_objects_list: NDArray[np.int32], merge_events: xr.Dataset, ) -> xr.Dataset: """ Cluster the object pairs and relabel to determine final event IDs. Parameters ---------- object_id_field_unique : xarray.DataArray Field of unique object IDs. IDs must not be repeated across time. object_props : xarray.Dataset Properties of each object that also need to be relabeled. overlap_objects_list : (N x 2) numpy.ndarray Array of object ID pairs that indicate which objects are in the same event. The object in the first column precedes the second column in time. merge_events : xarray.Dataset Information about merge events Returns ------- split_merged_events_ds : xarray.Dataset Dataset with relabeled events and their properties. ID = 0 indicates no object. """ # Cluster the overlap_pairs into groups of IDs that are actually the same object # Get IDs from overlap pairs # Step 1: Find all IDs that actually exist in the data max_ID = int(object_id_field_unique.max().compute().values.item()) # Get unique IDs from overlap list if len(overlap_objects_list) > 0: overlap_ids = np.unique(overlap_objects_list[:, :2].flatten()) overlap_ids = overlap_ids[overlap_ids > 0] # Remove 0 (background) else: overlap_ids = np.array([], dtype=np.int32) # pragma: no cover # Get unique IDs from object_id_field field_ids = np.unique(object_id_field_unique.compute().values) field_ids = field_ids[field_ids > 0] # Remove 0 (background) # Combine and get all valid IDs all_valid_ids = np.unique(np.concatenate([overlap_ids, field_ids])) logger.info(f"Found {len(all_valid_ids)} valid object IDs (out of max ID {max_ID})") # Step 2: Create dense mapping: original_ID -> dense_index # This ensures continuous indices for connected_components original_to_dense = {int(original_id): dense_idx for dense_idx, original_id in enumerate(all_valid_ids)} dense_to_original = {dense_idx: int(original_id) for original_id, dense_idx in original_to_dense.items()} n_valid = len(all_valid_ids) # Step 3: Convert overlap pairs to dense indices if len(overlap_objects_list) > 0: # Map to dense indices overlap_pairs_dense = np.array( [ [original_to_dense[int(pair[0])], original_to_dense[int(pair[1])]] for pair in overlap_objects_list if int(pair[0]) in original_to_dense and int(pair[1]) in original_to_dense ] ) # Create sparse graph with dense indices row_indices, col_indices = overlap_pairs_dense.T data = np.ones(len(overlap_pairs_dense), dtype=np.bool_) graph = csr_matrix((data, (row_indices, col_indices)), shape=(n_valid, n_valid), dtype=np.bool_) else: graph = csr_matrix((n_valid, n_valid), dtype=np.bool_) # pragma: no cover # Step 4: Solve for connected components (on dense graph) num_components, component_IDs_dense = connected_components(csgraph=graph, directed=False, return_labels=True) logger.info(f"Identified {num_components} connected components (events)") # Step 5: Create lookup from original IDs to event IDs # Event IDs will be continuous: 1, 2, 3, ... num_components original_to_event = {} for dense_idx, event_id in enumerate(component_IDs_dense): original_id = dense_to_original[dense_idx] original_to_event[original_id] = event_id + 1 # +1 so events start at 1, not 0 # Step 6: Create full lookup array for fast remapping ID_to_cluster_index_array = np.full(max_ID + 1, 0, dtype=np.int32) # 0 = background for original_id, event_id in original_to_event.items(): ID_to_cluster_index_array[original_id] = np.int32(event_id) # Convert to DataArray for apply_ufunc # N.B.: **Need to pass da into apply_ufunc, otherwise it doesn't manage the memory correctly # with large shared-mem numpy arrays** ID_to_cluster_index_da = xr.DataArray( ID_to_cluster_index_array, dims="ID", coords={"ID": np.arange(max_ID + 1, dtype=np.int32)}, ) def map_IDs_to_indices(block: NDArray[np.int32], ID_to_cluster_index_array: NDArray[np.int32]) -> NDArray[np.int32]: """Map original IDs to cluster indices.""" mask = block > 0 new_block = np.zeros_like(block, dtype=np.int32) new_block[mask] = ID_to_cluster_index_array[block[mask]] return new_block # Apply the mapping input_dims = [self.xdim] if self.unstructured_grid else [self.ydim, self.xdim] split_merged_relabeled_object_id_field = xr.apply_ufunc( map_IDs_to_indices, object_id_field_unique, ID_to_cluster_index_da, input_core_dims=[input_dims, ["ID"]], output_core_dims=[input_dims], vectorize=True, dask="parallelized", output_dtypes=[np.int32], ).persist() # Relabel the object_props to match the new IDs (and add time dimension) max_new_ID = num_components + 1 # New IDs range from 0 to max_new_ID new_ids = np.arange(1, max_new_ID + 1, dtype=np.int32) # Create new object_props dataset - use dimension coordinate for time data time_coord_data = object_id_field_unique.coords[self.timedim].data object_props_extended = xr.Dataset(coords={"ID": new_ids, self.timecoord: (self.timedim, time_coord_data)}) # Create mapping from new IDs to the original IDs _at the corresponding time_ valid_new_ids = split_merged_relabeled_object_id_field > 0 original_ids_field = object_id_field_unique.where(valid_new_ids) new_ids_field = split_merged_relabeled_object_id_field.where(valid_new_ids) if not self.unstructured_grid: original_ids_field = original_ids_field.stack(z=(self.ydim, self.xdim), create_index=False) new_ids_field = new_ids_field.stack(z=(self.ydim, self.xdim), create_index=False) new_id_to_idx = {id_val: idx for idx, id_val in enumerate(new_ids)} def process_timestep(orig_ids: NDArray[np.int32], new_ids_t: NDArray[np.int32]) -> NDArray[np.int32]: """Process a single timestep to create ID mapping.""" result = np.zeros(len(new_id_to_idx), dtype=np.int32) valid_mask = new_ids_t > 0 # Get valid points for this timestep if not valid_mask.any(): return result orig_valid = orig_ids[valid_mask] new_valid = new_ids_t[valid_mask] if len(orig_valid) == 0: return result unique_pairs = np.unique(np.column_stack((orig_valid, new_valid)), axis=0) # Create mapping for orig_id, new_id in unique_pairs: if new_id in new_id_to_idx: result[new_id_to_idx[new_id]] = orig_id return result # Process in parallel input_dim = [self.xdim] if self.unstructured_grid else ["z"] global_id_mapping = ( xr.apply_ufunc( process_timestep, original_ids_field, new_ids_field, input_core_dims=[input_dim, input_dim], output_core_dims=[["ID"]], vectorize=True, dask="parallelized", output_dtypes=[np.int32], dask_gufunc_kwargs={"output_sizes": {"ID": len(new_ids)}}, ) .assign_coords(ID=new_ids) .compute() ) # Store original ID mapping object_props_extended["global_ID"] = global_id_mapping # Post-condition: Now, e.g. global_id_mapping.sel(ID=10) # --> Given the new ID (10), returns corresponding original_id at every time # Transfer all properties from original object_props dummy = object_props.isel(ID=0) * np.nan # Add vale of ID = 0 to this coordinate ID object_props = xr.concat([dummy.assign_coords(ID=0), object_props], dim="ID") for var_name in object_props.data_vars: # Filter global_id_mapping to only include IDs that exist in object_props existing_ids = set(object_props.ID.values) valid_mapping_mask = global_id_mapping.isin(existing_ids) # Only select existing IDs valid_global_mapping = global_id_mapping.where(valid_mapping_mask, drop=True) if len(valid_global_mapping.ID) == 0: # No valid IDs - create empty result temp = object_props[var_name].isel(ID=slice(0, 0)) else: temp = ( object_props[var_name] .sel(ID=valid_global_mapping.rename({"ID": "new_id"})) .drop_vars("ID") .rename({"new_id": "ID"}) ) if var_name == "ID": temp = temp.astype(np.int32) else: temp = temp.astype(np.float32) object_props_extended[var_name] = temp # Map the merge_events using the old IDs to be from dimensions (merge_ID, parent_idx) # --> new merge_ledger with dimensions (time, ID, sibling_ID) # i.e. for each merge_ID --> merge_parent_IDs gives the old IDs --> map to new ID using ID_to_cluster_index_da # --> merge_time old_parent_IDs = xr.where(merge_events.parent_IDs > 0, merge_events.parent_IDs, 0) new_IDs_parents = ID_to_cluster_index_da.sel(ID=old_parent_IDs) # Replace the coordinate merge_ID in new_IDs_parents with merge_time. # merge_events.merge_time gives merge_time for each merge_ID new_IDs_parents_t = ( new_IDs_parents.assign_coords({"merge_time": merge_events.merge_time}) .drop_vars("ID") .swap_dims({"merge_ID": "merge_time"}) .persist() ) # Map new_IDs_parents_t into a new data array with dimensions time, ID, and sibling_ID merge_ledger = ( xr.full_like(global_id_mapping, fill_value=-1) .chunk({self.timedim: self.timechunks}) .expand_dims({"sibling_ID": new_IDs_parents_t.parent_idx.shape[0]}) .copy() ) # Wrapper for processing/mapping mergers in parallel def process_time_group( time_block: xr.DataArray, IDs_data: NDArray[np.int32], IDs_coords: Dict[str, Any], ) -> xr.DataArray: """Process all mergers for a single block of timesteps.""" result = xr.full_like(time_block, -1) # Get unique times in this block # time_block might not have the coordinate, so get it from the dimension index if self.timecoord in time_block.coords: unique_times = np.unique(time_block.coords[self.timecoord]) else: # Fall back to using the dimension index unique_times = np.unique(time_block[self.timedim]) for time_val in unique_times: # Get IDs for this time time_mask = IDs_coords["merge_time"] == time_val if not np.any(time_mask): continue IDs_at_time = IDs_data[time_mask] # Handle single merger case if IDs_at_time.ndim == 1: valid_mask = IDs_at_time > 0 if np.any(valid_mask): # Create expanded array for sibling_ID dimension expanded_IDs = np.broadcast_to(IDs_at_time, (len(time_block.sibling_ID), len(IDs_at_time))) result.loc[{self.timedim: time_val, "ID": IDs_at_time[valid_mask]}] = expanded_IDs[:, valid_mask] # Handle multiple mergers case else: for merger_IDs in IDs_at_time: valid_mask = merger_IDs > 0 if np.any(valid_mask): expanded_IDs = np.broadcast_to( merger_IDs, (len(time_block.sibling_ID), len(merger_IDs)), ) result.loc[{self.timedim: time_val, "ID": merger_IDs[valid_mask]}] = expanded_IDs[:, valid_mask] return result # Map blocks in parallel merge_ledger = xr.map_blocks( process_time_group, merge_ledger, args=(new_IDs_parents_t.values, new_IDs_parents_t.coords), template=merge_ledger, ) # Format merge ledger merge_ledger = merge_ledger.rename("merge_ledger").transpose(self.timedim, "ID", "sibling_ID").persist() # Add start and end time indices for each ID valid_presence = object_props_extended["global_ID"] > 0 # i.e. where there is valid data object_props_extended["presence"] = valid_presence object_props_extended["time_start"] = valid_presence[self.timecoord][ valid_presence.argmax(dim=self.timedim).astype(np.int32) ] object_props_extended["time_end"] = valid_presence[self.timecoord][ ((valid_presence.sizes[self.timedim] - 1) - (valid_presence[::-1]).argmax(dim=self.timedim)).astype(np.int32) ] # Recompute area & centroid (now that the IDs have been consolidated & merged & made continuous) if "area" in object_props_extended.data_vars or "centroid" in object_props_extended.data_vars: logger.info("Recalculating area and centroid properties for potentially disjoint events...") def calculate_area_centroid_for_slice( slice_data: NDArray[np.int32], cell_areas_slice: NDArray[np.float32], present_mask: NDArray[np.bool_], all_event_ids: NDArray[np.int32], lat_vals: NDArray[np.float32], lon_vals: NDArray[np.float32], is_unstructured: bool, regional_mode: bool, ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]: """ Calculate area and area-weighted centroid for IDs present at this timestep. Returns three arrays with full ID dimension (NaN for absent IDs). Parameters ---------- slice_data : array Spatial field of event IDs for this timestep cell_areas_slice : array Spatial field of cell areas present_mask : array 1D boolean array indicating which IDs are present (length = n_IDs) all_event_ids : array All event IDs (length = n_IDs) """ n_ids = len(all_event_ids) # Initialise output arrays with NaN areas = np.full(n_ids, np.nan, dtype=np.float32) centroid_lats = np.full(n_ids, np.nan, dtype=np.float32) centroid_lons = np.full(n_ids, np.nan, dtype=np.float32) # Get indices of IDs that are present at this timestep present_indices = np.where(present_mask)[0] if len(present_indices) == 0: return areas, centroid_lats, centroid_lons if is_unstructured: # Unstructured grid: area-weighted centroid using spherical geometry # Convert to radians for Cartesian calculation lat_rad = np.radians(lat_vals) lon_rad = np.radians(lon_vals) # Process each present ID for id_idx in present_indices: event_id = all_event_ids[id_idx] mask = slice_data == event_id if not np.any(mask): continue # pragma: no cover # Calculate physical area areas_masked = cell_areas_slice[mask] total_area = np.sum(areas_masked) areas[id_idx] = total_area # Calculate area-weighted centroid using spherical geometry cos_lat = np.cos(lat_rad[mask]) x = cos_lat * np.cos(lon_rad[mask]) y = cos_lat * np.sin(lon_rad[mask]) z = np.sin(lat_rad[mask]) # Weighted average in Cartesian coordinates weighted_x = np.sum(areas_masked * x) weighted_y = np.sum(areas_masked * y) weighted_z = np.sum(areas_masked * z) # Normalise norm = np.sqrt(weighted_x**2 + weighted_y**2 + weighted_z**2) if norm > 0: weighted_x /= norm weighted_y /= norm weighted_z /= norm # Convert back to lat/lon centroid_lat = np.degrees(np.arcsin(np.clip(weighted_z, -1, 1))) centroid_lon = np.degrees(np.arctan2(weighted_y, weighted_x)) # Fix longitude range to [-180, 180] if centroid_lon > 180: centroid_lon -= 360 # pragma: no cover elif centroid_lon < -180: centroid_lon += 360 # pragma: no cover centroid_lats[id_idx] = centroid_lat centroid_lons[id_idx] = centroid_lon else: # Structured grid: area-weighted centroid with periodic boundary handling ny, nx = slice_data.shape # Process each present ID for id_idx in present_indices: event_id = all_event_ids[id_idx] # Get binary mask for this event binary_mask = slice_data == event_id if not np.any(binary_mask): continue # pragma: no cover # Get indices where object exists y_indices, x_indices = np.nonzero(binary_mask) # Get cell areas for these indices pixel_areas = cell_areas_slice[binary_mask] total_area = np.sum(pixel_areas) areas[id_idx] = total_area # Calculate area-weighted y centroid (latitude) centroid_y_pix = np.sum(y_indices * pixel_areas) / total_area # Calculate area-weighted x centroid (longitude) - handle wrapping if needed if not regional_mode: # Check if object is near both edges (wrapping around periodic boundary) near_left = np.any(x_indices < 100) near_right = np.any(x_indices >= nx - 100) if near_left and near_right: # Object wraps around - adjust coordinates x_adjusted = x_indices.copy().astype(np.float64) right_side = x_indices > nx / 2 x_adjusted[right_side] -= nx # Area-weighted mean with adjusted coordinates centroid_x_pix = np.sum(x_adjusted * pixel_areas) / total_area # Ensure centroid is positive if centroid_x_pix < 0: centroid_x_pix += nx else: # No wrapping - standard area-weighted calculation centroid_x_pix = np.sum(x_indices * pixel_areas) / total_area else: # Regional mode - no wrapping, area-weighted centroid_x_pix = np.sum(x_indices * pixel_areas) / total_area # Convert pixel indices to coordinate values centroid_lat = np.interp(centroid_y_pix, np.arange(len(lat_vals)), lat_vals) centroid_lon = np.interp(centroid_x_pix, np.arange(len(lon_vals)), lon_vals) centroid_lats[id_idx] = centroid_lat centroid_lons[id_idx] = centroid_lon return areas, centroid_lats, centroid_lons # Prepare spatial dimensions spatial_dims = [self.xdim] if self.unstructured_grid else [self.ydim, self.xdim] # Ensure cell_area has correct dimensions for apply_ufunc if not self.unstructured_grid and self.cell_area.ndim == 1: # Broadcast 1D latitude-dependent cell areas to 2D (lat, lon) template = split_merged_relabeled_object_id_field.isel({self.timedim: 0}, drop=True) cell_area_broadcast, _ = xr.broadcast(self.cell_area, template) else: cell_area_broadcast = self.cell_area # Apply calculation in parallel across time slices logger.info("Computing area and centroid properties in parallel...") areas_computed, centroid_lats_computed, centroid_lons_computed = xr.apply_ufunc( calculate_area_centroid_for_slice, split_merged_relabeled_object_id_field, cell_area_broadcast, # Broadcasted to match spatial dimensions object_props_extended.presence, # Boolean mask of which IDs are present at each time object_props_extended.ID, self.lat, # Latitude coordinate values self.lon, # Longitude coordinate values kwargs={"is_unstructured": self.unstructured_grid, "regional_mode": self.regional_mode}, input_core_dims=[ spatial_dims, spatial_dims, ["ID"], ["ID"], [self.ydim] if not self.unstructured_grid else [self.xdim], [self.xdim], ], output_core_dims=[["ID"], ["ID"], ["ID"]], vectorize=True, dask="parallelized", output_dtypes=[np.float32, np.float32, np.float32], ) results = persist(areas_computed, centroid_lats_computed, centroid_lons_computed) areas_computed, centroid_lats_computed, centroid_lons_computed = results # Update area with proper dimension ordering (time, ID) object_props_extended["area"] = areas_computed.transpose(self.timedim, "ID") # Combine lat/lon centroids along component dimension new_centroid = xr.concat([centroid_lats_computed, centroid_lons_computed], dim="component") new_centroid = new_centroid.assign_coords(component=[0, 1]) # Update centroid with proper dimension ordering (component, time, ID) object_props_extended["centroid"] = new_centroid.transpose("component", self.timedim, "ID") logger.info("Property recalculation complete.") # Combine all components into final dataset split_merged_relabeled_events_ds = xr.merge( [ split_merged_relabeled_object_id_field.rename("ID_field"), object_props_extended, merge_ledger, ] ) # Remove the last ID -- it is all 0s (because we added an extra padding one above) return split_merged_relabeled_events_ds.isel(ID=slice(0, -1))
# ============================ # Splitting and Merging Methods # ============================
[docs] def split_and_merge_objects( self, object_id_field_unique: xr.DataArray, object_props: xr.Dataset ) -> Tuple[xr.DataArray, xr.Dataset, NDArray[np.int32], xr.Dataset]: """ Implement object splitting and merging logic. This identifies and processes cases where objects split or merge over time, creating new object IDs as needed. Parameters ---------- object_id_field_unique : xarray.DataArray Field of unique object IDs. IDs are required to be monotonically increasing with time. object_props : xarray.Dataset Properties of each object Returns ------- tuple (object_id_field, object_props, overlap_objects_list, merge_events) """ # Find overlapping objects overlap_objects_list = self.find_overlapping_objects( object_id_field_unique ) # List object pairs that overlap by at least overlap_threshold percent overlap_objects_list = self.enforce_overlap_threshold(overlap_objects_list, object_props) logger.info("Finished finding overlapping objects") # Initialise merge tracking lists merge_times = [] # When the merge occurred merge_child_ids = [] # Resulting child ID merge_parent_ids = [] # List of parent IDs that merged merge_areas = [] # Areas of overlap next_new_id = int(object_props.ID.max().item()) + 1 # Start new IDs after highest existing ID Nx = object_id_field_unique[self.xdim].size object_id_field_unique = object_id_field_unique.persist() updated_chunks = [] # Process each time chunk with timestep-first approach chunk_boundaries = np.cumsum([0] + list(object_id_field_unique.chunks[0])) for chunk_idx in range(len(object_id_field_unique.chunks[0])): # Extract and load an entire chunk into memory chunk_start = chunk_boundaries[chunk_idx] chunk_end = chunk_boundaries[chunk_idx + 1] # Ensure we don't exceed array bounds chunk_end = min(chunk_end, object_id_field_unique.sizes[self.timedim]) chunk_data = object_id_field_unique.isel({self.timedim: slice(chunk_start, chunk_end)}).compute() # Process each timestep within chunk sequentially for relative_t in range(chunk_data.sizes[self.timedim]): absolute_t = chunk_start + relative_t # Get data slices for current timestep data_t = chunk_data.isel({self.timedim: relative_t}) # Get previous timesteps for consolidation and partitioning if relative_t > 1: # Need both t-1 and t-2 for consolidation data_t_minus_2 = chunk_data.isel({self.timedim: relative_t - 2}) data_t_minus_1 = chunk_data.isel({self.timedim: relative_t - 1}) elif relative_t == 1: # t-1 is in current chunk, t-2 might be in previous chunk data_t_minus_1 = chunk_data.isel({self.timedim: 0}) # relative_t - 1 = 0 if updated_chunks: _, _, last_chunk_data = updated_chunks[-1] data_t_minus_2 = last_chunk_data[-1] # Last timestep from previous chunk else: data_t_minus_2 = xr.full_like(data_t, 0) else: # relative_t == 0, get both from previous chunk if available if updated_chunks: _, _, last_chunk_data = updated_chunks[-1] if len(last_chunk_data) >= 2: data_t_minus_2 = last_chunk_data[-2] data_t_minus_1 = last_chunk_data[-1] elif len(last_chunk_data) == 1: data_t_minus_2 = xr.full_like(data_t, 0) data_t_minus_1 = last_chunk_data[-1] else: data_t_minus_2 = xr.full_like(data_t, 0) data_t_minus_1 = xr.full_like(data_t, 0) else: data_t_minus_2 = xr.full_like(data_t, 0) data_t_minus_1 = xr.full_like(data_t, 0) # ID Consolidation of objects at t-1 if relative_t > 0: # Only consolidate if we have meaningful t-1 and t-2 data_t_minus_1, object_props = self.consolidate_object_ids( data_t_minus_2, data_t_minus_1, object_props, absolute_t - 1 ) # Update the chunk with consolidated data whenever t-1 is in current chunk chunk_data[{self.timedim: relative_t - 1}] = data_t_minus_1 # Normal overlap detection and partitioning (now with consolidated IDs) # Calculate overlaps for this timestep # Here, parents are at previous time=t-1 (LHS), children are at current time=t (RHS) timestep_overlaps = self.check_overlap_slice(data_t_minus_1.values, data_t.values) timestep_overlaps = self.enforce_overlap_threshold(timestep_overlaps, object_props) # Iterative processing within timestep=t until convergence # Only modifies data_t, which contains the children to be partitioned/relabelled timestep_converged = False iteration = 0 while not timestep_converged and iteration < 10: # Prevent infinite loops # Find merging objects for current timestep unique_children, children_counts = np.unique(timestep_overlaps[:, 1], return_counts=True) merging_children = unique_children[children_counts > 1] if len(merging_children) == 0: timestep_converged = True continue # Process all merging objects in this timestep # Parents exist in this timestep, but for child_id in merging_children: # Get mask of child object child_mask_2d = (data_t == child_id).values # Find all pairs involving this child child_mask = timestep_overlaps[:, 1] == child_id child_where = np.where(timestep_overlaps[:, 1] == child_id)[0].astype(np.int32) merge_group = timestep_overlaps[child_mask] # Get parent objects (LHS) that overlap with this child object parent_ids = merge_group[:, 0] num_parents = len(parent_ids) # Create new IDs for the other half of the child object & record in the merge ledger new_object_id = np.arange(next_new_id, next_new_id + (num_parents - 1), dtype=np.int32) next_new_id += num_parents - 1 # Replace the 2nd+ child in the overlap objects list with the new child ID timestep_overlaps[child_where[1:], 1] = new_object_id child_ids = np.concatenate((np.array([child_id]), new_object_id)) # Record merge event - extract time value using dimension name merge_times.append(data_t.coords[self.timedim].values) merge_child_ids.append(child_ids) merge_parent_ids.append(parent_ids) merge_areas.append(timestep_overlaps[child_mask, 2]) # Relabel the Original Child Object ID Field to account for the New ID: # Get parent centroids for partitioning parent_centroids = object_props.sel(ID=parent_ids).centroid.values.T # Partition the child object based on parent associations if self.nn_partitioning: # Nearest-neighbor partitioning # --> For every (Original) Child Cell in the ID Field, Find the closest (t-1) Parent _Cell_ if self.unstructured_grid: # Prepare parent masks parent_masks = np.zeros((len(parent_ids), data_t_minus_1.shape[0]), dtype=bool) for idx, parent_id in enumerate(parent_ids): parent_masks[idx] = (data_t_minus_1 == parent_id).values # Calculate maximum search distance max_area = np.max(object_props.sel(ID=parent_ids).area.values) / self.mean_cell_area max_distance = int(np.sqrt(max_area) * 2.0) # Use optimised unstructured partitioning new_labels = partition_nn_unstructured( child_mask_2d, parent_masks, child_ids, parent_centroids, self.neighbours_int.values, self.lat.values, # Need to pass these as NumPy arrays for JIT compatibility self.lon.values, max_distance=max(max_distance, 20) * 2, # Set minimum threshold, in cells ) else: # Prepare parent masks for structured grid parent_masks = np.zeros( ( len(parent_ids), data_t_minus_1.shape[0], data_t_minus_1.shape[1], ), dtype=bool, ) for idx, parent_id in enumerate(parent_ids): parent_masks[idx] = (data_t_minus_1 == parent_id).values # Calculate maximum search distance max_area = np.max(object_props.sel(ID=parent_ids).area.values) max_distance = int(np.sqrt(max_area) * 3.0) # Use 3x the max blob radius # Use optimised structured grid partitioning new_labels = partition_nn_grid( child_mask_2d, parent_masks, child_ids, parent_centroids, Nx, max_distance=max(max_distance, 40), # Set minimum threshold, in cells wrap=not self.regional_mode, # Turn longitude periodic wrapping off when in regional mode ) else: # Centroid-based partitioning # --> For every (Original) Child Cell in the ID Field, Find the closest (t-1) Parent _Centroid_ if self.unstructured_grid: new_labels = partition_centroid_unstructured( child_mask_2d, parent_centroids, child_ids, self.lat.values, self.lon.values, ) else: # Calculate distances to each parent centroid distances = wrapped_euclidian_distance_mask_parallel( child_mask_2d, parent_centroids, Nx, not self.regional_mode ) # Assign based on closest parent new_labels = child_ids[np.argmin(distances, axis=1).astype(np.int32)] # Update values in data_t and assign the updated slice back to the chunk temp = np.zeros_like(data_t) temp[child_mask_2d] = new_labels data_t = data_t.where(~child_mask_2d, temp) chunk_data[{self.timedim: relative_t}] = data_t # Update the Properties of the N Children Objects new_child_props = self.calculate_object_properties(data_t, properties=["area", "centroid"]) # Update the object_props DataArray: (but first, check if the original children still exists) if child_id in new_child_props.ID: # Update existing entry object_props.loc[{"ID": child_id}] = new_child_props.sel(ID=child_id) else: # Delete child_id: The object has split/morphed such that it doesn't get a partition of this child... object_props = object_props.drop_sel( ID=child_id ) # N.B.: This means that the IDs are no longer continuous... logger.info(f"Deleted child_id {child_id} because parents have split/morphed") # Add the properties for the N-1 other new child ID new_object_ids_still = new_child_props.ID.where(new_child_props.ID.isin(new_object_id), drop=True).ID object_props = xr.concat( [object_props, new_child_props.sel(ID=new_object_ids_still)], dim="ID", ) missing_ids = set(new_object_id) - set(new_object_ids_still.values) if len(missing_ids) > 0: logger.warning( f"Missing newly created child_ids {missing_ids} " f"because parents have split/morphed in the meantime..." ) # After processing all merging objects in this iteration # Recalculate overlaps to check for newly viable merges timestep_overlaps = self.check_overlap_slice(data_t_minus_1.values, data_t.values) timestep_overlaps = self.enforce_overlap_threshold(timestep_overlaps, object_props) iteration += 1 if iteration == 10: logger.warning(f"Resolving mergers at timestep {absolute_t} did not converge after 10 iterations") # End-of-chunk consolidation: consolidate the last timestep if chunk has multiple timesteps if chunk_data.sizes[self.timedim] >= 2: # Get last and second-to-last timesteps last_t_data = chunk_data.isel({self.timedim: -1}) second_last_t_data = chunk_data.isel({self.timedim: -2}) # Consolidate last timestep using second-to-last as reference consolidated_last, object_props = self.consolidate_object_ids( second_last_t_data, last_t_data, object_props, chunk_end - 1 ) # Update the last timestep in chunk chunk_data[{self.timedim: -1}] = consolidated_last # Store the processed chunk updated_chunks.append( ( chunk_start, chunk_end, chunk_data[: (chunk_end - chunk_start)], ) ) if chunk_idx % 10 == 0: logger.info(f"Processing splitting and merging in chunk {chunk_idx} of {len(object_id_field_unique.chunks[0])}") # Periodically update main array to manage memory if len(updated_chunks) > 1: # Keep the last chunk for potential reference for start, end, processed_chunk_data in updated_chunks[:-1]: object_id_field_unique[{self.timedim: slice(start, end)}] = processed_chunk_data updated_chunks = updated_chunks[-1:] # Keep only the last chunk object_id_field_unique = object_id_field_unique.persist() # Apply final chunk updates for start, end, processed_chunk_data in updated_chunks: object_id_field_unique[{self.timedim: slice(start, end)}] = processed_chunk_data object_id_field_unique = object_id_field_unique.persist() # Recompute final overlapping objects overlap_objects_list = self.find_overlapping_objects(object_id_field_unique) overlap_objects_list = self.enforce_overlap_threshold(overlap_objects_list, object_props) logger.info("Finished final overlapping objects search") # Check for duplicate children (multiple parents per child) if len(overlap_objects_list) > 0: child_ids = overlap_objects_list[:, 1] # RHS column (children) unique_children, child_counts = np.unique(child_ids, return_counts=True) # Find children with multiple parents duplicate_children = unique_children[child_counts > 1] # Enhanced validation with comprehensive spatial and temporal information if len(duplicate_children) > 0: logger.warning(f"There is {len(duplicate_children)} potentially problematic children:") # Log problematic child IDs (time info not available at this stage) logger.warning(f"Children IDs: {duplicate_children[:10].tolist()}") # Detailed analysis of each problematic child for child_id in duplicate_children[:5]: # Limit to first 5 for readability # Find all parent-child relationships for this child child_relationships = overlap_objects_list[overlap_objects_list[:, 1] == child_id] parent_ids = child_relationships[:, 0] overlap_areas = child_relationships[:, 2] logger.warning(f"\n--- Details for child ID {child_id} ---") logger.warning(f"Number of parents: {len(parent_ids)}") logger.warning(f"Parent IDs: {parent_ids.tolist()}") logger.warning(f"Raw overlap areas: {overlap_areas.tolist()}") # Get child object properties if available try: if child_id in object_props.ID.values: child_area = object_props.sel(ID=child_id).area.values.item() child_centroid = object_props.sel(ID=child_id).centroid.values logger.warning(f"Child total area: {child_area}") logger.warning(f"Child centroid: {child_centroid}") # Calculate overlap fractions for each parent overlap_fractions = [] parent_areas = [] for i, parent_id in enumerate(parent_ids): if parent_id in object_props.ID.values: parent_area = object_props.sel(ID=parent_id).area.values.item() parent_areas.append(parent_area) # Calculate overlap fraction based on smaller object min_area = min(child_area, parent_area) overlap_fraction = float(overlap_areas[i]) / min_area overlap_fractions.append(overlap_fraction) else: parent_areas.append("N/A") overlap_fractions.append("N/A") logger.warning(f"Parent areas: {parent_areas}") logger.warning(f"Overlap fractions: {overlap_fractions}") # Check for suspicious patterns total_overlap_area = sum(overlap_areas) logger.warning(f"Sum of overlap areas: {total_overlap_area}") logger.warning(f"Sum/Child area ratio: {total_overlap_area/child_area:.3f}") # Flag potential issues valid_fractions = [f for f in overlap_fractions if isinstance(f, (int, float))] if valid_fractions and max(valid_fractions) > 1.0: logger.warning(f"WARNING: Overlap fraction > 1.0 detected (max: {max(valid_fractions):.3f})") if total_overlap_area > child_area * 1.1: # Allow 10% tolerance logger.warning( f"WARNING: Total overlap exceeds child area by {(total_overlap_area/child_area - 1)*100:.1f}%" ) else: logger.warning(f"Child ID {child_id} not found in object_props") except Exception as e: logger.warning(f"Error analysing child ID {child_id}: {str(e)}") # Try to find timestep information by checking where this child appears try: child_timesteps = [] for t_idx in range(object_id_field_unique.sizes[self.timedim]): time_slice = object_id_field_unique.isel({self.timedim: t_idx}) if (time_slice == child_id).any(): time_coord = time_slice.coords[self.timedim].values child_timesteps.append((t_idx, time_coord)) if child_timesteps: logger.warning(f"Child appears at timesteps: {child_timesteps}") else: logger.warning("Child timestep information not found") except Exception as e: logger.warning(f"Error finding timestep for child ID {child_id}: {str(e)}") logger.warning("--- End detailed analysis ---\n") # Log summary information as warnings instead of raising error logger.warning("=" * 80) logger.warning("Tracker Warning: Multiple parents for single child detected after splitting/merging") logger.warning(f"Details: {len(duplicate_children)} children have multiple parents") logger.warning("Note: This is likely due to consolidation of IDs after splitting/merging") logger.warning(" and still is the correct behaviour (as per the tracking overlap logic") logger.warning(" applied to disjoint objects that will be grouped together.)") logger.warning("=" * 80) else: logger.info(f"Validation passed: All {len(unique_children)} children have unique parents") else: logger.info("No overlaps found - validation skipped") # Process merge events into a dataset # Handle case where there are no merge events if merge_parent_ids and merge_child_ids: max_parents = max(len(ids) for ids in merge_parent_ids) max_children = max(len(ids) for ids in merge_child_ids) else: max_parents = 1 # Default minimum size max_children = 1 # Convert lists to padded numpy arrays parent_ids_array = np.full((len(merge_parent_ids), max_parents), -1, dtype=np.int32) child_ids_array = np.full((len(merge_child_ids), max_children), -1, dtype=np.int32) overlap_areas_array = np.full((len(merge_areas), max_parents), -1, dtype=np.int32) for i, parents in enumerate(merge_parent_ids): parent_ids_array[i, : len(parents)] = parents for i, children in enumerate(merge_child_ids): child_ids_array[i, : len(children)] = children for i, areas in enumerate(merge_areas): overlap_areas_array[i, : len(areas)] = areas # Create merge events dataset merge_events = xr.Dataset( { "parent_IDs": (("merge_ID", "parent_idx"), parent_ids_array), "child_IDs": (("merge_ID", "child_idx"), child_ids_array), "overlap_areas": (("merge_ID", "parent_idx"), overlap_areas_array), "merge_time": ("merge_ID", merge_times), "n_parents": ( "merge_ID", np.array([len(p) for p in merge_parent_ids], dtype=np.int8), ), "n_children": ( "merge_ID", np.array([len(c) for c in merge_child_ids], dtype=np.int8), ), }, attrs={"fill_value": -1}, ) object_props = object_props.persist() return ( object_id_field_unique, object_props, overlap_objects_list[:, :2], # Only return first 2 columns (ID pairs) merge_events, )
[docs] def split_and_merge_objects_parallel( self, object_id_field_unique: xr.DataArray, object_props: xr.Dataset ) -> Tuple[xr.DataArray, xr.Dataset, NDArray[np.int32], xr.Dataset]: """ Optimised parallel implementation of object splitting and merging. This version is specifically designed for unstructured grids with more efficient memory handling and better parallelism than the standard split_and_merge_objects method. It processes data in chunks, handles merging events, and efficiently updates object IDs. Parameters ---------- object_id_field_unique : xarray.DataArray Field of unique object IDs object_props : xarray.Dataset Properties of each object Returns ------- tuple (object_id_field, object_props, overlap_objects_list, merge_events) """ # Constants for memory allocation MAX_MERGES = 20 # Maximum number of merges per timestep MAX_PARENTS = 10 # Maximum number of parents per merge MAX_CHILDREN = MAX_PARENTS def process_chunk( chunk_data_m1_full: NDArray[np.int32], chunk_data_p1_full: NDArray[np.int32], merging_objects: NDArray[np.int64], next_id_start: NDArray[np.int64], lat: NDArray[np.float32], lon: NDArray[np.float32], area: NDArray[np.float32], neighbours_int: NDArray[np.int32], ) -> Tuple[ NDArray[np.int32], # merge_child_ids NDArray[np.int32], # merge_parent_ids NDArray[np.float32], # merge_areas NDArray[np.int16], # merge_counts NDArray[np.bool_], # has_merge NDArray[np.uint8], # updates_array NDArray[np.int32], # updates_ids NDArray[np.int32], # final_merging_objects ]: """ Process a single chunk of merging objects. This function handles the complex batch processing of splitting and merging objects across timesteps within a single chunk. It finds overlapping objects, determines parent-child relationships, and creates new IDs as needed. Parameters ---------- chunk_data_m1_full : numpy.ndarray Data from previous timestep (t-1) and current timestep (t) chunk_data_p1_full : numpy.ndarray Data from next timestep (t+1) merging_objects : (n_time, max_merges) numpy.ndarray IDs of objects to process next_id_start : (n_time, max_merges) numpy.ndarray Starting ID values for new objects lat, lon : numpy.ndarray Latitude/longitude arrays area : numpy.ndarray Cell area array neighbours_int : numpy.ndarray Neighbor connectivity array Returns ------- tuple Contains merge events, object updates, and newly created objects """ # Fix Broadcasted dimensions of inputs: # Remove extra dimension if present while preserving time chunks # N.B.: This is a weird artefact/choice of xarray apply_ufunc broadcasting... # (i.e. 'nv' dimension gets injected into all the other arrays!) chunk_data_m1 = chunk_data_m1_full.squeeze()[0].astype(np.int32).copy() chunk_data = chunk_data_m1_full.squeeze()[1].astype(np.int32).copy() del chunk_data_m1_full # Free memory immediately chunk_data_p1 = chunk_data_p1_full.astype(np.int32).copy() # Remove any singleton dimensions except time and space while chunk_data_p1.ndim > 2: chunk_data_p1 = chunk_data_p1.squeeze(axis=-1) del chunk_data_p1_full # Extract and prepare input arrays lat = lat.squeeze().astype(np.float32) lon = lon.squeeze().astype(np.float32) area = area.squeeze().astype(np.float32) next_id_start = next_id_start.squeeze() # Handle neighbours_int with correct dimensions (nv, ncells) neighbours_int = neighbours_int.squeeze() if neighbours_int.shape[1] != lat.shape[0]: neighbours_int = neighbours_int.T # Handle multiple merging objects - ensure proper dimensionality merging_objects = merging_objects.squeeze() if merging_objects.ndim == 1: merging_objects = merging_objects[:, None] # Add dimension for max_merges # Pre-convert lat/lon to Cartesian coordinates for efficiency x = (np.cos(np.radians(lat)) * np.cos(np.radians(lon))).astype(np.float32) y = (np.cos(np.radians(lat)) * np.sin(np.radians(lon))).astype(np.float32) z = np.sin(np.radians(lat)).astype(np.float32) # Pre-allocate output arrays n_time = chunk_data_p1.shape[0] n_points = chunk_data_p1.shape[1] merge_child_ids = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.int32) merge_parent_ids = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.int32) merge_areas = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.float32) merge_counts = np.zeros(n_time, dtype=np.int16) # Number of merges per timestep updates_array = np.full((n_time, n_points), 255, dtype=np.uint8) updates_ids = np.full((n_time, 255), -1, dtype=np.int32) has_merge = np.zeros(n_time, dtype=np.bool_) # Prepare merging objects list for each timestep merging_objects_list = [list(merging_objects[i][merging_objects[i] > 0]) for i in range(merging_objects.shape[0])] final_merging_objects = np.full((n_time, MAX_MERGES), -1, dtype=np.int32) final_merge_count = 0 # Process each timestep data_p1 = [] for t in range(n_time): next_new_id = next_id_start[t] # Use the offset for this timestep # Get current time slice data if t == 0: data_m1 = chunk_data_m1 data_t = chunk_data del chunk_data_m1, chunk_data # Free memory else: data_m1 = data_t # Previous data_t becomes data_m1 data_t = data_p1 # Previous data_p1 becomes data_t data_p1 = chunk_data_p1[t] # Process each merging object at this timestep while merging_objects_list[t]: child_id = merging_objects_list[t].pop(0) # Get child mask and identify overlapping parents child_mask = data_t == child_id # Find parent objects that overlap with this child potential_parents = np.unique(data_m1[child_mask]) parent_iterator = 0 parent_masks_uint = np.full(n_points, 255, dtype=np.uint8) parent_centroids = np.full((MAX_PARENTS, 2), -1.0e10, dtype=np.float32) parent_ids = np.full(MAX_PARENTS, -1, dtype=np.int32) parent_areas = np.zeros(MAX_PARENTS, dtype=np.float32) overlap_areas = np.zeros(MAX_PARENTS, dtype=np.float32) n_parents = 0 # Find all unique parent IDs with significant overlap for parent_id in potential_parents[potential_parents > 0]: if n_parents >= MAX_PARENTS: # pragma: no cover raise TrackingError( "Too many parent objects for tracking", details=f"Child {child_id} at timestep {t} has {n_parents} parents (limit: {MAX_PARENTS})", suggestions=[ "Increase overlap_threshold to reduce fragmentation", "Apply stronger area filtering", ], context={ "child_id": child_id, "timestep": t, "n_parents": n_parents, "limit": MAX_PARENTS, }, ) parent_mask = data_m1 == parent_id if np.any(parent_mask & child_mask): # Calculate overlap area and check if it's large enough area_0 = area[parent_mask].sum() # Parent area area_1 = area[child_mask].sum() # Child area min_area = np.minimum(area_0, area_1) overlap_area = area[parent_mask & child_mask].sum() # Skip if overlap is below threshold if overlap_area / min_area < self.overlap_threshold: continue # Record parent information parent_masks_uint[parent_mask] = parent_iterator parent_ids[n_parents] = parent_id overlap_areas[n_parents] = overlap_area # Calculate area-weighted centroid for this parent mask_area = area[parent_mask] weighted_coords = np.array( [ np.sum(mask_area * x[parent_mask]), np.sum(mask_area * y[parent_mask]), np.sum(mask_area * z[parent_mask]), ], dtype=np.float32, ) norm = np.sqrt(np.sum(weighted_coords * weighted_coords)) # Convert back to lat/lon parent_centroids[n_parents, 0] = np.degrees(np.arcsin(weighted_coords[2] / norm)) parent_centroids[n_parents, 1] = np.degrees(np.arctan2(weighted_coords[1], weighted_coords[0])) # Fix longitude range to [-180, 180] if parent_centroids[n_parents, 1] > 180: parent_centroids[n_parents, 1] -= 360 elif parent_centroids[n_parents, 1] < -180: parent_centroids[n_parents, 1] += 360 parent_areas[n_parents] = area_0 parent_iterator += 1 n_parents += 1 # Need at least 2 parents for merging if n_parents < 2: continue # Create new IDs for each partition new_child_ids = np.arange(next_new_id, next_new_id + (n_parents - 1), dtype=np.int32) child_ids = np.concatenate((np.array([child_id]), new_child_ids)) # Record merge event curr_merge_idx = merge_counts[t] if curr_merge_idx > MAX_MERGES: # pragma: no cover raise TrackingError( "Too many merge operations", details=f"Timestep {t} requires {curr_merge_idx} merges (limit: {MAX_MERGES})", suggestions=[ "Increase area_filter_quartile to reduce small objects", "Consider adjusting tracking parameters", ], context={ "timestep": t, "merge_count": curr_merge_idx, "limit": MAX_MERGES, }, ) merge_child_ids[t, curr_merge_idx, :n_parents] = child_ids[:n_parents] merge_parent_ids[t, curr_merge_idx, :n_parents] = parent_ids[:n_parents] merge_areas[t, curr_merge_idx, :n_parents] = overlap_areas[:n_parents] merge_counts[t] += 1 has_merge[t] = True # Partition the child object based on parent associations if self.nn_partitioning: # Estimate maximum search distance based on object size max_area = parent_areas.max() / self.mean_cell_area max_distance = int(np.sqrt(max_area) * 2.0) # Use optimised nearest-neighbor partitioning new_labels_uint = partition_nn_unstructured_optimised( child_mask.copy(), parent_masks_uint.copy(), parent_centroids, neighbours_int.copy(), lat, lon, max_distance=max(max_distance, 20) * 2, ) # Returned 'new_labels_uint' is just the index of the child_ids new_labels = child_ids[new_labels_uint] # Help garbage collection new_labels_uint = None else: # Use centroid-based partitioning new_labels = partition_centroid_unstructured(child_mask, parent_centroids, child_ids, lat, lon) # Update slice data for subsequent merging in process_chunk data_t[child_mask] = new_labels # Record which cells get which new IDs for later updates spatial_indices_all = np.where(child_mask)[0].astype(np.int32) child_mask = None # Free memory gc.collect() # Record update information for each new ID for new_id in child_ids[1:]: update_idx = np.where(updates_ids[t] == -1)[0].astype(np.int32)[ 0 ] # Find next non-negative index in updates_ids updates_ids[t, update_idx] = new_id updates_array[t, spatial_indices_all[new_labels == new_id]] = update_idx next_new_id += n_parents - 1 # Find all child objects in the next timestep that overlap with our newly labeled regions new_merging_list = [] for new_id in child_ids: parent_mask = data_t == new_id if np.any(parent_mask): area_0 = area[parent_mask].sum() potential_children = np.unique(data_p1[parent_mask]) for potential_child in potential_children[potential_children > 0]: potential_child_mask = data_p1 == potential_child area_1 = area[potential_child_mask].sum() min_area = min(area_0, area_1) overlap_area = area[parent_mask & potential_child_mask].sum() if overlap_area / min_area > self.overlap_threshold: new_merging_list.append(potential_child) # Add newly found merging objects to processing queue if t < n_time - 1: # Add to next timestep in this chunk for new_object_id in new_merging_list: if new_object_id not in merging_objects_list[t + 1]: merging_objects_list[t + 1].append(new_object_id) else: # Record for next chunk for new_object_id in new_merging_list: if final_merge_count > MAX_MERGES: # pragma: no cover raise TrackingError( "Excessive merge operations detected", details=f"Final merge count {final_merge_count} exceeds limit {MAX_MERGES} at timestep {t}", suggestions=[ "Increase area_filter_quartile to reduce small objects", "Consider adjusting tracking parameters", ], context={ "timestep": t, "final_merge_count": final_merge_count, "limit": MAX_MERGES, }, ) if not np.any(final_merging_objects[t][:final_merge_count] == new_object_id): final_merging_objects[t][final_merge_count] = new_object_id final_merge_count += 1 return ( merge_child_ids, merge_parent_ids, merge_areas, merge_counts, has_merge, updates_array, updates_ids, final_merging_objects, ) def update_object_id_field_inplace( object_id_field: xr.DataArray, id_lookup: Dict[int, int], updates_array: xr.DataArray, updates_ids: xr.DataArray, has_merge: xr.DataArray, ) -> xr.DataArray: # pragma: no cover """ Update the object field with chunk results using xarray operations. This is memory efficient as it avoids creating full copies of the object_id_field. Parameters ---------- object_id_field : xarray.DataArray The full object field to update id_lookup : dict Dictionary mapping temporary IDs to new IDs updates_array : xarray.DataArray Array indicating which spatial indices to update updates_ids : xarray.DataArray The new IDs to assign to updated indices has_merge : xarray.DataArray Boolean indicating whether each timestep has merges Returns ------- xarray.DataArray Updated object field """ # Quick return if no merges to update if not has_merge.any(): return object_id_field def update_timeslice( data: NDArray[np.int32], updates: NDArray[np.uint8], update_ids: NDArray[np.int32], lookup_values: NDArray[np.int32], ) -> NDArray[np.int32]: """Process a single timeslice.""" # Extract valid update IDs valid_ids = update_ids[update_ids > -1] if len(valid_ids) == 0: return data # Create result array starting with original values result = data.copy() # Apply each update for idx, update_id in enumerate(valid_ids): mask = updates == idx if mask.any(): result = np.where(mask, lookup_values[update_id], result) return result # Convert lookup dict to array for vectorized access max_id = max(id_lookup.keys()) + 1 lookup_array = np.full(max_id, -1, dtype=np.int32) for temp_id, new_id in id_lookup.items(): lookup_array[temp_id] = new_id # Apply updates in parallel result = xr.apply_ufunc( update_timeslice, object_id_field, updates_array, updates_ids, kwargs={"lookup_values": lookup_array}, input_core_dims=[[self.xdim], [self.xdim], ["update_idx"]], output_core_dims=[[self.xdim]], vectorize=True, dask="parallelized", output_dtypes=[np.int32], ) return result def update_object_id_field_zarr( self, object_id_field: xr.DataArray, id_lookup: Dict[int, int], updates_array: xr.DataArray, updates_ids: xr.DataArray, has_merge: xr.DataArray, ) -> xr.DataArray: """ Update object field using a temporary zarr store for better memory efficiency. This approach minimises memory usage by writing changes directly to disk, allowing for more efficient parallel processing of large datasets. Parameters ---------- object_id_field : xarray.DataArray The object field to update id_lookup : dict Dictionary mapping temporary IDs to new IDs updates_array : xarray.DataArray Array indicating which spatial indices to update updates_ids : xarray.DataArray The new IDs to assign to updated indices has_merge : xarray.DataArray Boolean indicating whether each timestep has merges Returns ------- xarray.DataArray Updated object field from zarr store """ # Early return if no merges to save memory if not bool(has_merge.any().compute().item()): return object_id_field zarr_path = f"{self.scratch_dir}/marEx_temp_field.zarr/" # Initialise zarr store if needed if not os.path.exists(zarr_path): object_id_field.name = "temp" object_id_field.to_zarr(zarr_path, mode="w") def update_time_chunk(ds_chunk: xr.Dataset, lookup_dict: Dict[int, int]) -> xr.DataArray: """Process a single chunk with optimised memory usage.""" # Skip processing if no merges in this chunk needs_update = bool(ds_chunk["has_merge"].any().compute().item()) if not needs_update: return ds_chunk["object_field"] # Extract data from the chunk chunk_data = ds_chunk["object_field"] chunk_updates = ds_chunk["updates"] chunk_update_ids = ds_chunk["update_ids"] # Get zarr region indices time_idx_start = int(ds_chunk["time_indices"].values[0]) time_idx_end = int(ds_chunk["time_indices"].values[-1]) + 1 updated_chunk = chunk_data.copy() # Process each time slice in the chunk for t in range(chunk_data.sizes[self.timedim]): # Get update information for this time updates_slice = chunk_updates.isel({self.timedim: t}).values update_ids_slice = chunk_update_ids.isel({self.timedim: t}).values # Get valid update IDs valid_mask = update_ids_slice > -1 if not np.any(valid_mask): continue valid_ids = update_ids_slice[valid_mask] # Get the time slice data and apply updates result_slice = updated_chunk.isel({self.timedim: t}) for idx, update_id in enumerate(valid_ids): mask = updates_slice == idx if np.any(mask): new_id = lookup_dict.get(int(update_id), update_id) result_slice = xr.where(mask, new_id, result_slice) # Store updated slice updated_chunk[t] = result_slice # Write the updated chunk directly to zarr updated_chunk.name = "temp" updated_chunk.to_zarr( zarr_path, region={self.timedim: slice(time_idx_start, time_idx_end)}, ) return chunk_data # Return original data for dask graph consistency # Create time indices for slicing time_coords = object_id_field[self.timecoord].values time_indices = np.arange(len(time_coords), dtype=np.int32) time_index_da = xr.DataArray(time_indices, dims=[self.timedim], coords={self.timecoord: time_coords}) # Create dataset with all necessary components ds = xr.Dataset( { "object_field": object_id_field, "updates": updates_array, "update_ids": updates_ids, "time_indices": time_index_da, "has_merge": has_merge, } ).chunk({self.timedim: self.timechunks}) # Process chunks in parallel result = xr.map_blocks( update_time_chunk, ds, kwargs={"lookup_dict": id_lookup}, template=object_id_field, ) # Force computation to ensure all writes complete result = result.persist() wait(result) # Release resources del result, ds, object_id_field gc.collect() # Load the updated data from zarr store object_id_field_new = xr.open_zarr(zarr_path, chunks={self.timedim: self.timechunks}).temp return object_id_field_new def merge_objects_parallel_iteration( object_id_field_unique: xr.DataArray, merging_objects: Set[int], global_id_counter: int, ) -> Tuple[ xr.DataArray, # updated_field Tuple[ NDArray[np.int32], NDArray[np.int32], NDArray[np.float32], NDArray[np.int32], ], # merge_data Set[int], # new_merging_objects int, # updated_counter ]: """ Perform a single iteration of the parallel merging process. This function handles one complete batch of merging objects across all timesteps, updating object IDs and tracking merge events. Parameters ---------- object_id_field_unique : xarray.DataArray Field of unique object IDs merging_objects : set Set of object IDs to process in this iteration global_id_counter : int Current counter for assigning new global IDs Returns ------- tuple (updated_field, merge_data, new_merging_objects, updated_counter) """ n_time = len(object_id_field_unique[self.timecoord]) # Pre-allocate arrays for this iteration child_ids_iter = np.full( (n_time, MAX_MERGES, MAX_CHILDREN), -1, dtype=np.int32 ) # List of child ID arrays for this time parent_ids_iter = np.full( (n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.int32 ) # List of parent ID arrays for this time merge_areas_iter = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.float32) # List of areas for this time merge_counts_iter = np.zeros(n_time, dtype=np.int32) # Prepare neighbour information neighbours_int = self.neighbours_int.chunk({self.xdim: -1, "nv": -1}) logger.info(f"Processing Parallel Iteration {iteration + 1} with {len(merging_objects)} Merging Objects...") # Pre-compute the child_time_idx for merging_objects time_index_map = self.compute_id_time_dict(object_id_field_unique, list(merging_objects), global_id_counter) logger.debug("Finished Mapping Children to Time Indices") # Create uniform array of merging objects for each timestep max_merges = max(len([b for b in merging_objects if time_index_map.get(b, -1) == t]) for t in range(n_time)) uniform_merging_objects_array = np.zeros((n_time, max_merges), dtype=np.int32) for t in range(n_time): objects_at_t = [b for b in merging_objects if time_index_map.get(b, -1) == t] if objects_at_t: # Only fill if there are objects at this time uniform_merging_objects_array[t, : len(objects_at_t)] = np.array(objects_at_t, dtype=np.int32) # Create DataArrays for parallel processing merging_objects_da = xr.DataArray( uniform_merging_objects_array, dims=[self.timedim, "merges"], coords={self.timecoord: object_id_field_unique[self.timecoord]}, ) # Calculate ID offsets for each timestep to ensure unique IDs next_id_offsets = np.arange(n_time, dtype=np.int64) * max_merges * self.timechunks + global_id_counter # N.B.: We also need to account for possibility of newly-split objects subsequently creating # more than max_merges by the end of the iteration through the chunk # !!! This is likely the root cause of any errors such as "ID needs to be contiguous/continuous/full/unrepeated" next_id_offsets_da = xr.DataArray( next_id_offsets, dims=[self.timedim], coords={self.timecoord: object_id_field_unique[self.timecoord]}, ) # Create shifted arrays for time connectivity object_id_field_unique_p1 = object_id_field_unique.shift({self.timedim: -1}, fill_value=0) object_id_field_unique_m1 = object_id_field_unique.shift({self.timedim: 1}, fill_value=0) # Align chunks for better parallel processing object_id_field_unique_m1 = object_id_field_unique_m1.chunk({self.timedim: self.timechunks}) object_id_field_unique_p1 = object_id_field_unique_p1.chunk({self.timedim: self.timechunks}) merging_objects_da = merging_objects_da.chunk({self.timedim: self.timechunks}) next_id_offsets_da = next_id_offsets_da.chunk({self.timedim: self.timechunks}) # Process chunks in parallel results = xr.apply_ufunc( process_chunk, object_id_field_unique_m1, object_id_field_unique_p1, merging_objects_da, next_id_offsets_da, self.lat, self.lon, self.cell_area, neighbours_int, input_core_dims=[ [self.xdim], [self.xdim], ["merges"], [], [self.xdim], [self.xdim], [self.xdim], ["nv", self.xdim], ], output_core_dims=[ ["merge", "parent"], ["merge", "parent"], ["merge", "parent"], [], [], [self.xdim], ["update_idx"], ["merge"], ], output_dtypes=[ np.int32, np.int32, np.float32, np.int16, np.bool_, np.uint8, np.int32, np.int32, ], dask_gufunc_kwargs={ "output_sizes": { "merge": MAX_MERGES, "parent": MAX_PARENTS, "update_idx": 255, } }, vectorize=False, dask="parallelized", ) # Unpack and persist results ( merge_child_ids, merge_parent_ids, merge_areas, merge_counts, has_merge, updates_array, updates_ids, final_merging_objects, ) = results results = persist( merge_child_ids, merge_parent_ids, merge_areas, merge_counts, has_merge, updates_array, updates_ids, final_merging_objects, ) ( merge_child_ids, merge_parent_ids, merge_areas, merge_counts, has_merge, updates_array, updates_ids, final_merging_objects, ) = results # Get time indices where merges occurred has_merge = has_merge.compute() time_indices = np.where(has_merge)[0].astype(np.int32) # Clean up temporary arrays to save memory del ( object_id_field_unique_p1, object_id_field_unique_m1, merging_objects_da, next_id_offsets_da, ) gc.collect() logger.debug("Finished Batch Processing Step") # ====== Global Consolidation of Data ====== # 1. Collect all temporary IDs and create global mapping all_temp_ids = np.unique(merge_child_ids.where(merge_child_ids >= global_id_counter, other=0).compute().values) all_temp_ids = all_temp_ids[all_temp_ids > 0] # Remove the 0 if not len(all_temp_ids): # If no temporary IDs exist id_lookup = {} else: # Create mapping from temporary to permanent IDs id_lookup = { temp_id: np.int32(new_id) for temp_id, new_id in zip( all_temp_ids, range(global_id_counter, global_id_counter + len(all_temp_ids)), ) } global_id_counter += len(all_temp_ids) logger.debug("Finished Consolidation Step 1: Temporary ID Mapping") # 2. Update object ID field with new IDs update_on_disk = True # This is more memory efficient because it refreshes the dask graph every iteration if update_on_disk: object_id_field_unique = update_object_id_field_zarr( self, object_id_field_unique, id_lookup, updates_array, updates_ids, has_merge, ) else: # pragma: no cover object_id_field_unique = update_object_id_field_inplace( object_id_field_unique, id_lookup, updates_array, updates_ids, has_merge, ) object_id_field_unique = object_id_field_unique.chunk( {self.timedim: self.timechunks} ) # Rechunk to avoid accumulating chunks... # Clean up arrays no longer needed del updates_array, updates_ids gc.collect() logger.debug("Finished Consolidation Step 2: Data Field Update") # 3. Update merge events new_merging_objects = set() merge_counts = merge_counts.compute() for t in time_indices: count = int(merge_counts.isel({self.timedim: t}).item()) if count > 0: merge_counts_iter[t] = count # Extract valid IDs and areas for each merge event for merge_idx in range(count): # Get child IDs child_ids = merge_child_ids.isel({self.timedim: t, "merge": merge_idx}).compute().values child_ids = child_ids[child_ids >= 0] # Get parent IDs and areas parent_ids = merge_parent_ids.isel({self.timedim: t, "merge": merge_idx}).compute().values areas = merge_areas.isel({self.timedim: t, "merge": merge_idx}).compute().values valid_mask = parent_ids >= 0 parent_ids = parent_ids[valid_mask] areas = areas[valid_mask] # Map temporary IDs to permanent IDs mapped_child_ids = [id_lookup.get(int(id_.item()), int(id_.item())) for id_ in child_ids] mapped_parent_ids = [id_lookup.get(int(id_.item()), int(id_.item())) for id_ in parent_ids] # Store in pre-allocated arrays child_ids_iter[t, merge_idx, : len(mapped_child_ids)] = mapped_child_ids parent_ids_iter[t, merge_idx, : len(mapped_parent_ids)] = mapped_parent_ids merge_areas_iter[t, merge_idx, : len(areas)] = areas # Process final merging objects for next iteration final_merging_objects = final_merging_objects.compute().values final_merging_objects = final_merging_objects[final_merging_objects > 0] mapped_final_objects = [id_lookup.get(id_, id_) for id_ in final_merging_objects] new_merging_objects.update(mapped_final_objects) logger.debug("Finished Consolidation Step 3: Merge List Dictionary Consolidation") # Clean up memory del merge_child_ids, merge_parent_ids, merge_areas, merge_counts, has_merge gc.collect() return ( object_id_field_unique, (child_ids_iter, parent_ids_iter, merge_areas_iter, merge_counts_iter), new_merging_objects, global_id_counter, ) # ============================ # Main Loop for Parallel Merging # ============================ # Find overlapping objects overlap_objects_list = self.find_overlapping_objects( object_id_field_unique ) # List object pairs that overlap by at least overlap_threshold percent overlap_objects_list = self.enforce_overlap_threshold(overlap_objects_list, object_props) logger.info("Finished finding overlapping objects") # Find initial merging objects unique_children, children_counts = np.unique(overlap_objects_list[:, 1], return_counts=True) merging_objects = set(unique_children[children_counts > 1].astype(np.int32)) del overlap_objects_list # Process chunks iteratively until no new merging objects remain iteration = 0 processed_chunks = set() global_id_counter = int(object_props.ID.max().item()) + 1 # Initialise global merge event tracking global_child_ids = [] global_parent_ids = [] global_merge_areas = [] global_merge_tidx = [] while merging_objects and iteration < self.max_iteration: ( object_id_field_new, merge_data_iter, new_merging_objects, global_id_counter, ) = merge_objects_parallel_iteration(object_id_field_unique, merging_objects, global_id_counter) child_ids_iter, parent_ids_iter, merge_areas_iter, merge_counts_iter = merge_data_iter # Consolidate merge events from this iteration for t in range(len(merge_counts_iter)): count = merge_counts_iter[t] if count > 0: for merge_idx in range(count): # Extract valid children children = child_ids_iter[t, merge_idx] children = children[children >= 0] # Extract valid parents and areas parents = parent_ids_iter[t, merge_idx] areas = merge_areas_iter[t, merge_idx] valid_mask = parents >= 0 parents = parents[valid_mask] areas = areas[valid_mask] # Record valid merge events if len(children) > 0 and len(parents) > 0: global_child_ids.append(children) global_parent_ids.append(parents) global_merge_areas.append(areas) global_merge_tidx.append(t) # Prepare for next iteration - only process objects not already handled merging_objects = new_merging_objects - processed_chunks processed_chunks.update(new_merging_objects) iteration += 1 # Update the object field object_id_field_unique = object_id_field_new del object_id_field_new # Check if we reached maximum iterations if iteration == self.max_iteration: # pragma: no cover raise TrackingError( "Maximum iterations reached in tracking algorithm", details=f"Algorithm failed to converge after {self.max_iteration} iterations", suggestions=[ "Increase max_iteration parameter", "Increase area_filter_quartile to reduce small objects", "Consider adjusting tracking parameters", ], context={ "max_iteration": self.max_iteration, "reached_iteration": iteration, }, ) # Process the collected merge events times = object_id_field_unique[self.timecoord].values # Find maximum dimensions for arrays # Handle case where there are no merge events if global_parent_ids and global_child_ids: max_parents = max(len(ids) for ids in global_parent_ids) max_children = max(len(ids) for ids in global_child_ids) else: max_parents = 1 # Default minimum size max_children = 1 # Create padded arrays for merge events parent_ids_array = np.full((len(global_parent_ids), max_parents), -1, dtype=np.int32) child_ids_array = np.full((len(global_child_ids), max_children), -1, dtype=np.int32) overlap_areas_array = np.full( (len(global_merge_areas), max_parents), -1, dtype=np.float32 if self.unstructured_grid else np.int32, ) # Fill arrays with merge data for i, parents in enumerate(global_parent_ids): parent_ids_array[i, : len(parents)] = parents for i, children in enumerate(global_child_ids): child_ids_array[i, : len(children)] = children for i, areas in enumerate(global_merge_areas): overlap_areas_array[i, : len(areas)] = areas # Create merge events dataset merge_events = xr.Dataset( { "parent_IDs": (("merge_ID", "parent_idx"), parent_ids_array), "child_IDs": (("merge_ID", "child_idx"), child_ids_array), "overlap_areas": (("merge_ID", "parent_idx"), overlap_areas_array), "merge_time": ("merge_ID", times[global_merge_tidx]), "n_parents": ( "merge_ID", np.array([len(p) for p in global_parent_ids], dtype=np.int8), ), "n_children": ( "merge_ID", np.array([len(c) for c in global_child_ids], dtype=np.int8), ), }, attrs={"fill_value": -1}, ) # Recompute object properties and overlaps after all merging object_id_field_unique = object_id_field_unique.persist(optimize_graph=True) object_props = self.calculate_object_properties(object_id_field_unique, properties=["area", "centroid"]).persist( optimize_graph=True ) # Recompute overlaps based on final object configuration overlap_objects_list = self.find_overlapping_objects(object_id_field_unique) overlap_objects_list = self.enforce_overlap_threshold(overlap_objects_list, object_props) overlap_objects_list = overlap_objects_list[:, :2].astype(np.int32) return ( object_id_field_unique, object_props, overlap_objects_list, merge_events, )
""" MarEx Helper Functions These are the remaining implementations of helper functions for the MarEx package, providing optimised algorithms for partitioning, distance calculations, and spatial operations on both structured and unstructured grids. """
[docs] @jit(nopython=True, parallel=True, fastmath=True) def wrapped_euclidian_distance_mask_parallel( mask_values: NDArray[np.bool_], parent_centroids_values: NDArray[np.float64], Nx: int, wrap: bool, ) -> NDArray[np.float64]: # pragma: no cover """ Optimised function for computing wrapped Euclidean distances. Efficiently calculates distances between points in a binary mask and a set of centroids, accounting for periodic boundaries in the x dimension. Parameters ---------- mask_values : np.ndarray 2D boolean array where True indicates points to calculate distances for parent_centroids_values : np.ndarray Array of shape (n_parents, 2) containing (y, x) coordinates of parent centroids Nx : int Size of the x-dimension for periodic boundary wrapping wrap : bool Whether to treat x-dimension as periodic and wrap Returns ------- distances : np.ndarray Array of shape (n_true_points, n_parents) with minimum distances """ n_parents = len(parent_centroids_values) half_Nx = Nx / 2 y_indices, x_indices = np.nonzero(mask_values) n_true = len(y_indices) distances = np.empty((n_true, n_parents), dtype=np.float64) # Precompute for faster access parent_y = parent_centroids_values[:, 0] parent_x = parent_centroids_values[:, 1] # Parallel loop over true positions for idx in prange(n_true): y, x = y_indices[idx], x_indices[idx] # Pre-compute y differences for all parents dy = y - parent_y # Pre-compute x differences for all parents dx = x - parent_x # Wrapping correction if wrap: dx = np.where(dx > half_Nx, dx - Nx, dx) dx = np.where(dx < -half_Nx, dx + Nx, dx) distances[idx] = np.sqrt(dy * dy + dx * dx) return distances
[docs] @jit(nopython=True, fastmath=True) def create_grid_index_arrays( points_y: NDArray[np.int32], points_x: NDArray[np.int32], grid_size: int, ny: int, nx: int, ) -> Tuple[NDArray[np.int32], NDArray[np.int32]]: # pragma: no cover """ Create a grid-based spatial index for efficient point lookup. This function divides space into a grid and assigns points to grid cells for more efficient spatial queries compared to brute force comparisons. Parameters ---------- points_y, points_x : np.ndarray Coordinates of points to index grid_size : int Size of each grid cell ny, nx : int Dimensions of the overall grid Returns ------- grid_points : np.ndarray 3D array mapping grid cells to point indices grid_counts : np.ndarray 2D array with count of points in each grid cell """ n_grids_y = (ny + grid_size - 1) // grid_size n_grids_x = (nx + grid_size - 1) // grid_size max_points_per_cell = len(points_y) grid_points = np.full((n_grids_y, n_grids_x, max_points_per_cell), -1, dtype=np.int32) grid_counts = np.zeros((n_grids_y, n_grids_x), dtype=np.int32) for idx in range(len(points_y)): grid_y = min(points_y[idx] // grid_size, n_grids_y - 1) grid_x = min(points_x[idx] // grid_size, n_grids_x - 1) count = grid_counts[grid_y, grid_x] if count < max_points_per_cell: grid_points[grid_y, grid_x, count] = idx grid_counts[grid_y, grid_x] += 1 return grid_points, grid_counts
[docs] @jit(nopython=True, fastmath=True) def wrapped_euclidian_distance_points( y1: float, x1: float, y2: float, x2: float, nx: int, half_nx: float, wrap: bool ) -> float: # pragma: no cover """ Calculate distance with periodic boundary conditions in x dimension. Parameters ---------- y1, x1 : float Coordinates of first point y2, x2 : float Coordinates of second point nx : int Size of x dimension half_nx : float Half the size of x dimension wrap : bool Whether to apply periodic boundary conditions in x Returns ------- float Euclidean distance accounting for periodic boundary in x (or not) """ dy = y1 - y2 dx = x1 - x2 if wrap: if dx > half_nx: dx -= nx elif dx < -half_nx: dx += nx return np.sqrt(dy * dy + dx * dx)
[docs] @jit(nopython=True, parallel=True, fastmath=True) def partition_nn_grid( child_mask: NDArray[np.bool_], parent_masks: NDArray[np.bool_], child_ids: NDArray[np.int32], parent_centroids: NDArray[np.float64], Nx: int, max_distance: int = 20, wrap: bool = True, ) -> NDArray[np.int32]: # pragma: no cover """ Partition a child object based on nearest parent object points. This implementation uses spatial indexing and highly-threaded processing for efficient distance calculations. The algorithm assigns each point in the child object to the closest parent object. Parameters ---------- child_mask : np.ndarray Binary mask of the child object parent_masks : np.ndarray List of binary masks for each parent object child_ids : np.ndarray List of IDs to assign to partitions parent_centroids : np.ndarray Array of shape (n_parents, 2) with parent centroids Nx : int Size of x dimension for periodic boundaries max_distance : int, default=20 Maximum search distance wrap : bool, default=True Whether to apply periodic boundary conditions in the x dimension Returns ------- new_labels : np.ndarray Array containing assigned child_ids for each point """ ny, nx = child_mask.shape half_Nx = Nx / 2 n_parents = len(parent_masks) grid_size = max(2, max_distance // 4) y_indices, x_indices = np.nonzero(child_mask) n_child_points = len(y_indices) min_distances = np.full(n_child_points, np.inf) parent_assignments = np.zeros(n_child_points, dtype=np.int32) found_close = np.zeros(n_child_points, dtype=np.bool_) for parent_idx in range(n_parents): py, px = np.nonzero(parent_masks[parent_idx]) if len(py) == 0: # Skip empty parents continue # Create grid index for this parent n_grids_y = (ny + grid_size - 1) // grid_size n_grids_x = (nx + grid_size - 1) // grid_size grid_points, grid_counts = create_grid_index_arrays(py, px, grid_size, ny, nx) # Process child points in parallel for child_idx in prange(n_child_points): if found_close[child_idx]: # Skip if we already found an exact match continue child_y, child_x = y_indices[child_idx], x_indices[child_idx] grid_y = min(child_y // grid_size, n_grids_y - 1) grid_x = min(child_x // grid_size, n_grids_x - 1) min_dist_to_parent = np.inf # Check nearby grid cells for dy in range(-1, 2): grid_y_check = (grid_y + dy) % n_grids_y for dx in range(-1, 2): grid_x_check = (grid_x + dx) % n_grids_x # Process points in this grid cell n_points = grid_counts[grid_y_check, grid_x_check] for p_idx in range(n_points): point_idx = grid_points[grid_y_check, grid_x_check, p_idx] if point_idx == -1: break dist = wrapped_euclidian_distance_points(child_y, child_x, py[point_idx], px[point_idx], Nx, half_Nx, wrap) if dist > max_distance: continue if dist < min_dist_to_parent: min_dist_to_parent = dist if dist < 1e-6: # Found exact same point (within numerical precision) min_dist_to_parent = dist found_close[child_idx] = True break if found_close[child_idx]: break if found_close[child_idx]: break # Update assignment if this parent is closer if min_dist_to_parent < min_distances[child_idx]: min_distances[child_idx] = min_dist_to_parent parent_assignments[child_idx] = parent_idx # Handle any unassigned points using centroids unassigned = min_distances == np.inf if np.any(unassigned): for child_idx in np.nonzero(unassigned)[0]: child_y, child_x = y_indices[child_idx], x_indices[child_idx] min_dist = np.inf best_parent = 0 for parent_idx in range(n_parents): # Calculate distance to centroid with periodic boundary conditions dist = wrapped_euclidian_distance_points( child_y, child_x, parent_centroids[parent_idx, 0], parent_centroids[parent_idx, 1], Nx, half_Nx, wrap, ) if dist < min_dist: min_dist = dist best_parent = parent_idx parent_assignments[child_idx] = best_parent # Convert from parent indices to child_ids new_labels = child_ids[parent_assignments] return new_labels
[docs] @jit(nopython=True, fastmath=True) def partition_nn_unstructured( child_mask: NDArray[np.bool_], parent_masks: NDArray[np.bool_], child_ids: NDArray[np.int32], parent_centroids: NDArray[np.float64], neighbours_int: NDArray[np.int32], lat: NDArray[np.float32], lon: NDArray[np.float32], max_distance: int = 20, ) -> NDArray[np.int32]: # pragma: no cover """ Partition a child object on an unstructured grid based on nearest parent points. This function implements an efficient algorithm for assigning each cell in a child object to the nearest parent object, using graph traversal and spatial distances. It is optimised for unstructured grids. Parameters ---------- child_mask : np.ndarray 1D boolean array where True indicates points in the child object parent_masks : np.ndarray 2D boolean array of shape (n_parents, n_points) where True indicates points in each parent object child_ids : np.ndarray 1D array containing the IDs to assign to each partition of the child object parent_centroids : np.ndarray Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees neighbours_int : np.ndarray 2D array of shape (3, n_points) containing indices of neighboring cells for each point lat, lon : np.ndarray Latitude/longitude arrays in degrees max_distance : int, default=20 Maximum number of edge hops to search for parent points Returns ------- new_labels : np.ndarray 1D array containing the assigned child_ids for each True point in child_mask """ # Force contiguous arrays in memory for optimal vectorised performance child_mask = np.ascontiguousarray(child_mask) parent_masks = np.ascontiguousarray(parent_masks) n_points = len(child_mask) n_parents = len(parent_masks) # Pre-allocate arrays distances = np.full(n_points, np.inf, dtype=np.float32) parent_assignments = np.full(n_points, -1, dtype=np.int32) visited = np.zeros((n_parents, n_points), dtype=np.bool_) # Initialise with direct overlaps for parent_idx in range(n_parents): overlap_mask = parent_masks[parent_idx] & child_mask if np.any(overlap_mask): visited[parent_idx, overlap_mask] = True unclaimed_overlap = distances[overlap_mask] == np.inf if np.any(unclaimed_overlap): overlap_points = np.where(overlap_mask)[0].astype(np.int32) valid_points = overlap_points[unclaimed_overlap] distances[valid_points] = 0 parent_assignments[valid_points] = parent_idx # Pre-compute trig values for efficiency lat_rad = np.deg2rad(lat) lon_rad = np.deg2rad(lon) cos_lat = np.cos(lat_rad) # Graph traversal for remaining points - expanding from parent frontiers current_distance = 0 any_unassigned = np.any(child_mask & (parent_assignments == -1)) while current_distance < max_distance and any_unassigned: current_distance += 1 updates_made = False for parent_idx in range(n_parents): # Get current frontier points frontier_mask = visited[parent_idx] if not np.any(frontier_mask): continue # Process neighbors for i in range(3): # For each neighbor direction neighbors = neighbours_int[i, frontier_mask] valid_neighbors = neighbors >= 0 if not np.any(valid_neighbors): continue valid_points = neighbors[valid_neighbors] unvisited = ~visited[parent_idx, valid_points] new_points = valid_points[unvisited] if len(new_points) > 0: visited[parent_idx, new_points] = True update_mask = distances[new_points] > current_distance if np.any(update_mask): points_to_update = new_points[update_mask] distances[points_to_update] = current_distance parent_assignments[points_to_update] = parent_idx updates_made = True if not updates_made: break any_unassigned = np.any(child_mask & (parent_assignments == -1)) # Handle remaining unassigned points using great circle distances unassigned_mask = child_mask & (parent_assignments == -1) if np.any(unassigned_mask): parent_lat_rad = np.deg2rad(parent_centroids[:, 0]) parent_lon_rad = np.deg2rad(parent_centroids[:, 1]) cos_parent_lat = np.cos(parent_lat_rad) unassigned_points = np.where(unassigned_mask)[0].astype(np.int32) for point in unassigned_points: # Vectoised haversine calculation dlat = parent_lat_rad - lat_rad[point] dlon = parent_lon_rad - lon_rad[point] a = np.sin(dlat / 2) ** 2 + cos_lat[point] * cos_parent_lat * np.sin(dlon / 2) ** 2 dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) parent_assignments[point] = np.argmin(dist).astype(np.int32) # Return only the assignments for points in child_mask child_points = np.where(child_mask)[0].astype(np.int32) return child_ids[parent_assignments[child_points]]
[docs] @jit(nopython=True, fastmath=True) def partition_nn_unstructured_optimised( child_mask: NDArray[np.bool_], parent_frontiers: NDArray[np.uint8], parent_centroids: NDArray[np.float64], neighbours_int: NDArray[np.int32], lat: NDArray[np.float32], lon: NDArray[np.float32], max_distance: int = 20, ) -> NDArray[np.uint8]: # pragma: no cover """ Memory-optimised nearest neighbor partitioning for unstructured grids. This version uses more efficient memory management compared to partition_nn_unstructured, making it suitable for very large grids. It uses a compact representation of parent frontiers to reduce memory usage during graph traversal. Parameters ---------- child_mask : np.ndarray 1D boolean array indicating which cells belong to the child object parent_frontiers : np.ndarray 1D uint8 array with parent indices (255 for unvisited points) parent_centroids : np.ndarray Array of shape (n_parents, 2) containing (lat, lon) coordinates neighbours_int : np.ndarray 2D array of shape (3, n_points) containing indices of neighboring cells lat, lon : np.ndarray 1D arrays of latitude/longitude in degrees max_distance : int, default=20 Maximum number of edge hops to search for parent points Returns ------- result : np.ndarray 1D array containing the assigned parent indices for points in child_mask """ # Create working copies to ensure memory cleanup parent_frontiers_working = parent_frontiers.copy() child_mask_working = child_mask.copy() n_parents = np.max(parent_frontiers_working[parent_frontiers_working < 255]) + 1 # Graph traversal - expanding frontiers current_distance = 0 any_unassigned = np.any(child_mask_working & (parent_frontiers_working == 255)) while current_distance < max_distance and any_unassigned: current_distance += 1 updates_made = False for parent_idx in range(n_parents): # Skip if no frontier points for this parent if not np.any(parent_frontiers_working == parent_idx): continue # Process neighbours for current parent's frontier for i in range(3): neighbors = neighbours_int[i, parent_frontiers_working == parent_idx] valid_neighbors = neighbors >= 0 if not np.any(valid_neighbors): continue valid_points = neighbors[valid_neighbors] unvisited = parent_frontiers_working[valid_points] == 255 if not np.any(unvisited): continue # Update new frontier points new_points = valid_points[unvisited] parent_frontiers_working[new_points] = parent_idx if np.any(child_mask_working[new_points]): updates_made = True if not updates_made: break any_unassigned = np.any(child_mask_working & (parent_frontiers_working == 255)) # Handle remaining unassigned points using great circle distances unassigned_mask = child_mask_working & (parent_frontiers_working == 255) if np.any(unassigned_mask): # Pre-compute parent coordinates in radians parent_lat_rad = np.deg2rad(parent_centroids[:, 0]) parent_lon_rad = np.deg2rad(parent_centroids[:, 1]) cos_parent_lat = np.cos(parent_lat_rad) # Process each unassigned point unassigned_points = np.where(unassigned_mask)[0].astype(np.int32) for point in unassigned_points: dlat = parent_lat_rad - np.deg2rad(lat[point]) dlon = parent_lon_rad - np.deg2rad(lon[point]) a = np.sin(dlat / 2) ** 2 + np.cos(np.deg2rad(lat[point])) * cos_parent_lat * np.sin(dlon / 2) ** 2 dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) parent_frontiers_working[point] = np.int32(np.argmin(dist)) # Extract result for child points only result = parent_frontiers_working[child_mask_working].copy() # Explicitly clear working arrays to help with memory management parent_frontiers_working = None child_mask_working = None return result
[docs] @jit(nopython=True, parallel=True, fastmath=True) def partition_centroid_unstructured( child_mask: NDArray[np.bool_], parent_centroids: NDArray[np.float64], child_ids: NDArray[np.int32], lat: NDArray[np.float32], lon: NDArray[np.float32], ) -> NDArray[np.int32]: # pragma: no cover """ Partition a child object based on closest parent centroids on an unstructured grid. This function assigns each cell in the child object to the parent with the closest centroid, using great circle distances on a spherical grid. Parameters ---------- child_mask : np.ndarray 1D boolean array indicating which cells belong to the child object parent_centroids : np.ndarray Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees child_ids : np.ndarray Array of IDs to assign to each partition of the child object lat, lon : np.ndarray Latitude/longitude arrays in degrees Returns ------- new_labels : np.ndarray 1D array containing assigned child_ids for cells in child_mask """ n_cells = len(child_mask) n_parents = len(parent_centroids) # Convert to radians for spherical calculations lat_rad = np.deg2rad(lat) lon_rad = np.deg2rad(lon) parent_coords_rad = np.deg2rad(parent_centroids) new_labels = np.zeros(n_cells, dtype=child_ids.dtype) # Process each child cell in parallel for i in prange(n_cells): if not child_mask[i]: continue min_dist = np.inf closest_parent = 0 # Calculate great circle distance to each parent centroid for j in range(n_parents): dlat = parent_coords_rad[j, 0] - lat_rad[i] dlon = parent_coords_rad[j, 1] - lon_rad[i] # Use haversine formula for great circle distance a = np.sin(dlat / 2) ** 2 + np.cos(lat_rad[i]) * np.cos(parent_coords_rad[j, 0]) * np.sin(dlon / 2) ** 2 dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) if dist < min_dist: min_dist = dist closest_parent = j new_labels[i] = child_ids[closest_parent] return new_labels
[docs] @njit(fastmath=True, parallel=True) def sparse_bool_power( vec: NDArray[np.bool_], sp_data: NDArray[np.bool_], indices: NDArray[np.int32], indptr: NDArray[np.int32], exponent: int, ) -> NDArray[np.bool_]: # pragma: no cover """ Efficient sparse boolean matrix power operation. This function implements a fast sparse matrix power operation for boolean matrices, avoiding memory leaks present in scipy+Dask implementations. It's used for efficient morphological operations on unstructured grids. Parameters ---------- vec : np.ndarray Boolean vector to multiply sp_data, indices, indptr : np.ndarray Sparse matrix in CSR format exponent : int Number of times to apply the matrix Returns ------- np.ndarray Result of (sparse_matrix ^ exponent) * vec """ vec = vec.T num_rows = indptr.size - 1 num_cols = vec.shape[1] result = vec.copy() for _ in range(exponent): temp_result = np.zeros((num_rows, num_cols), dtype=np.bool_) for i in prange(num_rows): for j in range(indptr[i], indptr[i + 1]): if sp_data[j]: for k in range(num_cols): if result[indices[j], k]: temp_result[i, k] = True result = temp_result return result.T
[docs] def regional_tracker( data_bin: xr.DataArray, mask: xr.DataArray, coordinate_units: Literal["degrees", "radians"], R_fill: Union[int, float], area_filter_quartile: Optional[float] = None, area_filter_absolute: Optional[int] = None, **kwargs, ) -> "tracker": """ Create a tracker instance configured for regional (non-global) data. This is a convenience function that automatically sets regional_mode=True and requires explicit specification of coordinate units, since auto-detection may fail for regional coordinate ranges. Parameters ---------- data_bin : xr.DataArray Binary data to identify and track objects in (True = object, False = background) mask : xr.DataArray Binary mask indicating valid regions (True = valid, False = invalid) coordinate_units : {'degrees', 'radians'} Units of the coordinate system. Must be specified for regional data. R_fill : int or float Radius for filling holes/gaps in spatial domain (in grid cells) area_filter_quartile : float, optional Quantile (0-1) for filtering smallest objects (e.g., 0.25 removes smallest 25%). Mutually exclusive with area_filter_absolute. Default is 0.5 if neither parameter is provided. area_filter_absolute : int, optional The minimum area (in grid cells) for an object to be retained. Mutually exclusive with area_filter_quartile. **kwargs Additional parameters passed to the tracker class Returns ------- tracker Configured tracker instance with regional_mode=True Examples -------- Track events in regional Mediterranean Sea data: >>> import marEx >>> # For regional data with degree coordinates >>> regional_tracker = marEx.regional_tracker( ... extreme_events, ... mask, ... coordinate_units='degrees', ... R_fill=5, ... area_filter_quartile=0.3 ... ) >>> events = regional_tracker.run() Track events in regional data with radian coordinates: >>> # For model output with radian coordinates >>> regional_tracker = marEx.regional_tracker( ... extreme_events, ... mask, ... coordinate_units='radians', ... R_fill=8, ... area_filter_quartile=0.5 ... ) >>> events = regional_tracker.run() Using absolute area filtering in regional mode: >>> # Keep only features larger than 15 grid cells >>> absolute_regional = marEx.regional_tracker( ... extreme_events, ... mask, ... coordinate_units='degrees', ... R_fill=5, ... area_filter_absolute=15 ... ) >>> events = absolute_regional.run() """ return tracker( data_bin=data_bin, mask=mask, R_fill=R_fill, area_filter_quartile=area_filter_quartile, area_filter_absolute=area_filter_absolute, regional_mode=True, coordinate_units=coordinate_units, **kwargs, )