Source code for marEx.detect

"""
MarEx-Detect: Marine Extremes Detection Module

Preprocessing toolkit for marine extremes identification from scalar oceanographic data.
Converts raw time series into standardised anomalies and identifies extreme events
(e.g., Marine Heatwaves using Sea Surface Temperature).

Core capabilities:

* Two preprocessing methodologies: Detrended Baseline and Shifting Baseline
* Two definitions for extreme events: Global Extreme and Hobday Extreme
* Threshold-based extreme event identification
* Efficient processing of both structured (gridded) and unstructured data

Compatible data formats:

* Structured data: 3D arrays (time, lat, lon)
* Unstructured data: 2D arrays (time, cell)
"""

import logging
import warnings
from typing import Dict, List, Literal, Optional, Tuple

import dask
import flox.xarray
import numpy as np
import pandas as pd
import xarray as xr
from dask import persist
from dask.base import is_dask_collection
from numpy.lib.stride_tricks import sliding_window_view
from numpy.typing import NDArray
from xhistogram.xarray import histogram

# Coordinate validation imports removed
from .exceptions import ConfigurationError, create_data_validation_error
from .helper import checkpoint_to_zarr, fix_dask_tuple_array
from .logging_config import configure_logging, get_logger, log_dask_info, log_memory_usage, log_timing

# Get module logger
logger = get_logger(__name__)

# Suppress noisy distributed logging
logging.getLogger("distributed.shuffle._scheduler_plugin").setLevel(logging.ERROR)


# ============================
# Validation Functions
# ============================


def _validate_dimensions_exist(da: xr.DataArray, dimensions: Dict[str, str]) -> None:
    """
    Validate that all specified dimensions exist in the dataset.

    Parameters
    ----------
    da : xarray.DataArray
        Input data array to validate
    dimensions : dict
        Mapping of conceptual dimensions to actual dimension names

    Raises
    ------
    DataValidationError
        If any specified dimension does not exist in the dataset
    """
    missing_dims = []
    for concept_dim, actual_dim in dimensions.items():
        if actual_dim not in da.dims:
            missing_dims.append(f"'{actual_dim}' (for {concept_dim})")

    if missing_dims:
        available_dims = list(da.dims)
        raise create_data_validation_error(
            f"Missing required dimensions: {', '.join(missing_dims)}",
            details=f"Dataset has dimensions: {available_dims}",
            suggestions=[
                "Check dimension names in your data",
                "Update the 'dimensions' parameter to match your data structure",
                f"Available dimensions: {available_dims}",
            ],
            data_info={
                "missing_dimensions": missing_dims,
                "available_dimensions": available_dims,
                "provided_dimensions": dimensions,
            },
        )


def _validate_coordinates_exist(da: xr.DataArray, coordinates: Dict[str, str]) -> None:
    """
    Validate that all specified coordinates exist in the dataset.

    Parameters
    ----------
    da : xarray.DataArray
        Input data array to validate
    coordinates : dict
        Mapping of conceptual coordinates to actual coordinate names

    Raises
    ------
    DataValidationError
        If any specified coordinate does not exist in the dataset
    """
    missing_coords = []
    for concept_coord, actual_coord in coordinates.items():
        if actual_coord not in da.coords:
            missing_coords.append(f"'{actual_coord}' (for {concept_coord})")

    if missing_coords:
        available_coords = list(da.coords.keys())
        raise create_data_validation_error(
            f"Missing required coordinates: {', '.join(missing_coords)}",
            details=f"Dataset has coordinates: {available_coords}",
            suggestions=[
                "Check coordinate names in your data",
                "Update the 'coordinates' parameter to match your data structure",
                f"Available coordinates: {available_coords}",
            ],
            data_info={
                "missing_coordinates": missing_coords,
                "available_coordinates": available_coords,
                "provided_coordinates": coordinates,
            },
        )


def _infer_dims_coords(
    da: xr.DataArray, dimensions: Optional[Dict[str, str]], coordinates: Optional[Dict[str, str]]
) -> Tuple[Dict[str, str], Dict[str, str]]:
    """
    Determine full set of dimensions and coordinates for the DataArray.
    Sets default (standard) dimension and coordinate names if unspecified.

    This function ensures the dimensions dictionary includes required keys and coordinates
    are properly set based on data structure. It validates that all specified dimensions
    and coordinates exist in the dataset.

    Parameters
    ----------
    da : xarray.DataArray
        Input data array to infer dimensions and coordinates for
    dimensions : dict
        Mapping of conceptual dimensions to actual dimension names
    coordinates : dict, optional
        Mapping of conceptual coordinates to actual coordinate names

    Returns
    -------
    tuple
        Tuple of (dimensions, coordinates) dictionaries with defaults applied

    Raises
    ------
    DataValidationError
        If any specified dimension or coordinate does not exist in the dataset
    """
    if dimensions is None:
        dimensions = {"time": "time", "x": "lon", "y": "lat"}

    if "time" not in dimensions:
        dimensions = {"time": "time", **dimensions}  # Permit partial default dimensions --> "time"

    # Handle coordinates parameter based on data structure
    if coordinates is None:
        if "y" not in dimensions:
            # Unstructured (2D) data - requires explicit coordinate specification
            logger.error("Coordinates parameter required for unstructured data")
            raise create_data_validation_error(
                "Coordinates parameter must be explicitly specified for unstructured data",
                details="Unstructured data requires coordinate names for x and y spatial coordinates",
                suggestions=[
                    "Specify coordinates parameter with spatial coordinate names",
                    "Example: coordinates={'time': 'time', 'x': 'lon', 'y': 'lat'}",
                    f"Your x dimension '{dimensions['x']}' needs associated coordinate names",
                    "If data is gridded, ensure 'y' dimension is also specified",
                ],
                data_info={
                    "data_structure": "unstructured (2D)",
                    "dimensions": dimensions,
                    "missing_coordinates": "x and y spatial coordinates",
                },
            )
        else:
            # Gridded (3D) data - copy dimensions to coordinates
            coordinates = dimensions.copy()
            logger.debug("Gridded data detected - copying dimensions to coordinates")
    else:
        # Coordinates provided but ensure time coordinate is included if missing
        if "time" not in coordinates:
            coordinates = {"time": dimensions.get("time", "time"), **coordinates}
            logger.debug("Added default time coordinate to provided coordinates")

    # Validate dimensions and coordinates exist in dataset
    logger.debug("Validating dimensions and coordinates")
    _validate_dimensions_exist(da, dimensions)
    _validate_coordinates_exist(da, coordinates)

    return dimensions, coordinates


def _validate_data_values(da: xr.DataArray, dimensions: Dict[str, str]) -> None:
    """
    Validate that all unmasked data contains only finite values (no NaN or inf).

    Parameters
    ----------
    da : xarray.DataArray
        Input data array to validate
    dimensions : dict
        Mapping of conceptual dimensions to actual dimension names

    Raises
    ------
    DataValidationError
        If any unmasked data contains NaN or infinite values
    """
    # Create spatial mask from first time step (2D array)
    spatial_mask = np.isfinite(da.isel({dimensions["time"]: 0}))

    # Check if there's any valid data at all
    if not spatial_mask.any().compute():
        raise create_data_validation_error(
            "Dataset contains no valid (finite) data",
            details="All values in the first time step are NaN or infinite",
            suggestions=[
                "Check your input data for data quality issues",
                "Verify the data was loaded correctly",
                "Check for issues in data preprocessing steps",
            ],
            data_info={
                "total_values": int(da.size),
                "total_spatial_locations": int(np.prod([da.sizes[d] for d in da.dims if d != dimensions["time"]])),
            },
        )

    # Reduce first, then mask (avoids broadcasting across time)
    # Count invalid values at each spatial location across time dimension
    # This produces a 2D spatial array instead of a 3D array
    finite_mask = np.isfinite(da)
    invalid_per_location = (~finite_mask).sum(dim=dimensions["time"])

    # Now apply spatial mask to this 2D result (no broadcasting across time!)
    invalid_in_valid_locations = invalid_per_location.where(spatial_mask, 0)

    # Check if any valid ocean location has invalid data
    max_invalid = invalid_in_valid_locations.max().compute()

    if max_invalid > 0:
        total_invalid_in_ocean = int(invalid_in_valid_locations.sum().compute())
        total_ocean_locations = int(spatial_mask.sum().compute())
        locations_affected = int((invalid_in_valid_locations > 0).sum().compute())
        total_time_steps = int(da.sizes[dimensions["time"]])

        raise create_data_validation_error(
            f"Dataset contains {total_invalid_in_ocean} invalid values in {locations_affected} ocean locations",
            details=(
                f"Found invalid data across time series. Worst location has {int(max_invalid)} "
                f"invalid time steps out of {total_time_steps}."
            ),
            suggestions=[
                "Remove or interpolate NaN/infinite values before preprocessing",
                "Check data quality and loading procedures",
                "Consider using data.fillna() or data.interpolate_na() methods",
                "Verify coordinate/dimension alignment in your dataset",
                "For ocean data, ensure land mask is properly applied before preprocessing",
            ],
            data_info={
                "total_invalid_values_in_ocean": total_invalid_in_ocean,
                "locations_affected": locations_affected,
                "total_ocean_locations": total_ocean_locations,
                "max_invalid_at_one_location": int(max_invalid),
                "total_time_steps": total_time_steps,
                "percentage_affected": f"{100.0 * locations_affected / total_ocean_locations:.2f}%",
            },
        )


# ============================
# Methodology Selection
# ============================


[docs] def preprocess_data( da: xr.DataArray, method_anomaly: Literal[ "detrend_harmonic", "shifting_baseline", "fixed_baseline", "detrend_fixed_baseline" ] = "shifting_baseline", method_extreme: Literal["global_extreme", "hobday_extreme"] = "hobday_extreme", threshold_percentile: float = 95, window_year_baseline: int = 15, # for shifting_baseline smooth_days_baseline: int = 21, # " window_days_hobday: int = 11, # for hobday_extreme window_spatial_hobday: Optional[int] = None, # " std_normalise: bool = False, # for detrend_harmonic detrend_orders: Optional[List[int]] = None, # " force_zero_mean: bool = True, # " reference_period: Optional[Tuple[int, int]] = None, # for fixed_baseline & detrend_fixed_baseline method_percentile: Literal["exact", "approximate"] = "approximate", precision: float = 0.01, max_anomaly: float = 5.0, dask_chunks: Optional[Dict[str, int]] = None, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, neighbours: Optional[xr.DataArray] = None, cell_areas: Optional[xr.DataArray] = None, use_temp_checkpoints: bool = False, verbose: Optional[bool] = None, quiet: Optional[bool] = None, ) -> xr.Dataset: """ Complete preprocessing pipeline for marine extreme event identification. Supports separate methods for anomaly computation and extreme identification: Anomaly Methods: * 'detrend_harmonic': Detrending with harmonics and polynomials -- more efficient, but biases statistics * 'shifting_baseline': Rolling climatology using previous window_year_baseline years -- more "correct", but shortens time series by window_year_baseline years * 'fixed_baseline': Daily climatology using full time series -- does _not_ remove climate trends ! * 'detrend_fixed_baseline': Polynomial detrending followed by fixed daily climatology -- keeps full time-series of data, but does not account for trends in the timing of seasonal transitions (which may appear as extremes) Extreme Methods: * 'global_extreme': Global-in-time threshold value * 'hobday_extreme': Local day-of-year specific thresholds with windowing Parameters ---------- da : xarray.DataArray Raw input data method_anomaly : str, default='shifting_baseline' Anomaly computation method ('detrend_harmonic', 'shifting_baseline', 'fixed_baseline', or 'detrend_fixed_baseline'). method_extreme : str, default='hobday_extreme' Extreme identification method ('global_extreme' or 'hobday_extreme'). threshold_percentile : float, default=95 Percentile threshold for extreme event detection. window_year_baseline : int, default=15 Number of previous years for rolling climatology (shifting_baseline method only). smooth_days_baseline : int, default=21 Days for smoothing rolling climatology (shifting_baseline method only). window_days_hobday : int, default=11 Window size for day-of-year threshold calculation (hobday_extreme method only). window_spatial_hobday : int, default=None Spatial window size (2D centred window) for the day-of-year threshold calculation (hobday_extreme method only). std_normalise : bool, default=False Whether to standardise anomalies by rolling standard deviation (detrend_harmonic only). detrend_orders : list, default=[1] Polynomial orders for detrending (detrend_harmonic method only). Default is 1st order (linear) detrend. `[1,2]` e.g. would use a linear+quadratic detrending. force_zero_mean : bool, default=True Whether to enforce zero mean in detrended anomalies (detrend_harmonic method only). reference_period : tuple of (int, int), optional Year range (start_year, end_year) inclusive for computing the daily climatology (fixed_baseline and detrend_fixed_baseline only). If None (default), uses all available years. Anomalies are computed for the full time series regardless. Example: reference_period=(1990, 2020) computes the climatology from 1990-2020 but outputs anomalies for the entire input time range. method_percentile : str, default='approximate' Method for percentile calculation ('exact' or 'approximate') for both global_extreme & hobday_extreme methods. N.B.: Using the exact percentile calculation requires both careful/thoughtful chunking & sufficient memory, in consideration of the limitations inherent to distributed parallel I/O & processing. precision : float, default=0.01 Precision for histogram bins in approximate percentile method. max_anomaly : float, default=5.0 Maximum anomaly value for histogram binning in the approximate percentile method. dask_chunks : dict, optional Chunking specification for distributed computation. dimensions : dict, default={"time": "time", "x": "lon", "y": "lat"} Mapping of dimensions to names in the data. coordinates : dict, optional Mapping of coordinates to names in the data. Defaults to dimensions mapping. neighbours : xarray.DataArray, optional Neighbour connectivity for spatial clustering. cell_areas : xarray.DataArray, optional Cell areas for weighted spatial statistics. use_temp_checkpoints : bool, default=False Enable checkpointing to temporary zarr stores to break Dask graph dependencies. When True, intermediate results (anomalies, thresholds, extremes) are saved to temporary zarr files and immediately reloaded, preventing expensive recomputations. Recommended for large datasets on HPC systems where the 2D histogram computation is expensive. Temporary files are automatically cleaned up after reloading. verbose : bool, default=None Enable verbose logging with detailed progress information. If None, uses current global logging configuration. quiet : bool, default=None Enable quiet logging with minimal output (warnings and errors only). If None, uses current global logging configuration. Note: quiet takes precedence over verbose if both are True. Returns ------- xarray.Dataset Processed dataset with anomalies and extreme event identification Examples -------- Basic usage with gridded SST data for marine heatwave detection: >>> import xarray as xr >>> import marEx >>> >>> # Load and chunk SST data >>> sst = xr.open_dataset('sst_data.nc', chunks={}).sst.chunk({'time': 30}) >>> >>> # Basic preprocessing with default shifting baseline method >>> result = marEx.preprocess_data(sst, threshold_percentile=90) >>> print(result) <xarray.Dataset> Dimensions: (time: 1461, lat: 180, lon: 360) Data variables: dat_anomaly (time, lat, lon) float32 dask.array<chunksize=(30, 180, 360)> mask (lat, lon) bool dask.array<chunksize=(180, 360)> extreme_events (time, lat, lon) bool dask.array<chunksize=(30, 180, 360)> thresholds (lat, lon) float32 dask.array<chunksize=(180, 360)> >>> # Check which locations have extreme events >>> print(f"Total extreme events: {result.extreme_events.sum().compute()}") Total extreme events: 15847 Using shifting baseline method for more accurate climatology: >>> # Requires at least 15 years of data by default >>> result_shifting = marEx.preprocess_data( ... sst, ... method_anomaly="shifting_baseline", ... window_year_baseline=10, # Use shorter window if needed ... smooth_days_baseline=31 # Longer smoothing window ... ) >>> # Note: First 10 years will be removed from output Using Hobday extreme method with day-of-year specific thresholds: >>> result_hobday = marEx.preprocess_data( ... sst, ... method_extreme="hobday_extreme", ... window_days_hobday=11, # 11-day window for each day-of-year ... threshold_percentile=95 ... ) >>> print(result_hobday.thresholds.dims) ('dayofyear', 'lat', 'lon') Previous configuration (marEx v2.0 default) with polynomial detrending and standardisation: >>> result_advanced = marEx.preprocess_data( ... sst, ... method_anomaly="detrend_harmonic", ... detrend_orders=[1, 2], # Linear and quadratic trends ... std_normalise=True, # Standardise by rolling std ... force_zero_mean=True, ... threshold_percentile=95 ... ) >>> # Result includes both raw and standardised anomalies >>> print('dat_stn' in result_advanced) True Processing unstructured data: >>> # For ICON ocean model data >>> icon_sst = xr.open_dataset('icon_sst.nc', chunks={}).to.chunk({'time': 50}) >>> result_unstructured = marEx.preprocess_data( ... icon_sst, ... dimensions={"x": "ncells"}, # Must specify the name of the spatial dimension ... dask_chunks={"time": 50} ... ) Error handling - insufficient data for shifting baseline: >>> short_data = sst.isel(time=slice(0, 1000)) # Only ~3 years >>> try: ... result = marEx.preprocess_data( ... short_data, ... method_anomaly="shifting_baseline", ... window_year_baseline=15 ... ) ... except ValueError as e: ... print(f"Error: {e}") Error: Insufficient data for shifting_baseline method. Dataset spans 3 years but window_year_baseline requires at least 15 years. Performance considerations with chunking: >>> # For large datasets, adjust chunking for memory management >>> large_sst = sst.chunk({"time": 25, "lat": 90, "lon": 180}) >>> result = marEx.preprocess_data( ... large_sst, ... dask_chunks={"time": 25}, ... method_percentile="approximate" # Use approximate method (Default) for long time-series calculations ... ) Integration with tracking workflow: >>> # Preprocess data then track events >>> processed = marEx.preprocess_data(sst, threshold_percentile=95) >>> tracker = marEx.tracker( ... processed.extreme_events, ... processed.mask, ... R_fill=8, ... area_filter_quartile=0.5 ... ) >>> events = tracker.run() >>> print(f"Identified {events.event.max().compute()} distinct events") Simple fixed baseline approach: >>> # Basic daily climatology across all years >>> result_fixed = marEx.preprocess_data( ... sst, ... method_anomaly="fixed_baseline", ... threshold_percentile=95 ... ) >>> # Uses all available data for climatology computation Combined trend removal and fixed climatology: >>> # Remove long-term trends then compute daily climatology >>> result_combined = marEx.preprocess_data( ... sst, ... method_anomaly="detrend_fixed_baseline", ... detrend_orders=[1], # Linear trend ... threshold_percentile=95, ... force_zero_mean=True ... ) >>> # Balances trend removal with simple climatology """ # Set default values for mutable parameters if detrend_orders is None: detrend_orders = [1] if dask_chunks is None: dask_chunks = {"time": 25} # Configure logging if verbose/quiet parameters are provided if verbose is not None or quiet is not None: configure_logging(verbose=verbose, quiet=quiet) # Log preprocessing start with parameters logger.info(f"Starting data preprocessing - Method: {method_anomaly} -> {method_extreme}") logger.info(f"Parameters: percentile={threshold_percentile}%, method_percentile={method_percentile}") logger.debug( f"Anomaly method parameters: window_year={window_year_baseline}, smooth_days={smooth_days_baseline}, " + f"std_normalise={std_normalise}, detrend_orders={detrend_orders}, force_zero_mean={force_zero_mean}" ) logger.debug(f"Extreme method parameters: window_days_hobday={window_days_hobday}") # Log input data info log_dask_info(logger, da, "Input data") log_memory_usage(logger, "Initial memory state") # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) # Check if input data is dask-backed if not is_dask_collection(da.data): logger.error("Input DataArray is not Dask-backed - preprocessing requires chunked data") raise create_data_validation_error( "Input DataArray must be Dask-backed", details="Preprocessing requires chunked data for efficient computation", suggestions=[ "Convert to Dask array: da = da.chunk({'time': 30})", "Load with chunking: xr.open_dataset('file.nc', chunks={'time': 30})", ], data_info={"data_type": type(da.data).__name__, "shape": da.shape}, ) # Validate reference_period before triggering any computation if reference_period is not None and method_anomaly not in ("fixed_baseline", "detrend_fixed_baseline"): raise ConfigurationError( f"reference_period is not supported for method_anomaly='{method_anomaly}'", details="reference_period is only applicable to 'fixed_baseline' and 'detrend_fixed_baseline' methods", suggestions=[ "Remove the reference_period parameter, or", "Use method_anomaly='fixed_baseline' or 'detrend_fixed_baseline'", ], ) # Validate that all unmasked data is valid (finite values only) logger.debug("Validating data values for NaN/infinite values") _validate_data_values(da, dimensions) logger.debug("Enabling Dask large chunk splitting for preprocessing") dask.config.set({"array.slicing.split_large_chunks": True}) # Step 1: Compute anomalies with log_timing( logger, f"Anomaly computation using {method_anomaly} method", log_memory=True, show_progress=True, ): logger.debug( f"Computing anomalies with parameters: method={method_anomaly}, " f"std_normalise={std_normalise}, force_zero_mean={force_zero_mean}" ) ds = compute_normalised_anomaly( da.astype(np.float32), method_anomaly, dimensions, coordinates, window_year_baseline, smooth_days_baseline, std_normalise, detrend_orders, force_zero_mean, reference_period, use_temp_checkpoints, ) log_memory_usage(logger, "After anomaly computation", logging.DEBUG) # For shifting baseline, remove first window_year_baseline years (insufficient climatology data) if method_anomaly == "shifting_baseline": min_year = int(ds[coordinates["time"]].dt.year.min().values.item()) max_year = int(ds[coordinates["time"]].dt.year.max().values.item()) total_years = max_year - min_year + 1 logger.info(f"Shifting baseline data validation: {total_years} years available ({min_year}-{max_year})") if total_years < window_year_baseline: logger.error(f"Insufficient data: {total_years} years < {window_year_baseline} required") raise create_data_validation_error( "Insufficient data for shifting_baseline method", details=f"Dataset spans {total_years} years but requires at least {window_year_baseline} years", suggestions=[ "Use more years of data to meet minimum requirement", f"Reduce window_year_baseline parameter (currently {window_year_baseline})", "Consider using detrend_fixed_baseline or detrend_harmonic method instead", ], data_info={ "available_years": int(total_years), "required_years": int(window_year_baseline), }, ) start_year = int(min_year + window_year_baseline) logger.info(f"Trimming data to start from {start_year} (removing first {window_year_baseline} years)") time_sel = (ds[coordinates["time"]].dt.year >= start_year).compute() ds = ds.isel({dimensions["time"]: time_sel}) # Break graph after expensive anomaly computation # Only needed for shifting baseline, because this is the most expensive method # Other methods result in an odd chunking structure that cannot be checkpointed easily if use_temp_checkpoints and method_anomaly == "shifting_baseline": logger.debug("Checkpointing anomaly dataset to break graph dependencies") ds = checkpoint_to_zarr(ds, name="anomalies", timedim=dimensions["time"]) anomalies = ds.dat_anomaly # Step 2: Identify extreme events (both methods now return consistent tuple structures) with log_timing( logger, f"Extreme event identification using {method_extreme} method", log_memory=True, show_progress=True, ): logger.debug( f"Identifying extremes with parameters: method={method_extreme}, " f"percentile={threshold_percentile}%, method_percentile={method_percentile}" ) extremes, thresholds = identify_extremes( anomalies, method_extreme, threshold_percentile, dimensions, coordinates, window_days_hobday, window_spatial_hobday, method_percentile, precision, max_anomaly, use_temp_checkpoints, ) log_memory_usage(logger, "After extreme identification", logging.DEBUG) # Add extreme events and thresholds to dataset ds_temp = persist(extremes, thresholds) extremes, thresholds = ds_temp ds["extreme_events"] = extremes ds["thresholds"] = thresholds # Handle standardised anomalies if requested (only for detrend_harmonic) if std_normalise and method_anomaly == "detrend_harmonic": logger.info("Processing standardised anomalies for extreme identification") with log_timing( logger, "Standardised extreme identification", log_memory=True, show_progress=True, ): extremes_stn, thresholds_stn = identify_extremes( ds.dat_stn, method_extreme, threshold_percentile, dimensions, coordinates, window_days_hobday, window_spatial_hobday, method_percentile, precision, max_anomaly, use_temp_checkpoints, ) # Break graph after standardised extremes computation if use_temp_checkpoints: logger.debug("Checkpointing standardised extremes and thresholds to break graph dependencies") extremes_stn = checkpoint_to_zarr(extremes_stn, name="extremes_stn", timedim=dimensions["time"]) thresholds_stn = checkpoint_to_zarr(thresholds_stn, name="thresholds_stn", timedim="dayofyear") ds["extreme_events_stn"] = extremes_stn ds["thresholds_stn"] = thresholds_stn # Add optional spatial metadata if neighbours is not None: logger.debug("Adding neighbour connectivity data") chunk_dict = {dim: -1 for dim in neighbours.dims} ds["neighbours"] = neighbours.astype(np.int32).chunk(chunk_dict) if "nv" in neighbours.dims: ds = ds.assign_coords(nv=neighbours.nv) if cell_areas is not None: logger.debug("Adding cell area data") chunk_dict = {dim: -1 for dim in cell_areas.dims} ds["cell_areas"] = cell_areas.astype(np.float32).chunk(chunk_dict) # Add processing parameters to metadata ds.attrs.update( { "method_anomaly": method_anomaly, "method_extreme": method_extreme, "threshold_percentile": threshold_percentile, "preprocessing_steps": _get_preprocessing_steps( method_anomaly, method_extreme, std_normalise, detrend_orders, window_year_baseline, smooth_days_baseline, window_days_hobday, window_spatial_hobday, reference_period, ), } ) # Add method-specific parameters if method_anomaly == "detrend_harmonic": ds.attrs.update( { "detrend_orders": detrend_orders, "force_zero_mean": force_zero_mean, "std_normalise": std_normalise, } ) elif method_anomaly == "shifting_baseline": ds.attrs.update( { "window_year_baseline": window_year_baseline, "smooth_days_baseline": smooth_days_baseline, } ) elif method_anomaly == "fixed_baseline": attrs = {} if reference_period is not None: attrs["reference_period"] = list(reference_period) ds.attrs.update(attrs) elif method_anomaly == "detrend_fixed_baseline": attrs = { "detrend_orders": detrend_orders, "force_zero_mean": force_zero_mean, } if reference_period is not None: attrs["reference_period"] = list(reference_period) ds.attrs.update(attrs) if method_extreme == "hobday_extreme": ds.attrs.update({"window_days_hobday": window_days_hobday}) ds.attrs.update({"method_percentile": method_percentile, "precision": precision, "max_anomaly": max_anomaly}) # Final rechunking time_chunks = dask_chunks.get(dimensions["time"], dask_chunks.get("time", 10)) logger.debug(f"Final rechunking with time chunks: {time_chunks}") chunk_dict = {dimensions[dim]: -1 for dim in ["x", "y"] if dim in dimensions} chunk_dict[dimensions["time"]] = time_chunks if method_extreme == "hobday_extreme": chunk_dict["dayofyear"] = time_chunks ds = ds.chunk(chunk_dict) # Clear encoding metadata that may conflict with actual Dask chunks # This encoding carries over from checkpointing and can cause chunk misalignment errors logger.debug("Clearing encoding metadata for Dask-backed variables") for var in ds.data_vars: if hasattr(ds[var].data, "chunks"): # Only for Dask-backed variables if hasattr(ds[var], "encoding") and "chunks" in ds[var].encoding: del ds[var].encoding["chunks"] # Fix encoding issue with saving when calendar & units attribute is present if "calendar" in ds[coordinates["time"]].attrs: # pragma: no cover logger.debug("Removing calendar attribute for Zarr compatibility") del ds[coordinates["time"]].attrs["calendar"] if "units" in ds[coordinates["time"]].attrs: # pragma: no cover logger.debug("Removing units attribute for Zarr compatibility") del ds[coordinates["time"]].attrs["units"] logger.info("Persisting final dataset and optimising task graph") with log_timing( logger, "Dataset persistence and optimisation", log_memory=True, show_progress=True, ): ds = ds.persist(optimize_graph=True) ds["thresholds"] = ds.thresholds.compute() # Patch for a dask-Zarr bug that has problems saving this data array... ds["mask"] = ds.mask.compute() ds["dat_anomaly"] = fix_dask_tuple_array(ds.dat_anomaly) # Patch for same dask-Zarr bug: materialise *any* remaining Dask-backed coordinates. # Auxiliary (non-index) coordinates (e.g. gridded `x`/`y` or unstructured `lon`/`lat`) # become Dask-backed via `.chunk()` and otherwise retain tuple chunk references that # break distributed serialisation on save. This is seen as an # "AttributeError: 'tuple' object has no attribute 'size'" from dask.array.store. for coord_name in list(ds.coords): if is_dask_collection(ds[coord_name].data): ds[coord_name] = ds[coord_name].compute() if "neighbours" in ds.data_vars: ds["neighbours"] = ds.neighbours.compute() if "cell_areas" in ds.data_vars: ds["cell_areas"] = ds.cell_areas.compute() log_memory_usage(logger, "After dataset persistence", logging.DEBUG) # Final success reporting with summary extreme_count = ds.extreme_events.sum() if hasattr(extreme_count, "compute"): extreme_count = extreme_count.compute() logger.info(f"Preprocessing completed successfully - {extreme_count} extreme events identified") logger.debug(f"Final dataset shape: {ds.dims}") log_dask_info(logger, ds, "Final preprocessed dataset") return ds
def _get_preprocessing_steps( method_anomaly: str, method_extreme: str, std_normalise: bool, detrend_orders: List[int], window_year_baseline: int, smooth_days_baseline: int, window_days_hobday: int, window_spatial_hobday: Optional[int], reference_period: Optional[Tuple[int, int]] = None, ) -> List[str]: """Generate preprocessing steps description based on selected methods.""" steps = [] if method_anomaly == "detrend_harmonic": steps.append(f"Removed polynomial trend orders={detrend_orders} & seasonal cycle") if std_normalise: steps.append("Normalised by 30-day rolling STD") elif method_anomaly == "shifting_baseline": steps.append(f"Rolling climatology using {window_year_baseline} years") steps.append(f"Smoothed with {smooth_days_baseline}-day window") elif method_anomaly == "fixed_baseline": if reference_period is not None: steps.append(f"Daily climatology computed from {reference_period[0]}-{reference_period[1]}") else: steps.append("Daily climatology computed from full time series") elif method_anomaly == "detrend_fixed_baseline": steps.append(f"Removed polynomial trend orders={detrend_orders}") if reference_period is not None: steps.append(f"Daily climatology computed from detrended data ({reference_period[0]}-{reference_period[1]})") else: steps.append("Daily climatology computed from detrended data") # Extreme method steps if method_extreme == "global_extreme": steps.append("Global percentile threshold applied to all days") elif method_extreme == "hobday_extreme": if window_spatial_hobday is not None: steps.append( f"Day-of-year thresholds with {window_days_hobday} day window & {window_spatial_hobday} spatial neighbours" ) else: steps.append(f"Day-of-year thresholds with {window_days_hobday} day window") return steps
[docs] def compute_normalised_anomaly( da: xr.DataArray, method_anomaly: Literal[ "detrend_harmonic", "shifting_baseline", "fixed_baseline", "detrend_fixed_baseline" ] = "shifting_baseline", dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, window_year_baseline: int = 15, # for shifting_baseline smooth_days_baseline: int = 21, # " std_normalise: bool = False, # for detrend_harmonic detrend_orders: Optional[List[int]] = None, # " force_zero_mean: bool = True, # " reference_period: Optional[Tuple[int, int]] = None, # for fixed_baseline & detrend_fixed_baseline use_temp_checkpoints: bool = False, verbose: Optional[bool] = None, quiet: Optional[bool] = None, ) -> xr.Dataset: """ Generate normalised anomalies using specified methodology. Parameters ---------- da : xarray.DataArray Input data with dimensions matching the 'dimensions' parameter method_anomaly : str, default='shifting_baseline' Anomaly computation method. Options: - 'detrend_harmonic': Detrending with harmonics and polynomials (efficient, biased) - 'shifting_baseline': Rolling climatology (accurate, shortens time series) - 'fixed_baseline': Daily climatology using full time series (keeps long-term trends in the anomaly) - 'detrend_fixed_baseline': Polynomial detrending + fixed climatology (does not shorten time series, keeps trends in seasonal timing in the anomaly) dimensions : dict, optional Mapping of conceptual dimensions to actual dimension names in the data coordinates : dict, optional Mapping of conceptual coordinates to actual coordinate names in the data window_year_baseline : int, default=15 Number of years for rolling climatology (shifting_baseline only) smooth_days_baseline : int, default=21 Days for smoothing rolling climatology (shifting_baseline only) std_normalise : bool, default=False Whether to normalise by 30-day rolling standard deviation (detrend_harmonic only) detrend_orders : list, default=[1] Polynomial orders for trend removal (detrend_harmonic and detrend_fixed_baseline only) force_zero_mean : bool, default=True Explicitly enforce zero mean in final anomalies (detrend_harmonic and detrend_fixed_baseline only) reference_period : tuple of (int, int), optional Year range (start_year, end_year) inclusive for computing the daily climatology (fixed_baseline and detrend_fixed_baseline only). If None (default), uses all available years. Anomalies are computed for the full time series regardless. Returns ------- xarray.Dataset Dataset containing anomalies, mask, and metadata Examples -------- Basic detrended baseline anomaly computation: >>> import xarray as xr >>> import marEx >>> >>> # Load chunked SST data >>> sst = xr.open_dataset('sst_data.nc', chunks={}).sst.chunk({'time': 30}) >>> >>> # Compute anomalies using shifting baseline (default) >>> result = marEx.compute_normalised_anomaly(sst) >>> print(result.data_vars) Data variables: dat_anomaly (time, lat, lon) float32 dask.array<chunksize=(30, 180, 360)> mask (lat, lon) bool dask.array<chunksize=(180, 360)> >>> # Check that anomalies have approximately zero mean >>> print(f"Mean anomaly: {result.dat_anomaly.mean().compute():.6f}") Mean anomaly: 0.000023 Previous configuration (marEx v2.0 default) of detrended baseline with higher-order polynomials and standardisation. Note: marEx v3.0+ uses shifting_baseline as the default method: >>> result_advanced = marEx.compute_normalised_anomaly( ... sst, ... method_anomaly="detrend_harmonic", ... detrend_orders=[1, 2, 3], # Linear, quadratic, cubic trends ... std_normalise=True, # Add standardised anomalies ... force_zero_mean=True ... ) >>> print(result_advanced.data_vars) Data variables: dat_anomaly (time, lat, lon) float32 dask.array<chunksize=(30, 180, 360)> mask (lat, lon) bool dask.array<chunksize=(180, 360)> dat_stn (time, lat, lon) float32 dask.array<chunksize=(30, 180, 360)> STD (dayofyear, lat, lon) float32 dask.array<chunksize=(366, 180, 360)> >>> # Standardised anomalies have unit variance >>> print(f"STD of standardised anomalies: {result_advanced.dat_stn.std().compute():.3f}") Accurate shifting baseline method for climate-aware anomalies: >>> result_shifting = marEx.compute_normalised_anomaly( ... sst, ... method_anomaly="shifting_baseline", ... window_year_baseline=10, # Use 10-year rolling climatology ... smooth_days_baseline=31 # 31-day smoothing window ... ) >>> # Anomalies computed relative to recent past climatology Processing unstructured data: >>> # ICON ocean model with ncells dimension >>> icon_data = xr.open_dataset('icon_sst.nc', chunks={}).to.chunk({'time': 25}) >>> result_unstructured = marEx.compute_normalised_anomaly( ... icon_data, ... dimensions={"time": "time", "x": "ncells"} ... coordinates={"time": "time", "x": "lon", "y": "lat"}, ... ) >>> print(result_unstructured.dims) Frozen({'time': 1461, 'ncells': 83886}) Comparison of methods - detrended vs shifting baseline: >>> # Detrended baseline - faster, slight bias >>> detrended = marEx.compute_normalised_anomaly( ... sst, method_anomaly="detrend_harmonic" ... ) >>> >>> # Shifting baseline - slower, more accurate >>> shifting = marEx.compute_normalised_anomaly( ... sst, method_anomaly="shifting_baseline", ... window_year_baseline=15 ... ) >>> >>> # Compare anomaly magnitudes >>> print(f"Detrended RMS: {detrended.dat_anomaly.std().compute():.3f}") >>> print(f"Shifting RMS: {shifting.dat_anomaly.std().compute():.3f}") Fixed baseline climatology: >>> # Use full time series for daily climatology >>> result_fixed = marEx.compute_normalised_anomaly( ... sst, ... method_anomaly="fixed_baseline" ... ) >>> # Climatology computed from all available years Fixed baseline with a restricted reference period: >>> # Compute climatology from 1990-2020 only, but output anomalies for all years >>> result_ref = marEx.compute_normalised_anomaly( ... sst, ... method_anomaly="fixed_baseline", ... reference_period=(1990, 2020) ... ) Fixed detrended baseline: >>> # Remove long-term trends then compute fixed climatology >>> result_fixed_detrended = marEx.compute_normalised_anomaly( ... sst, ... method_anomaly="detrend_fixed_baseline", ... detrend_orders=[1], # Remove linear trend ... force_zero_mean=True ... ) >>> # Combines trend removal with fixed climatology """ # Set default values for mutable parameters if detrend_orders is None: detrend_orders = [1] # Configure logging if verbose/quiet parameters are provided if verbose is not None or quiet is not None: configure_logging(verbose=verbose, quiet=quiet) logger.debug(f"Computing normalised anomaly using {method_anomaly} method") # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) # Validate reference_period is only used with compatible methods if reference_period is not None and method_anomaly not in ("fixed_baseline", "detrend_fixed_baseline"): raise ConfigurationError( f"reference_period is not supported for method_anomaly='{method_anomaly}'", details="reference_period is only applicable to 'fixed_baseline' and 'detrend_fixed_baseline' methods", suggestions=[ "Remove the reference_period parameter, or", "Use method_anomaly='fixed_baseline' or 'detrend_fixed_baseline'", ], ) if method_anomaly == "detrend_harmonic": logger.debug( f"Detrended baseline parameters: std_normalise={std_normalise}, orders={detrend_orders}, zero_mean={force_zero_mean}" ) return _compute_anomaly_detrended(da, std_normalise, detrend_orders, dimensions, coordinates, force_zero_mean) elif method_anomaly == "shifting_baseline": logger.debug(f"Shifting baseline parameters: window_years={window_year_baseline}, smooth_days={smooth_days_baseline}") return _compute_anomaly_shifting_baseline( da, window_year_baseline, smooth_days_baseline, dimensions, coordinates, use_temp_checkpoints ) elif method_anomaly == "fixed_baseline": logger.debug(f"Fixed baseline parameters: reference_period={reference_period}") return _compute_anomaly_fixed_baseline(da, dimensions, coordinates, reference_period) elif method_anomaly == "detrend_fixed_baseline": logger.debug( f"Fixed detrended baseline parameters: orders={detrend_orders}, " f"zero_mean={force_zero_mean}, reference_period={reference_period}" ) return _compute_anomaly_detrend_fixed_baseline( da, detrend_orders, dimensions, coordinates, force_zero_mean, reference_period ) else: logger.error(f"Unknown anomaly method: {method_anomaly}") raise ConfigurationError( f"Unknown anomaly method '{method_anomaly}'", details="Invalid method_anomaly parameter", suggestions=[ "Use 'detrend_harmonic' for efficient processing with trend and harmonic removal", "Use 'shifting_baseline' for accurate climatology (requires more data)", "Use 'fixed_baseline' to remove a single daily climatology across all years " "(keeps any long-term trend in the anomaly)", "Use 'detrend_fixed_baseline' for trend removal followed by fixed climatology", ], context={ "provided_method": method_anomaly, "valid_methods": ["detrend_harmonic", "shifting_baseline", "fixed_baseline", "detrend_fixed_baseline"], }, )
[docs] def identify_extremes( da: xr.DataArray, method_extreme: Literal["global_extreme", "hobday_extreme"] = "hobday_extreme", threshold_percentile: float = 95, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, window_days_hobday: int = 11, # for hobday_extreme window_spatial_hobday: Optional[int] = None, # for hobday_extreme method_percentile: Literal["exact", "approximate"] = "approximate", precision: float = 0.01, max_anomaly: float = 5.0, use_temp_checkpoints: bool = False, verbose: Optional[bool] = None, quiet: Optional[bool] = None, ) -> Tuple[xr.DataArray, xr.DataArray]: """ Identify extreme events exceeding a percentile threshold using specified method. Parameters ---------- da : xarray.DataArray DataArray containing anomalies method_extreme : str, default='hobday_extreme' Method for threshold calculation ('global_extreme' or 'hobday_extreme') threshold_percentile : float, default=95 Percentile threshold (e.g., 95 for 95th percentile) dimensions : dict, optional Mapping of dimensions to names in the data coordinates : dict, optional Mapping of coordinates to names in the data window_days_hobday : int, default=11 Window for day-of-year threshold (hobday_extreme only) window_spatial_hobday : int, default=None Window for day-of-year threshold spatial clustering (hobday_extreme only) method_percentile : str, default='approximate' Method for percentile computation ('exact' or 'approximate') precision : float, default=0.01 Precision for histogram bins in approximate method max_anomaly : float, default=5.0 Maximum anomaly value for histogram binning Returns ------- tuple Tuple of (extremes, thresholds) where extremes is a boolean array identifying extreme events and thresholds contains the threshold values used Examples -------- Basic extreme identification with global thresholds: >>> import xarray as xr >>> import marEx >>> >>> # Load anomaly data (from compute_normalised_anomaly) >>> anomalies = xr.open_dataset('anomalies.nc', chunks={}).dat_anomaly >>> >>> # Identify extreme events using global-in-time 95th percentile >>> extremes, thresholds = marEx.identify_extremes( ... anomalies, ... method_extreme="global_extreme", ... threshold_percentile=95 ... ) >>> print(f"Extreme events shape: {extremes.shape}") Extreme events shape: (1461, 180, 360) >>> print(f"Thresholds shape: {thresholds.shape}") Thresholds shape: (180, 360) >>> # Count total extreme events >>> total_extremes = extremes.sum().compute() >>> print(f"Total extreme events: {total_extremes}") Using day-of-year specific thresholds (cf. Hobday et al. 2016 method): >>> # More sophisticated threshold calculation >>> extremes_hobday, thresholds_hobday = marEx.identify_extremes( ... anomalies, ... method_extreme="hobday_extreme", ... threshold_percentile=95, ... window_days_hobday=11 # 11-day window around each day-of-year ... window_spatial_hobday=3 # 3x3 spatial window for clustering percentile calcuation ... ) >>> print(f"Hobday thresholds shape: {thresholds_hobday.shape}") Hobday thresholds shape: (366, 180, 360) >>> # Compare seasonal variation in thresholds >>> summer_threshold = thresholds_hobday.sel(dayofyear=200).mean() >>> winter_threshold = thresholds_hobday.sel(dayofyear=50).mean() >>> print(f"Summer vs Winter thresholds: {summer_threshold:.3f} vs {winter_threshold:.3f}") Comparison of exact vs approximate percentile methods: >>> # Approximate method (faster, default) >>> extremes_approx, thresh_approx = marEx.identify_extremes( ... anomalies, method_percentile="approximate" ... ) >>> >>> # Exact method (slower & memory intensive) >>> extremes_exact, thresh_exact = marEx.identify_extremes( ... anomalies, method_percentile="exact" ... ) >>> >>> # Compare threshold precision — ~0.005C >>> threshold_diff = (thresh_exact - thresh_approx).std().compute() >>> print(f"Threshold difference (exact vs approx): {threshold_diff:.6f}") Different percentile thresholds for varying event rarity: >>> # Conservative threshold - very extreme events only >>> extremes_98, _ = marEx.identify_extremes( ... anomalies, threshold_percentile=98 ... ) >>> >>> # Moderate threshold - more frequent events >>> extremes_90, _ = marEx.identify_extremes( ... anomalies, threshold_percentile=90 ... ) >>> >>> # Compare event frequency >>> print(f"99th percentile events: {extremes_99.sum().compute()}") >>> print(f"90th percentile events: {extremes_90.sum().compute()}") Processing unstructured data: >>> # ICON ocean model data >>> icon_anomalies = xr.open_dataset('icon_anomalies.nc', chunks={}).dat_anomaly >>> extremes_unstructured, thresholds_unstructured = marEx.identify_extremes( ... icon_anomalies, ... dimensions={"time": "time", "x": "ncells"}, ... coordinates={"time": "time", "x": "lon", "y": "lat"}, ... threshold_percentile=95 ... ) >>> print(f"Unstructured extremes shape: {extremes_unstructured.shape}") Advanced Hobday method with custom temporal window: >>> # Longer temporal window for smoother thresholds >>> extremes_smooth, thresholds_smooth = marEx.identify_extremes( ... anomalies, ... method_extreme="hobday_extreme", ... window_days_hobday=31, # Longer smoothing window ... threshold_percentile=95 ... ) >>> >>> # Compare threshold smoothness >>> std_11day = thresholds_hobday.std(dim='dayofyear').mean().compute() >>> std_31day = thresholds_smooth.std(dim='dayofyear').mean().compute() >>> print(f"Threshold variability: 11-day={std_11day:.3f}, 31-day={std_31day:.3f}") """ # Configure logging if verbose/quiet parameters are provided if verbose is not None or quiet is not None: configure_logging(verbose=verbose, quiet=quiet) logger.debug(f"Identifying extremes using {method_extreme} method - {threshold_percentile}th percentile") # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) # Validate method_percentile parameter valid_methods = ["exact", "approximate"] if method_percentile not in valid_methods: logger.error(f"Unknown method_percentile: {method_percentile}") raise ConfigurationError( f"Unknown method_percentile '{method_percentile}'", details="Invalid method_percentile parameter", suggestions=[ "Use 'exact' for precise percentile computation (memory intensive)", "Use 'approximate' for efficient histogram-based computation (default)", ], context={ "provided_method": method_percentile, "valid_methods": valid_methods, }, ) # Validate parameter compatibility for exact percentile method if method_percentile == "exact": default_precision = 0.01 default_max_anomaly = 5.0 # Check if precision parameter was explicitly set to a non-default value if precision != default_precision: logger.error(f"Invalid parameter: precision={precision} with method_percentile='exact'") raise ConfigurationError( "Parameter 'precision' cannot be used with method_percentile='exact'", details=( f"The precision parameter (precision={precision}) is only used by the approximate " "histogram method and is ignored when using exact percentile computation" ), suggestions=[ "Remove the 'precision' parameter when using method_percentile='exact'", "Use method_percentile='approximate' if you want to control histogram precision", ], context={ "method_percentile": method_percentile, "provided_precision": precision, "default_precision": default_precision, }, ) # Check if max_anomaly parameter was explicitly set to a non-default value if max_anomaly != default_max_anomaly: logger.error(f"Invalid parameter: max_anomaly={max_anomaly} with method_percentile='exact'") raise ConfigurationError( "Parameter 'max_anomaly' cannot be used with method_percentile='exact'", details=( f"The max_anomaly parameter (max_anomaly={max_anomaly}) is only used by the approximate " "histogram method and is ignored when using exact percentile computation" ), suggestions=[ "Remove the 'max_anomaly' parameter when using method_percentile='exact'", "Use method_percentile='approximate' if you want to control histogram binning range", ], context={ "method_percentile": method_percentile, "provided_max_anomaly": max_anomaly, "default_max_anomaly": default_max_anomaly, }, ) # Validate percentile parameter when using approximate method if threshold_percentile < 60 and method_percentile == "approximate": logger.error(f"Invalid percentile threshold: {threshold_percentile}% with method_percentile='approximate'") raise ConfigurationError( f"Percentile threshold {threshold_percentile}% is not supported with method_percentile='approximate'", details=( "Low percentile thresholds (<60%) produce undefined and unsupported behaviour " "when using approximate histogram methods" ), suggestions=[ "Use method_percentile='exact' for percentiles below 60%", "Use a higher percentile threshold (≥60%) with method_percentile='approximate'", "Consider if such low percentiles are appropriate for extreme event identification", ], context={ "threshold_percentile": threshold_percentile, "method_percentile": method_percentile, "min_supported_percentile": 60, }, ) # Validate window_spatial_hobday parameter if window_spatial_hobday is not None: # Check if window_spatial_hobday is specified for unstructured grid has_y_dim = "y" in dimensions and dimensions["y"] in da.dims if not has_y_dim: logger.error(f"window_spatial_hobday={window_spatial_hobday} specified for unstructured grid") raise ConfigurationError( "window_spatial_hobday is not supported for unstructured grids", details=( "Spatial smoothing with window_spatial_hobday requires structured grids with both x and y dimensions. " "Unstructured grids do not support spatial window operations due to computational and memory " "limitations of the algorithms." ), suggestions=[ "Remove the window_spatial_hobday parameter for unstructured grids", "Use structured grid data if spatial smoothing is required", "Set window_spatial_hobday=None to use default behavior", ], context={ "grid_type": "unstructured", "window_spatial_hobday": window_spatial_hobday, "dimensions": dimensions, "available_dims": list(da.dims), }, ) # Check if window_spatial_hobday is specified when hobday_extreme is not used if method_extreme != "hobday_extreme": logger.error(f"window_spatial_hobday={window_spatial_hobday} specified with method_extreme='{method_extreme}'") raise ConfigurationError( "window_spatial_hobday can only be used with method_extreme='hobday_extreme'", details=( "The window_spatial_hobday parameter is only implemented for the Hobday extreme method. " "Other extreme methods do not support spatial smoothing due to computational and memory " "limitations of the algorithms." ), suggestions=[ "Remove the window_spatial_hobday parameter when using method_extreme='global_extreme'", "Use method_extreme='hobday_extreme' if spatial smoothing is required", "Set window_spatial_hobday=None to use default behavior", ], context={ "method_extreme": method_extreme, "window_spatial_hobday": window_spatial_hobday, "compatible_methods": ["hobday_extreme"], }, ) # Check if window_spatial_hobday is specified when method_percentile is "exact" if method_percentile == "exact": logger.error(f"window_spatial_hobday={window_spatial_hobday} specified with method_percentile='exact'") raise ConfigurationError( "window_spatial_hobday is not supported with method_percentile='exact'", details=( "The window_spatial_hobday parameter is only implemented for the approximate percentile method. " "Exact percentile computation does not support spatial smoothing due to computational and memory " "limitations of the algorithms." ), suggestions=[ "Remove the window_spatial_hobday parameter when using method_percentile='exact'", "Use method_percentile='approximate' if spatial smoothing is required", "Set window_spatial_hobday=None to use default behavior", ], context={ "method_percentile": method_percentile, "window_spatial_hobday": window_spatial_hobday, "compatible_methods": ["approximate"], }, ) # Validate that window parameters are odd numbers (only for hobday_extreme method) if method_extreme == "hobday_extreme" and window_days_hobday is not None and window_days_hobday % 2 == 0: logger.error(f"window_days_hobday={window_days_hobday} is not an odd number") raise ConfigurationError( "window_days_hobday must be an odd number", details=( f"Window parameters require odd numbers to ensure symmetric windows around a central point. " f"window_days_hobday={window_days_hobday} is even, which would create asymmetric temporal windows." ), suggestions=[ f"Use window_days_hobday={window_days_hobday + 1} or {window_days_hobday - 1}", "Choose an odd number", ], context={ "window_days_hobday": window_days_hobday, "is_odd": False, }, ) # Set default spatial window (only for hobday_extreme method) if method_extreme == "hobday_extreme" and window_spatial_hobday is None and "y" in dimensions and dimensions["y"] in da.dims: window_spatial_hobday = 5 # Default to 5x5 spatial window for structured grids if method_extreme == "hobday_extreme" and window_spatial_hobday is not None and window_spatial_hobday % 2 == 0: logger.error(f"window_spatial_hobday={window_spatial_hobday} is not an odd number") raise ConfigurationError( "window_spatial_hobday must be an odd number", details=( f"Window parameters require odd numbers to ensure symmetric windows around a central point. " f"window_spatial_hobday={window_spatial_hobday} is even, which would create asymmetric spatial windows." ), suggestions=[ f"Use window_days_hobday={window_days_hobday + 1} or {window_days_hobday - 1}", "Choose an odd number.", ], context={ "window_spatial_hobday": window_spatial_hobday, "is_odd": False, }, ) if method_extreme == "global_extreme": logger.debug(f"Global extreme method - method_percentile={method_percentile}") return _identify_extremes_constant(da, threshold_percentile, method_percentile, dimensions, precision, max_anomaly) elif method_extreme == "hobday_extreme": logger.debug(f"Hobday extreme method - window_days={window_days_hobday}, method_percentile={method_percentile}") return _identify_extremes_hobday( da, threshold_percentile, window_days_hobday, window_spatial_hobday, method_percentile, dimensions, coordinates, precision, max_anomaly, use_temp_checkpoints, ) else: logger.error(f"Unknown extreme method: {method_extreme}") raise ConfigurationError( f"Unknown extreme method '{method_extreme}'", details="Invalid method_extreme parameter", suggestions=[ "Use 'global_extreme' for efficient constant percentile threshold", "Use 'hobday_extreme' for day-of-year specific thresholds", ], context={ "provided_method": method_extreme, "valid_methods": ["global_extreme", "hobday_extreme"], }, )
# =============================================== # Shifting Baseline Anomaly Method (New Method) # ===============================================
[docs] def rolling_climatology( da: xr.DataArray, window_year_baseline: int = 15, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, use_temp_checkpoints: bool = False, ) -> xr.DataArray: """ Compute rolling climatology efficiently using flox cohorts. Uses the previous `window_year_baseline` years of data and reassemble it to match the original data structure. Years without enough previous data will be filled with NaN. Parameters ---------- da : xarray.DataArray Input data with time coordinate window_year_baseline : int, default=15 Number of years to include in each climatology window dimensions : dict, optional Mapping of dimensions to names in the data coordinates : dict, optional Mapping of coordinates to names in the data Returns ------- xarray.DataArray Rolling climatology with same shape as input data Examples -------- Basic rolling climatology computation: >>> import xarray as xr >>> import marEx >>> >>> # Load 20 years of SST data >>> sst = xr.open_dataset('sst_data.nc', chunks={}).sst.chunk({'time': 30}) >>> >>> # Compute 15-year rolling climatology >>> climatology = marEx.rolling_climatology(sst, window_year_baseline=15) >>> print(climatology.shape) (7305, 180, 360) # Same as input >>> >>> # First 15 years will be NaN (insufficient history) >>> print(f"NaN values in first year: {climatology.isel(time=slice(0, 365)).isnull().all().compute()}") True Shorter window for datasets with limited time span: >>> # For datasets with only 10 years, use shorter window >>> short_climatology = marEx.rolling_climatology( ... sst, window_year_baseline=5 ... ) >>> # First 5 years will be NaN instead of 15 Processing unstructured data: >>> # ICON ocean model data >>> icon_sst = xr.open_dataset('icon_sst.nc', chunks={}).to.chunk({'time': 25}) >>> icon_climatology = marEx.rolling_climatology( ... icon_sst, ... dimensions={"time": "time", "x": "ncells"} ... coordinates={"time": "time", "x": "lon", "y": "lat"} ... ) >>> print(icon_climatology.dims) Frozen({'time': 7305, 'ncells': 83886}) Comparing with fixed climatology: >>> # Fixed climatology (traditional approach) >>> fixed_clim = sst.groupby(sst.time.dt.dayofyear).mean() >>> >>> # Rolling climatology (adaptive approach) >>> rolling_clim = marEx.rolling_climatology(sst) >>> >>> # Rolling climatology adapts to climate change >>> clim_2000 = rolling_clim.sel(time='2000').mean() >>> clim_2020 = rolling_clim.sel(time='2020').mean() >>> print(f"Climate change signal: {(clim_2020 - clim_2000).compute():.3f} °C") Memory considerations for large datasets: >>> # Ensure appropriate chunking for memory efficiency >>> large_sst = sst.chunk({'time': 30, 'lat': 45, 'lon': 90}) >>> large_climatology = marEx.rolling_climatology(large_sst) >>> # Output maintains input chunking structure """ # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) timedim = dimensions["time"] time_coord = coordinates["time"] original_chunk_dict = dict(zip(da.dims, da.chunks)) # Add temporal coordinates years = da[time_coord].dt.year doys = da[time_coord].dt.dayofyear da = da.assign_coords({"year": years, "dayofyear": doys}) # Get temporal bounds years, doys = persist(years, doys) year_vals = years.values doy_vals = doys.values unique_years = np.unique(year_vals) min_year = int(unique_years.min().item()) # Create long-form grouping variables # For each time point, determine which target years it contributes to contributing_time_indices = [] contributing_target_years = [] contributing_dayofyears = [] for t_idx, (year_val, doy_val) in enumerate(zip(year_vals, doy_vals)): # Convert numpy scalars to Python ints to avoid dtype issues year_val = int(year_val) doy_val = int(doy_val) # Find target years this time point contributes to # A time point from year Y contributes to target years where: # target_year - window_year_baseline <= Y < target_year # Which means: Y < target_year <= Y + window_year_baseline candidate_targets = unique_years[(unique_years > year_val) & (unique_years <= year_val + window_year_baseline)] # Only include target years that have sufficient history valid_targets = candidate_targets[candidate_targets >= min_year + window_year_baseline] # Add entries for each valid target year n_targets = len(valid_targets) contributing_time_indices.extend([t_idx] * n_targets) contributing_target_years.extend(valid_targets.tolist()) contributing_dayofyears.extend([doy_val] * n_targets) # Convert to numpy arrays with explicit dtypes time_indices = np.array(contributing_time_indices, dtype=np.int32) target_year_groups = np.array(contributing_target_years, dtype=np.int32) dayofyear_groups = np.array(contributing_dayofyears, dtype=np.int32) # Create long-form dataset by selecting the contributing time points long_form_data = da.isel({timedim: time_indices}) # Create a new time dimension for the long-form data long_timedim = f"{timedim}_contrib" long_form_data = long_form_data.rename({timedim: long_timedim}) # Convert grouping arrays to DataArrays with the correct dimension target_year_da = xr.DataArray(target_year_groups, dims=[long_timedim], name="target_year") dayofyear_da = xr.DataArray(dayofyear_groups, dims=[long_timedim], name="dayofyear") # Use flox with both grouping variables to compute climatologies climatologies = flox.xarray.xarray_reduce( long_form_data, target_year_da, dayofyear_da, dim=long_timedim, func="nanmean", expected_groups=(unique_years, np.arange(1, 367, dtype=np.int32)), isbin=(False, False), dtype=np.float32, fill_value=np.nan, ).chunk({"dayofyear": -1}) if use_temp_checkpoints: logger.debug("Checkpointing climatologies to break graph dependencies") climatologies = checkpoint_to_zarr(climatologies, name="climatologies", timedim=timedim) # Create index arrays for final mapping year_to_idx = pd.Series(range(len(unique_years)), index=unique_years) year_indices = year_to_idx[year_vals].values # Select appropriate climatology for each time point result = climatologies.isel( target_year=xr.DataArray(year_indices, dims=[timedim]), dayofyear=xr.DataArray(doy_vals - 1, dims=[timedim]), ) # Clean up dimensions and coordinates result = result.drop_vars(["target_year", "dayofyear"]) return result.chunk(original_chunk_dict)
[docs] def smoothed_rolling_climatology( da: xr.DataArray, window_year_baseline: int = 15, smooth_days_baseline: int = 21, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, use_temp_checkpoints: bool = False, ) -> xr.DataArray: """ Compute a smoothed rolling climatology using the previous `window_year_baseline` years of data and reassemble it to match the original data structure. Years without enough previous data will be filled with NaN. Parameters ---------- da : xarray.DataArray Input data with time coordinate window_year_baseline : int, default=15 Number of years to include in each climatology window smooth_days_baseline : int, default=21 Number of days for temporal smoothing window dimensions : dict, optional Mapping of dimensions to names in the data coordinates : dict, optional Mapping of coordinates to names in the data Returns ------- xarray.DataArray Smoothed rolling climatology with same shape as input data Examples -------- Basic smoothed rolling climatology: >>> import xarray as xr >>> import marEx >>> >>> # Load SST data >>> sst = xr.open_dataset('sst_data.nc', chunks={}).sst.chunk({'time': 30}) >>> >>> # Compute smoothed rolling climatology >>> smooth_clim = marEx.smoothed_rolling_climatology( ... sst, ... window_year_baseline=15, ... smooth_days_baseline=21 ... ) >>> print(smooth_clim.shape) (7305, 180, 360) Comparing different smoothing windows: >>> # Short smoothing - more day-to-day variability >>> clim_short = marEx.smoothed_rolling_climatology( ... sst, smooth_days_baseline=7 ... ) >>> >>> # Long smoothing - smoother seasonal cycle >>> clim_long = marEx.smoothed_rolling_climatology( ... sst, smooth_days_baseline=61 ... ) >>> >>> # Compare variability >>> var_short = clim_short.std(dim='time').mean().compute() >>> var_long = clim_long.std(dim='time').mean().compute() >>> print(f"Variability: short={var_short:.3f}, long={var_long:.3f}") Climatology for anomaly computation: >>> # Compute smoothed climatology then anomalies >>> climatology = marEx.smoothed_rolling_climatology(sst) >>> anomalies = sst - climatology >>> >>> # Check that anomalies have reasonable properties >>> print(f"Anomaly mean: {anomalies.mean().compute():.6f}") >>> print(f"Anomaly std: {anomalies.std().compute():.3f}") Unstructured data processing: >>> # ICON ocean data >>> icon_sst = xr.open_dataset('icon_sst.nc', chunks={}).to.chunk({'time': 25}) >>> icon_smooth_clim = marEx.smoothed_rolling_climatology( ... icon_sst, ... dimensions={"time": "time", "x": "ncells"}, ... coordinates={"time": "time", "x": "lon", "y": "lat"}, ... window_year_baseline=10, ... smooth_days_baseline=31 ... ) Effect of smoothing on seasonal cycle: >>> # Raw rolling climatology (no temporal smoothing) >>> raw_clim = marEx.rolling_climatology(sst, window_year_baseline=15) >>> >>> # Smoothed rolling climatology >>> smooth_clim = marEx.smoothed_rolling_climatology( ... sst, window_year_baseline=15, smooth_days_baseline=21 ... ) >>> >>> # Compare seasonal cycle smoothness >>> # Extract annual cycle for a point >>> point_raw = raw_clim.isel(lat=90, lon=180).sel(time='2010') >>> point_smooth = smooth_clim.isel(lat=90, lon=180).sel(time='2010') >>> >>> print(f"Raw climatology range: {(point_raw.max() - point_raw.min()).compute():.3f}") >>> print(f"Smooth climatology range: {(point_smooth.max() - point_smooth.min()).compute():.3f}") Performance considerations: >>> # Efficient implementation smooths raw data first, then computes climatology >>> # This is more memory-efficient than smoothing the climatology >>> large_sst = sst.chunk({'time': 25, 'lat': 45, 'lon': 90}) >>> efficient_clim = marEx.smoothed_rolling_climatology(large_sst) """ # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) timedim = dimensions["time"] # N.B.: It is more efficient (chunking-wise) to smooth the raw data rather than the climatology da_smoothed = ( da.rolling({timedim: smooth_days_baseline}, center=True).mean().chunk(dict(zip(da.dims, da.chunks))).astype(np.float32) ) clim = rolling_climatology(da_smoothed, window_year_baseline, dimensions, coordinates, use_temp_checkpoints) return clim
def _compute_anomaly_shifting_baseline( da: xr.DataArray, window_year_baseline: int = 15, smooth_days_baseline: int = 21, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, use_temp_checkpoints: bool = False, ) -> xr.Dataset: """ Compute anomalies using shifting baseline method with smoothed rolling climatology. Returns ------- xarray.Dataset Dataset containing anomalies and mask """ # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) # Compute smoothed rolling climatology climatology_smoothed = smoothed_rolling_climatology( da, window_year_baseline, smooth_days_baseline, dimensions, coordinates, use_temp_checkpoints ) # Compute anomaly as difference from climatology anomalies = da - climatology_smoothed # Create ocean/land mask from first time step mask = np.isfinite(da.isel({dimensions["time"]: 0})).drop_vars({coordinates["time"]}) # Build output dataset return xr.Dataset({"dat_anomaly": anomalies, "mask": mask}) # ========================== # Hobday Extreme Definition # ========================== def _identify_extremes_hobday( da: xr.DataArray, threshold_percentile: float = 95, window_days_hobday: int = 11, window_spatial_hobday: Optional[int] = None, method_percentile: Literal["exact", "approximate"] = "approximate", dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, precision: float = 0.01, max_anomaly: float = 5.0, use_temp_checkpoints: bool = False, ) -> Tuple[xr.DataArray, xr.DataArray]: """ Identify extreme events using day-of-year (i.e. climatological percentile threshold). For each spatial point and day-of-year, computes the p-th percentile of values within a window_days_hobday day window across all years. This implements the standard methodology for marine heatwave detection threshold calculation. Parameters: ----------- da : xarray.DataArray Anomaly data with dimensions (time, lat, lon) Must be chunked with time dimension unbounded (time: -1) threshold_percentile : float, default=95 Percentile to compute (0-100) window_days_hobday : int, default=11 Window in days window_spatial_hobday : int, default=None Window size in cells method_percentile : str, default='approximate' Method for percentile computation ('exact' or 'approximate') precision : float, default=0.01 Precision for histogram bins in approximate method max_anomaly : float, default=5.0 Maximum anomaly value for histogram binning Returns: -------- tuple (extreme_bool, thresholds) extreme_bool : xarray.DataArray Boolean mask indicating extreme events (True for extreme days) thresholds : xarray.DataArray Threshold values with dimensions (dayofyear, lat, lon) """ # Check if there is sufficient samples N_years = np.unique(da[coordinates["time"]].dt.year).size N_samples = N_years * window_days_hobday * (window_spatial_hobday if window_spatial_hobday is not None else 1) ** 2 N_above_threshold = N_samples * (1.0 - threshold_percentile / 100.0) if N_above_threshold < 50: # Make warning logger.warning( f"Not enough samples for accurate extreme detection: {N_above_threshold} < 50. " "Consider using a lower threshold_percentile, increasing your time-series size, " "increasing the window_days_hobday, or using a larger window_spatial_hobday." "If your time-series is very short, consider using method_percentile='exact'." ) # Add day-of-year coordinate (compute it to avoid chunked groupby issues) da = da.assign_coords(dayofyear=da[coordinates["time"]].dt.dayofyear.compute()).chunk(dict(zip(da.dims, da.chunks))).persist() # Group by day-of-year and compute percentile if method_percentile == "exact": # Use apply_ufunc to compute DOY percentiles per spatial chunk in pure numpy. da_ufunc = da.chunk({dimensions["time"]: -1}) dayofyear_vals = da_ufunc[coordinates["time"]].dt.dayofyear.values half_w = window_days_hobday // 2 # Pre-compute boolean masks (which time indices contribute to each DOY) doy_masks = [] for doy in range(1, 367): mask = np.zeros(len(dayofyear_vals), dtype=bool) for offset in range(-half_w, half_w + 1): target = ((doy - 1 + offset) % 366) + 1 mask |= dayofyear_vals == target doy_masks.append(mask) def _doy_percentiles(data, doy_masks, percentile): """Compute per-DOY percentiles. data: (*spatial, time) -> (*spatial, 366).""" result = np.full(data.shape[:-1] + (366,), np.nan, dtype=np.float32) for i, mask in enumerate(doy_masks): if mask.any(): result[..., i] = np.nanpercentile(data[..., mask], percentile, axis=-1) return result thresholds = xr.apply_ufunc( _doy_percentiles, da_ufunc, input_core_dims=[[dimensions["time"]]], output_core_dims=[["dayofyear"]], dask="parallelized", kwargs={"doy_masks": doy_masks, "percentile": threshold_percentile}, output_dtypes=[np.float32], dask_gufunc_kwargs={"output_sizes": {"dayofyear": 366}}, ) # Assign dayofyear coordinate values and move dayofyear to first dimension thresholds = thresholds.assign_coords(dayofyear=np.arange(1, 367)).transpose("dayofyear", ...) else: # Optimised histogram approximation method thresholds = _compute_histogram_quantile_2d( da, threshold_percentile / 100.0, window_days_hobday=window_days_hobday, window_spatial_hobday=window_spatial_hobday, dimensions=dimensions, precision=precision, max_anomaly=max_anomaly, use_temp_checkpoints=use_temp_checkpoints, ) if use_temp_checkpoints: logger.debug("Checkpointing thresholds to break graph dependencies") thresholds = checkpoint_to_zarr(thresholds, name="thresholds", timedim="dayofyear") # Extract spatial chunk sizes from input data for alignment # Use most common chunk size to handle irregular chunks robustly spatial_chunks = {} for dim_key in ["x", "y"]: if dim_key in dimensions: dim_name = dimensions[dim_key] if dim_name in da.dims: dim_index = da.dims.index(dim_name) chunks_tuple = da.chunks[dim_index] # Get the most common chunk size (handles irregular chunks better) spatial_chunks[dim_name] = max(set(chunks_tuple), key=chunks_tuple.count) # Drop time coordinate/dimension to avoid conflicts when comparing with data grouped by dayofyear coords_to_drop = [] if coordinates["time"] in thresholds.coords: coords_to_drop.append(coordinates["time"]) if dimensions["time"] in thresholds.coords and dimensions["time"] not in thresholds.dims: coords_to_drop.append(dimensions["time"]) if "time" in thresholds.coords and "time" not in thresholds.dims: coords_to_drop.append("time") if coords_to_drop: thresholds = thresholds.drop_vars(coords_to_drop) # Rechunk thresholds BEFORE comparison to align with input data # This eliminates expensive implicit rechunking during the groupby operation logger.debug(f"Aligning threshold chunks to match input data spatial chunks: {spatial_chunks}") thresholds = thresholds.chunk(spatial_chunks) # Compare anomalies to day-of-year specific thresholds # Assign dayofyear coordinate and use UniqueGrouper for chunked arrays da = da.assign_coords(dayofyear=da[coordinates["time"]].dt.dayofyear) extremes = da.groupby(dayofyear=xr.groupers.UniqueGrouper(labels=np.arange(1, 367))) >= thresholds # Drop unnecessary dayofyear coordinate if "dayofyear" in extremes.coords: extremes = extremes.drop_vars("dayofyear") # Rechunk to fix irregular time chunks created by groupby operation # Zarr requires uniform chunks, so we rechunk to match input data's time chunks time_dim_index = da.dims.index(dimensions["time"]) time_chunk_size = max(set(da.chunks[time_dim_index]), key=da.chunks[time_dim_index].count) rechunk_dict = {dimensions["time"]: time_chunk_size} rechunk_dict.update(spatial_chunks) logger.debug(f"Rechunking extremes to fix irregular chunks from groupby: {rechunk_dict}") extremes = extremes.chunk(rechunk_dict) if use_temp_checkpoints: logger.debug("Checkpointing extremes to break graph dependencies") extremes = checkpoint_to_zarr(extremes, name="extremes", timedim=dimensions["time"]) return extremes, thresholds # =============================================== # Detrended Baseline Anomaly Method (Old Method) # ===============================================
[docs] def add_decimal_year(da: xr.DataArray, dim: str = "time", coord: Optional[str] = None) -> xr.DataArray: """ Add decimal year coordinate to DataArray for trend analysis. Parameters ---------- da : xarray.DataArray Input data with datetime coordinate dim : str, optional Name of the time dimension coord : str, optional Name of the time coordinate (if different from dimension name) Returns ------- xarray.DataArray Input data with added 'decimal_year' coordinate """ # Use coordinate name if provided, otherwise use dimension name coord_name = coord if coord is not None else dim time = pd.to_datetime(da[coord_name]) start_of_year = pd.to_datetime(time.year.astype(str) + "-01-01") start_of_next_year = pd.to_datetime((time.year + 1).astype(str) + "-01-01") year_elapsed = (time - start_of_year).days year_duration = (start_of_next_year - start_of_year).days decimal_year = time.year + year_elapsed / year_duration return da.assign_coords(decimal_year=(dim, decimal_year))
def _compute_anomaly_detrended( da: xr.DataArray, std_normalise: bool = False, detrend_orders: Optional[List[int]] = None, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, force_zero_mean: bool = True, remove_harmonics: bool = True, ) -> xr.Dataset: """ Generate normalised anomalies by removing trends, seasonal cycles, and optionally standardising by local temporal variability using the detrended baseline method. Parameters ---------- da : xarray.DataArray Input data with time coordinate std_normalise : bool, default=False Whether to standardise anomalies by temporal variability detrend_orders : list, optional Polynomial orders for trend removal (default: [1] for linear) dimensions : dict, optional Mapping of dimensions to names in the data coordinates : dict, optional Mapping of coordinates to names in the data force_zero_mean : bool, default=True Whether to enforce zero mean in detrended data remove_harmonics : bool, default=True Whether to remove seasonal harmonics (annual and semi-annual cycles) Returns ------- xarray.Dataset Dataset containing anomalies, mask, and optionally standardised data """ # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) # Default detrend_orders to linear if not specified if detrend_orders is None: detrend_orders = [1] # Validate detrend_orders is not empty and contains valid values if not detrend_orders: raise ConfigurationError( "detrend_orders cannot be empty", details="At least one polynomial order must be specified for detrending", suggestions=[ "Use detrend_orders=[1] for linear detrending", "Use detrend_orders=[1, 2] for linear + quadratic detrending", "Remove detrend_orders optional parameter to use default [1]", ], ) # Validate all orders are positive integers if any(order < 1 for order in detrend_orders): invalid_orders = [order for order in detrend_orders if order < 1] raise ConfigurationError( f"Invalid polynomial orders: {invalid_orders}", details="Polynomial orders must be positive integers (≥ 1)", suggestions=[ "Use only positive integers for polynomial orders", "Common values: [1] for linear, [1,2] for linear+quadratic", f"Remove invalid orders: {invalid_orders}", ], ) da = da.astype(np.float32) # Ensure time is the first dimension for efficient processing if da.dims[0] != dimensions["time"]: da = da.transpose(dimensions["time"], ...) # Warn if using higher-order detrending without linear component if 1 not in detrend_orders and len(detrend_orders) > 1: print("Warning: Higher-order detrending without linear term may be unstable") # Add decimal year for trend modelling da = add_decimal_year(da, dim=dimensions["time"], coord=coordinates["time"]) dy = da.decimal_year.compute() # Build model matrix with constant term, trends, and seasonal harmonics model_components = [np.ones(len(dy))] # Constant term # Add polynomial trend terms centered_time = da.decimal_year - np.mean(dy) for order in detrend_orders: model_components.append(centered_time**order) # Add annual and semi-annual cycles (harmonics) if requested if remove_harmonics: model_components.extend( [ np.sin(2 * np.pi * dy), # Annual sine np.cos(2 * np.pi * dy), # Annual cosine np.sin(4 * np.pi * dy), # Semi-annual sine np.cos(4 * np.pi * dy), # Semi-annual cosine ] ) # Convert to numpy array for matrix operations model = np.array(model_components) # Orthogonalise model components for numerical stability for i in range(1, model.shape[0]): model[i] = model[i] - np.mean(model[i]) * model[0] # Compute pseudo-inverse for model fitting pmodel = np.linalg.pinv(model) n_coeffs = len(model_components) # Convert model matrices to xarray model_da = xr.DataArray( model.T, dims=[dimensions["time"], "coeff"], coords={ dimensions["time"]: da[coordinates["time"]].values, "coeff": np.arange(1, n_coeffs + 1), }, ).chunk({dimensions["time"]: da.chunks[0]}) pmodel_da = xr.DataArray( pmodel.T, dims=["coeff", dimensions["time"]], coords={ "coeff": np.arange(1, n_coeffs + 1), dimensions["time"]: da[coordinates["time"]].values, }, ).chunk({dimensions["time"]: da.chunks[0]}) # Prepare dimensions for model coefficients based on data structure dims = ["coeff"] coords = {"coeff": np.arange(1, n_coeffs + 1)} # Handle 1D (time series), 2D (unstructured) and 3D (gridded) data if "y" in dimensions: # 3D gridded case dims.extend([dimensions["y"], dimensions["x"]]) coords[dimensions["y"]] = da[coordinates["y"]].values coords[dimensions["x"]] = da[coordinates["x"]].values elif "x" in dimensions: # 2D unstructured case dims.append(dimensions["x"]) coords.update(da[coordinates["x"]].coords) # else: 1D time series case - no spatial dimensions to add # Fit model to data - use the actual dimensions of the result dot_result = pmodel_da.dot(da) # For dot product result, dimensions match input data's spatial dimensions spatial_dims = [dim for dim in da.dims if dim != dimensions["time"]] result_dims = ["coeff"] + spatial_dims # Build coordinates for the result result_coords = {"coeff": np.arange(1, n_coeffs + 1)} for dim in spatial_dims: if dim in da.coords: result_coords[dim] = da.coords[dim] model_fit_da = xr.DataArray(dot_result, dims=result_dims, coords=result_coords) # Remove trend and seasonal cycle da_detrend = (da.drop_vars({"decimal_year"}) - model_da.dot(model_fit_da).astype(np.float32)).persist() # Force zero mean if requested if force_zero_mean: da_detrend = da_detrend - da_detrend.mean(dim=dimensions["time"]) # Create ocean/land mask from first time step # Handle both spatial (3D) and time-series (1D) data spatial_dims = [dim for dim in ["x", "y"] if dim in dimensions] if spatial_dims: # Spatial data - create 2D/3D mask chunk_dict_mask = {dimensions[dim]: -1 for dim in spatial_dims} mask_temp = np.isfinite(da.isel({dimensions["time"]: 0})).chunk(chunk_dict_mask) # Drop time-related coordinates to create spatial mask vars_to_drop = [] if "decimal_year" in mask_temp.coords: vars_to_drop.append("decimal_year") if dimensions["time"] in mask_temp.coords: vars_to_drop.append(dimensions["time"]) if coordinates["time"] in mask_temp.coords: vars_to_drop.append(coordinates["time"]) mask = mask_temp.drop_vars(vars_to_drop) if vars_to_drop else mask_temp else: # 1D time series - create scalar mask indicating if any finite values exist chunk_dict_mask = {} # Empty for 1D case mask = xr.DataArray(np.any(np.isfinite(da.values)), dims=[], attrs={"description": "Time series validity mask"}) # Initialise output dataset data_vars = {"dat_anomaly": da_detrend, "mask": mask} # Ensure all original coordinates are preserved in the dataset coords_to_preserve = {} for coord_name in da.coords: if coord_name not in data_vars: # Don't override data variables coords_to_preserve[coord_name] = da.coords[coord_name] # Standardise anomalies by temporal variability if requested if std_normalise: # Calculate day-of-year standard deviation using cohorts std_day = flox.xarray.xarray_reduce( da_detrend, da_detrend[coordinates["time"]].dt.dayofyear, dim=dimensions["time"], func="std", isbin=False, method="cohorts", dtype=np.float32, ) # Calculate 30-day rolling standard deviation with annual wrapped padding std_day_wrap = std_day.pad(dayofyear=16, mode="wrap") std_rolling = np.sqrt((std_day_wrap**2).rolling(dayofyear=30, center=True).mean()).isel(dayofyear=slice(16, 366 + 16)) # Divide anomalies by rolling standard deviation # Replace any zeros or extremely small values with NaN to avoid division warnings std_rolling_safe = std_rolling.where(std_rolling > 1e-10, np.nan) da_detrend = da_detrend.assign_coords(dayofyear=da_detrend[coordinates["time"]].dt.dayofyear) da_stn = da_detrend.groupby(dayofyear=xr.groupers.UniqueGrouper(labels=np.arange(1, 367))) / std_rolling_safe # Drop dayofyear coordinate to avoid merge conflicts if "dayofyear" in da_stn.coords: da_stn = da_stn.drop_vars("dayofyear") # Rechunk data for efficient processing chunk_dict_std = chunk_dict_mask.copy() chunk_dict_std["dayofyear"] = -1 da_stn = da_stn.chunk(chunk_dict_mask) std_rolling = std_rolling.chunk(chunk_dict_std) # Add standardised data to output data_vars["dat_stn"] = da_stn data_vars["STD"] = std_rolling # Build output dataset with metadata return xr.Dataset(data_vars=data_vars, coords=coords_to_preserve).drop_vars("decimal_year") def _compute_anomaly_fixed_baseline( da: xr.DataArray, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, reference_period: Optional[Tuple[int, int]] = None, ) -> xr.Dataset: """ Compute anomalies using fixed baseline method with full time series climatology. This method computes a daily climatology using all available years in the dataset (or a specified reference period), then subtracts this climatology from the original data to obtain anomalies. Parameters ---------- da : xarray.DataArray Input data with time coordinate dimensions : dict, optional Mapping of dimensions to names in the data coordinates : dict, optional Mapping of coordinates to names in the data reference_period : tuple of (int, int), optional Year range (start_year, end_year) inclusive for computing the daily climatology. If None (default), uses all available years. Anomalies are still computed for the full time series. Returns ------- xarray.Dataset Dataset containing anomalies and mask """ # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) # Select data for climatology computation (optionally restricted to reference period) if reference_period is not None: start_year, end_year = reference_period if start_year > end_year: raise ConfigurationError( f"Invalid reference_period: start year ({start_year}) must be <= end year ({end_year})", details="The reference_period tuple must be (start_year, end_year) with start_year <= end_year", suggestions=[f"Swap the order: use reference_period=({end_year}, {start_year})"], ) years = da[coordinates["time"]].dt.year year_mask = (years >= start_year) & (years <= end_year) da_for_clim = da.isel({dimensions["time"]: year_mask}) if da_for_clim.sizes[dimensions["time"]] == 0: data_min_year = int(years.min().values) data_max_year = int(years.max().values) raise ConfigurationError( f"No data found in reference_period ({start_year}, {end_year})", details=f"Dataset spans {data_min_year}-{data_max_year} but no timesteps fall within the specified period", suggestions=[ f"Adjust reference_period to overlap with data range ({data_min_year}-{data_max_year})", "Set reference_period=None to use the full time series", ], ) logger.debug( f"Using reference_period ({start_year}-{end_year}): " f"{da_for_clim.sizes[dimensions['time']]} of {da.sizes[dimensions['time']]} timesteps" ) else: da_for_clim = da # Compute daily climatology using flox for efficiency logger.debug("Computing daily climatology across %s", "reference period" if reference_period else "all years") daily_climatology = flox.xarray.xarray_reduce( da_for_clim, da_for_clim[coordinates["time"]].dt.dayofyear, dim=dimensions["time"], func="nanmean", isbin=False, method="cohorts", dtype=np.float32, ).persist() # Compute anomalies by subtracting daily climatology from original data logger.debug("Computing anomalies by subtracting daily climatology") da = da.assign_coords(dayofyear=da[coordinates["time"]].dt.dayofyear) anomalies = da.groupby(dayofyear=xr.groupers.UniqueGrouper(labels=np.arange(1, 367))) - daily_climatology anomalies = anomalies.astype(np.float32) # Drop dayofyear coordinate to avoid merge conflicts if "dayofyear" in anomalies.coords: anomalies = anomalies.drop_vars("dayofyear") # Create ocean/land mask from first time step # Handle both spatial (3D) and time-series (1D) data spatial_dims = [dim for dim in ["x", "y"] if dim in dimensions] if spatial_dims: # Spatial data - create 2D/3D mask chunk_dict_mask = {dimensions[dim]: -1 for dim in spatial_dims} mask = np.isfinite(da.isel({dimensions["time"]: 0})).drop_vars({coordinates["time"]}).chunk(chunk_dict_mask) else: # 1D time series - create scalar mask indicating if any finite values exist mask = xr.DataArray(np.any(np.isfinite(da.values)), dims=[], attrs={"description": "Time series validity mask"}) # Build output dataset return xr.Dataset({"dat_anomaly": anomalies, "mask": mask}) def _compute_anomaly_detrend_fixed_baseline( da: xr.DataArray, detrend_orders: Optional[List[int]] = None, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, force_zero_mean: bool = True, reference_period: Optional[Tuple[int, int]] = None, ) -> xr.Dataset: """ Compute anomalies using fixed detrended baseline method. This method first removes polynomial trends (without harmonics) from the data, then removes a daily climatology from the detrended signal. The trend removal always uses the full time series; only the climatology step respects reference_period. Parameters ---------- da : xarray.DataArray Input data with time coordinate detrend_orders : list, optional Polynomial orders for trend removal (default: [1] for linear) dimensions : dict, optional Mapping of dimensions to names in the data coordinates : dict, optional Mapping of coordinates to names in the data force_zero_mean : bool, default=True Whether to enforce zero mean in detrended data reference_period : tuple of (int, int), optional Year range (start_year, end_year) inclusive for computing the daily climatology. If None (default), uses all available years. Only affects the climatology step, not the polynomial detrending. Returns ------- xarray.Dataset Dataset containing anomalies and mask """ # Infer and validate dimensions and coordinates dimensions, coordinates = _infer_dims_coords(da, dimensions, coordinates) logger.debug(f"Removing polynomial trends of orders: {detrend_orders}") # Step 1: Remove polynomial trends (without harmonics) using _compute_anomaly_detrended detrended_result = _compute_anomaly_detrended( da=da, std_normalise=False, detrend_orders=detrend_orders, dimensions=dimensions, coordinates=coordinates, force_zero_mean=force_zero_mean, remove_harmonics=False, # Only remove trends, not harmonics )["dat_anomaly"].persist() # Step 2: Compute daily climatology and anomalies using _compute_anomaly_fixed_baseline logger.debug("Computing daily climatology and anomalies from detrended data") final_result = _compute_anomaly_fixed_baseline( da=detrended_result, dimensions=dimensions, coordinates=coordinates, reference_period=reference_period, ) return final_result def _rolling_histogram_quantile( hist_chunk: NDArray[np.int32], window_days_hobday: int, q: float, bin_centers: NDArray[np.float64], ) -> NDArray[np.float32]: """ Efficiently compute quantile thresholds from histogram data using vectorised numpy operations. Improved robust interpolation handles sparse histograms, especially in the tails. Parameters ---------- hist_chunk : numpy.ndarray Histogram data with shape (dayofyear, da_bin) window_days_hobday : int Rolling window size for day-of-year smoothing q : float Quantile to compute (0-1) bin_centers : numpy.ndarray Bin centre values for interpolation Returns ------- numpy.ndarray Quantile thresholds with shape (dayofyear,) """ n_doy, n_bins = hist_chunk.shape eps = 1e-10 # Pad histogram with wrap mode for day-of-year cycling pad_size = window_days_hobday // 2 hist_pad = np.concatenate([hist_chunk[-pad_size:], hist_chunk, hist_chunk[:pad_size]], axis=0) # Apply rolling sum using stride tricks FTW windowed_view = sliding_window_view(hist_pad, window_days_hobday, axis=0) hist_windowed = np.sum(windowed_view, axis=-1) # Apply gaussian smoothing along bin dimension # sigma = 2 # hist_smoothed = gaussian_filter1d( # hist_windowed.astype(np.float32), sigma=sigma, axis=1, mode="constant", cval=0.0 # Along bin dimension # ).astype(np.float32) # Count-based interpolation (rather than interpolating CDF in probability space) # Calculate cumulative counts (not normalized CDF) cumsum = np.cumsum(hist_windowed, axis=1, dtype=np.int32) total_counts = cumsum[:, -1] # Total count for each day # Calculate the exact position where the quantile should be # For n samples, the q-th quantile is at position q*(n-1) # It is q*n here since we're working with cumulative counts quantile_position = q * total_counts # Vectorised search for the bins containing the quantile position # searchsorted with side='right' gives the first bin where cumsum > quantile_position idx_upper = np.zeros(n_doy, dtype=np.int32) for i in range(n_doy): if total_counts[i] <= 0: # No data idx_upper[i] = 0 else: # Find first bin where cumulative count exceeds target position idx_upper[i] = np.searchsorted(cumsum[i], quantile_position[i], side="right") # Clip to valid range idx_upper = np.clip(idx_upper, 0, n_bins - 1) idx_lower = np.maximum(0, idx_upper - 1) # Extract values for vectorised interpolation doy_indices = np.arange(n_doy, dtype=np.int32) # Get cumulative counts at the boundaries count_lower = np.where(idx_lower >= 0, cumsum[doy_indices, idx_lower], 0) count_upper = cumsum[doy_indices, idx_upper] # Bin centers for interpolation bin_lower = bin_centers[idx_lower] bin_upper = bin_centers[idx_upper] # Compute interpolation fraction based on counts count_diff = count_upper - count_lower safe_diff = np.where(count_diff > eps, count_diff, 1.0) frac = np.where(count_diff > eps, (quantile_position - count_lower) / safe_diff, 0.5) # If no difference, use midpoint # Linear interpolation between bin centers threshold = bin_lower + frac * (bin_upper - bin_lower) # Handle edge cases # If total_counts is 0, return NaN threshold = np.where(total_counts > 0, threshold, np.nan) # If at the first bin (all data is negative), use the first bin center threshold = np.where((idx_upper == 0) & (total_counts > 0), bin_centers[0], threshold) return threshold.astype(np.float32) def _compute_histogram_quantile_2d( da: xr.DataArray, q: float, window_days_hobday: int = 11, window_spatial_hobday: Optional[int] = None, bin_edges: Optional[NDArray[np.float64]] = None, dimensions: Optional[Dict[str, str]] = None, precision: float = 0.01, max_anomaly: float = 5.0, use_temp_checkpoints: bool = False, ) -> xr.DataArray: """ Efficiently compute quantiles using binned histograms optimised for extreme values. Uses fine-grained bins for positive anomalies and a single bin for negative values. Parameters ---------- da : xarray.DataArray Input data array q : float Quantile to compute (0-1) window_days_hobday : int, default=11 Rolling window size for day-of-year quantiles window_spatial_hobday : int, default=None Spatial window size for day-of-year quantiles bin_edges : numpy.ndarray, optional Custom bin edges for histogram computation dimensions : dict, optional Dimension mapping dictionary precision : float, default=0.01 Precision for positive anomaly bins max_anomaly : float, default=5.0 Maximum anomaly value for binnin Returns ------- xarray.DataArray Computed quantile value for each spatial location """ if bin_edges is None: # Create optimised asymmetric bins bin_edges = np.concatenate( [[-np.inf], np.arange(-precision, max_anomaly + precision, precision, dtype=np.float32)], dtype=np.float32 ) bin_centers_array = (bin_edges[1:] + bin_edges[:-1]) / 2 bin_centers_array[0] = 0.0 bin_centers = xr.DataArray( bin_centers_array.astype(np.float32), dims=["da_bin"], coords={"da_bin": np.arange(len(bin_centers_array), dtype=np.uint16)}, name="bin_centers", ) chunk_dict = {dimensions["time"]: -1} chunk_dict[dimensions["x"]] = 16 if "y" in dimensions: chunk_dict[dimensions["y"]] = 16 da_bin = ( xr.DataArray( np.digitize(da.data, bin_edges) - 1, # -1 so first bin is 0 dims=da.dims, coords=da.coords, name="da_bin", ) .chunk(chunk_dict) .astype(np.uint16) ) if use_temp_checkpoints: logger.debug("Checkpointing binned data to break graph dependencies") da_bin = checkpoint_to_zarr(da_bin, name="da_bin", timedim=dimensions["time"]).chunk(chunk_dict) # Construct 2D histogram using flox (in doy & anomaly) hist_raw = flox.xarray.xarray_reduce( da_bin, da_bin.dayofyear, da_bin, dim=[dimensions["time"]], func="count", expected_groups=(np.arange(1, 367, dtype=np.uint16), np.arange(len(bin_edges) - 1, dtype=np.uint16)), isbin=(False, False), dtype=np.uint16, fill_value=0, ) hist_raw.name = None # Apply spatial-kernel smoothing to the histogram if window_spatial_hobday is not None and window_spatial_hobday > 1: pad_size = window_spatial_hobday // 2 lon_dim, lat_dim = dimensions.get("x"), dimensions.get("y") hist_rolled = hist_raw # Periodic padding in longitude, rolling mean in both dimensions, then trim if lon_dim in hist_raw.dims: hist_rolled = hist_rolled.pad({lon_dim: pad_size}, mode="wrap") hist_rolled = hist_rolled.rolling({lon_dim: window_spatial_hobday}, center=True, min_periods=1).sum() hist_rolled = hist_rolled.isel({lon_dim: slice(pad_size, pad_size + hist_raw.sizes[lon_dim])}) # Standard rolling in latitude if lat_dim in hist_raw.dims: hist_rolled = hist_rolled.rolling({lat_dim: window_spatial_hobday}, center=True, min_periods=1).sum() hist_raw = hist_rolled def _compute_quantile_with_params(hist_chunk, bin_centers_chunk): return _rolling_histogram_quantile(hist_chunk, window_days_hobday, q, bin_centers_chunk) # Rechunk histogram so core dimensions are unchunked for apply_ufunc # Create chunk dict for hist_raw that preserves spatial chunks but drops time hist_chunk_dict = {dimensions["x"]: chunk_dict.get(dimensions["x"], 16), "dayofyear": -1, "da_bin": -1} if "y" in dimensions: hist_chunk_dict[dimensions["y"]] = chunk_dict.get(dimensions["y"], 16) hist_raw = hist_raw.chunk(hist_chunk_dict) # Apply the optimised computation using apply_ufunc threshold = xr.apply_ufunc( _compute_quantile_with_params, hist_raw, bin_centers, input_core_dims=[["dayofyear", "da_bin"], ["da_bin"]], output_core_dims=[["dayofyear"]], dask="parallelized", vectorize=True, output_dtypes=[np.float32], dask_gufunc_kwargs={"output_sizes": {"dayofyear": 366}}, keep_attrs=True, ) if use_temp_checkpoints: logger.debug("Checkpointing threshold to break graph dependencies") threshold = checkpoint_to_zarr(threshold, name="threshold") # Drop time coordinate to avoid conflicts when comparing with data grouped by dayofyear if dimensions["time"] in threshold.coords: threshold = threshold.drop_vars(dimensions["time"]) # Set threshold to NaN for spatial points that contain NaN values nan_mask = da.isel({dimensions["time"]: 0}).isnull().compute() threshold = threshold.where(~nan_mask).persist() # Validate threshold values against bounds upper_bound = bin_edges[-2] lower_bound = bin_edges[3] # We want this to be positive so that constant=0 anomalies will not be "extreme" # Check if any values are too high (ignore NaN values) too_high = threshold > upper_bound if too_high.any(): warnings.warn( f"Quantile values exceed expected range: max={threshold.max().compute():.4f} > {upper_bound:.4f}. " f"Consider increasing max_anomaly parameter (currently {max_anomaly:.2f}) or using a lower percentile threshold.", UserWarning, stacklevel=2, ) # Check if any values are too low (ignore NaN values) too_low = threshold < lower_bound if too_low.any(): warnings.warn( f"Quantile values below expected range in some locations: min={threshold.min().compute():.4f} < {lower_bound:.4f}. " "This is likely due to a constant anomaly in certain (e.g. due to sea ice). " "Double check the computed threshold values are correct.", UserWarning, stacklevel=2, ) # Set too low values to lower bound -- This is to ensure that constant=0 anomalies will not be "extreme" threshold = threshold.where(~too_low, lower_bound) return threshold def _compute_histogram_quantile_1d( da: xr.DataArray, q: float, dim: str = "time", bin_edges: Optional[NDArray[np.float64]] = None, precision: float = 0.01, max_anomaly: float = 5.0, ) -> xr.DataArray: """ Efficiently compute quantiles using binned histograms optimised for extreme values. Uses fine-grained bins for positive anomalies and a single bin for negative values. Improved robust interpolation handles empty bins in the tails. Parameters ---------- da : xarray.DataArray Input data array q : float Quantile to compute (0-1) dim : str, optional Dimension along which to compute quantile bin_edges : numpy.ndarray, optional Custom bin edges for histogram computation precision : float, default=0.01 Precision for positive anomaly bins max_anomaly : float, default=5.0 Maximum anomaly value for binning Returns ------- xarray.DataArray Computed quantile value for each spatial location """ if bin_edges is None: # Create optimised asymmetric bins bin_edges = np.concatenate([[-np.inf], np.arange(-precision, max_anomaly + precision, precision)]) # Compute histogram hist = histogram(da, bins=[bin_edges], dim=[dim]).persist() # Convert to PDF and CDF hist_sum = hist.sum(dim=f"{da.name}_bin") + 1e-10 pdf = hist / hist_sum cdf = pdf.cumsum(dim=f"{da.name}_bin").persist() # Get bin centers bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 bin_centers[0] = 0.0 # Set negative bin centre to 0 eps = 1e-10 # Find bins for interpolation # Find first bin where CDF >= (q - eps) - this becomes upper bound cdf_above_q = cdf >= (q - eps) idx_upper = cdf_above_q.argmax(dim=f"{da.name}_bin") # Get CDF value one point to the left of idx_upper idx_before_upper = xr.where(idx_upper - 1 > 0, idx_upper - 1, 0) # Extract the target CDF value (avoiding negative indexing issues) idx_before_upper_computed = idx_before_upper.compute() cdf_target = cdf.isel({f"{da.name}_bin": idx_before_upper_computed}) # Find idx_lower: first bin where CDF > cdf_target cdf_above_target = cdf > cdf_target idx_lower = cdf_above_target.argmax(dim=f"{da.name}_bin") # Ensure bounds are valid idx_lower = xr.where(idx_lower < 0, 0, xr.where(idx_lower > len(bin_centers) - 2, len(bin_centers) - 2, idx_lower)) idx_upper = xr.where(idx_upper < 1, 1, xr.where(idx_upper > len(bin_centers) - 1, len(bin_centers) - 1, idx_upper)) # Extract CDF and bin values for interpolation idx_lower_computed = idx_lower.compute() idx_upper_computed = idx_upper.compute() cdf_lower = cdf.isel({f"{da.name}_bin": idx_lower_computed}) cdf_upper = cdf.isel({f"{da.name}_bin": idx_upper_computed}) bin_lower = bin_centers[idx_lower_computed] bin_upper = bin_centers[idx_upper_computed] # Robust interpolation with proper handling of degenerate cases denom = cdf_upper - cdf_lower # Handle exact matches and zero denominators exact_match = (xr.ufuncs.fabs(cdf_lower - q) < eps).persist() zero_denom = (xr.ufuncs.fabs(denom) <= eps).persist() # Standard interpolation frac = (q - cdf_lower) / xr.where(xr.ufuncs.fabs(denom) > eps, denom, 1.0) threshold = bin_lower + frac * (bin_upper - bin_lower) # For exact matches, use the lower bin center threshold = xr.where(exact_match, bin_lower, threshold) # For zero denominator without exact match, use bin midpoint no_exact_match = zero_denom & ~exact_match threshold = xr.where(no_exact_match, (bin_lower + bin_upper) / 2, threshold) # Set threshold to NaN for spatial points that contain NaN values nan_mask = da.isnull().any(dim=dim) threshold = threshold.where(~nan_mask).drop_vars(f"{da.name}_bin").persist() # Validate threshold against bounds upper_bound = bin_edges[-2] lower_bound = bin_edges[3] # We want this to be positive so that constant=0 anomalies will not be "extreme" # Check if any values are too high (ignore NaN values) too_high = (threshold > upper_bound) & threshold.notnull() if too_high.any(): warnings.warn( f"Quantile values exceed expected range: max={threshold.max().compute():.4f} > {upper_bound:.4f}. " f"Consider increasing max_anomaly parameter (currently {max_anomaly:.2f}) or using a lower percentile threshold.", UserWarning, stacklevel=2, ) # Check if any values are too low (ignore NaN values) too_low = (threshold < lower_bound) & threshold.notnull() if too_low.any(): warnings.warn( f"Quantile values below expected range in some locations: min={threshold.min().compute():.4f} < {lower_bound:.4f}. " "This is likely due to a constant anomaly in certain (e.g. due to sea ice). " "Double check the computed threshold values are correct.", UserWarning, stacklevel=2, ) # Set too low values to lower bound -- This is to ensure that constant=0 anomalies will not be "extreme" threshold = threshold.where(~too_low, lower_bound).persist() return threshold # ====================================== # Constant (in time) Extreme Definition # ====================================== def _identify_extremes_constant( da: xr.DataArray, threshold_percentile: float = 95, method_percentile: Literal["exact", "approximate"] = "approximate", dimensions: Optional[Dict[str, str]] = None, precision: float = 0.01, max_anomaly: float = 5.0, ) -> Tuple[xr.DataArray, xr.DataArray]: """ Identify extreme events exceeding a constant (in time) percentile threshold. i.e. There is 1 threshold for each spatial point, computed across all time. Returns both the extreme events boolean mask and the thresholds used. """ if method_percentile == "exact": # Compute exact percentile (memory-intensive) # Determine appropriate chunk size based on data dimensions if "y" in dimensions: rechunk_size = "auto" else: rechunk_size = 100 * int(np.sqrt(da[dimensions["x"]].size) * 1.5 / 100) # N.B.: If this rechunk_size is too small, then dask will be overwhelmed by the number of tasks chunk_dict = {dimensions[dim]: rechunk_size for dim in ["x", "y"] if dim in dimensions} chunk_dict[dimensions["time"]] = -1 da_rechunk = da.chunk(chunk_dict) # Calculate threshold threshold = da_rechunk.quantile(threshold_percentile / 100.0, dim=dimensions["time"]) else: # Use an efficient histogram-based method with specified accuracy threshold = _compute_histogram_quantile_1d( da, threshold_percentile / 100.0, dim=dimensions["time"], precision=precision, max_anomaly=max_anomaly ) # Clean up coordinates if needed if "quantile" in threshold.coords: threshold = threshold.drop_vars("quantile") # Ensure spatial dimensions are fully loaded for efficient comparison spatial_chunks = {dimensions[dim]: -1 for dim in ["x", "y"] if dim in dimensions} threshold = threshold.chunk(spatial_chunks).persist() # Create boolean mask for values exceeding threshold extremes = da >= threshold # Clean up coordinates if needed if "quantile" in extremes.coords: extremes = extremes.drop_vars("quantile") extremes = extremes.astype(bool).chunk(dict(zip(da.dims, da.chunks))).persist() return extremes, threshold