Source code for marEx.plotX.base

"""Base classes and utilities for the plotX visualisation system.

This module provides the core infrastructure for plotting marine extreme event data,
supporting both structured and unstructured grids with comprehensive configuration
and animation capabilities.
"""

import gc
import shutil
import subprocess
import tempfile
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import dask
import numpy as np
import xarray as xr
from numpy.typing import NDArray

from ..exceptions import DependencyError, VisualisationError
from ..logging_config import configure_logging, get_logger

# Get module logger
logger = get_logger(__name__)

# Handle optional dependencies for plotting
try:
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature

    HAS_CARTOPY = True
except ImportError:
    HAS_CARTOPY = False
    ccrs = None
    cfeature = None

try:
    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes
    from matplotlib.cm import ScalarMappable
    from matplotlib.colorbar import Colorbar
    from matplotlib.colors import BoundaryNorm, ListedColormap, Normalize
    from matplotlib.figure import Figure

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False
    plt = None
    Axes = None
    ScalarMappable = None
    Colorbar = None
    BoundaryNorm = None
    ListedColormap = None
    Normalize = None
    Figure = None

try:
    from PIL import Image

    HAS_PIL = True
except ImportError:
    HAS_PIL = False
    Image = None


def _check_plotting_dependencies() -> None:
    """Check if plotting dependencies are available and raise informative error if not."""
    from .._dependencies import require_dependencies

    require_dependencies(["matplotlib", "cartopy"], "Plotting functionality")


[docs] @dataclass class PlotConfig: """Configuration class for plot parameters Attributes: title: Plot title var_units: Variable units for colorbar label issym: Whether data is symmetric (centers colormap at 0) cmap: Colormap name or ListedColormap object cperc: Percentile range for automatic color limits [min, max] clim: Manual color limits (vmin, vmax) show_colorbar: Whether to display colorbar grid_lines: Whether to display grid lines grid_labels: Whether to display grid labels dimensions: Mapping of conceptual to actual dimension names coordinates: Mapping of conceptual to actual coordinate names norm: Custom normalization (BoundaryNorm or Normalize) plot_IDs: Whether to plot object IDs with random colors extend: Colorbar extension ('neither', 'both', 'min', 'max') verbose: Enable verbose logging quiet: Enable quiet logging projection: Cartopy projection for map plots framerate: Frames per second for animations (default 10) """ title: Optional[str] = None var_units: str = "" issym: bool = False cmap: Optional[Union[str, ListedColormap]] = None cperc: List[int] = None clim: Optional[Tuple[float, float]] = None show_colorbar: bool = True grid_lines: bool = True grid_labels: bool = False dimensions: Dict[str, str] = None coordinates: Dict[str, str] = None norm: Optional[Union[BoundaryNorm, Normalize]] = None plot_IDs: bool = False extend: str = "both" verbose: Optional[bool] = None quiet: Optional[bool] = None projection: Optional[Any] = None framerate: int = 10
[docs] def __post_init__(self) -> None: """Initialise default values and configure logging.""" if self.cperc is None: self.cperc = [4, 96] if self.dimensions is None: self.dimensions = {"time": "time", "y": "lat", "x": "lon"} if self.coordinates is None: self.coordinates = {"time": "time", "y": "lat", "x": "lon"} if self.plot_IDs: self.show_colorbar = False if self.projection is None: self.projection = ccrs.Robinson() # Configure logging if verbose/quiet parameters are provided if self.verbose is not None or self.quiet is not None: configure_logging(verbose=self.verbose, quiet=self.quiet)
def _validate_dimensions_exist(da: xr.DataArray, dimensions: Dict[str, str]) -> None: """Validate that required dimensions exist in the dataset. Time dimension is optional.""" missing_dims = [] for concept_dim, actual_dim in dimensions.items(): # Time dimension is optional for plotting - only spatial dimensions are required if concept_dim == "time" and actual_dim not in da.dims: continue if actual_dim not in da.dims: missing_dims.append(f"'{actual_dim}' (for {concept_dim})") if missing_dims: available_dims = list(da.dims) raise VisualisationError( f"Missing required dimensions: {', '.join(missing_dims)}", details=f"Dataset has dimensions: {available_dims}", suggestions=[ "Check dimension names in your data", "Update the 'dimensions' parameter to match your data structure", f"Available dimensions: {available_dims}", ], context={ "missing_dimensions": missing_dims, "available_dimensions": available_dims, "provided_dimensions": dimensions, }, ) def _validate_coordinates_exist(da: xr.DataArray, coordinates: Dict[str, str]) -> None: """Validate that required coordinates exist in the dataset. Time coordinate is optional.""" missing_coords = [] for concept_coord, actual_coord in coordinates.items(): # Time coordinate is optional for plotting - only spatial coordinates are required if concept_coord == "time" and actual_coord not in da.coords: continue if actual_coord not in da.coords: missing_coords.append(f"'{actual_coord}' (for {concept_coord})") if missing_coords: available_coords = list(da.coords) raise VisualisationError( f"Missing required coordinates: {', '.join(missing_coords)}", details=f"Dataset has coordinates: {available_coords}", suggestions=[ "Check coordinate names in your data", "Update the 'coordinates' parameter to match your data structure", f"Available coordinates: {available_coords}", ], context={ "missing_coordinates": missing_coords, "available_coordinates": available_coords, "provided_coordinates": coordinates, }, )
[docs] class PlotterBase: """Base class for all plotters providing common functionality. This class provides the core infrastructure for plotting marine extreme event data, including parameter setup, map features, colorbars, and animation capabilities. """
[docs] def __init__( self, xarray_obj: xr.DataArray, dimensions: Optional[Dict[str, str]] = None, coordinates: Optional[Dict[str, str]] = None, ) -> None: """Initialise the plotter with data and coordinate mappings. Args: xarray_obj: The data to plot dimensions: Mapping of conceptual to actual dimension names coordinates: Mapping of conceptual to actual coordinate names """ _check_plotting_dependencies() self.da = xarray_obj # Set default dimensions and coordinates if not provided if dimensions is None: dimensions = {"time": "time", "y": "lat", "x": "lon"} if coordinates is None: coordinates = {"time": "time", "y": "lat", "x": "lon"} self.dimensions = dimensions self.coordinates = coordinates # Validate dimensions and coordinates exist in the data _validate_dimensions_exist(self.da, self.dimensions) _validate_coordinates_exist(self.da, self.coordinates) # Cache common features self._land = cfeature.LAND.with_scale("50m") self._coastlines = cfeature.COASTLINE.with_scale("50m")
def _setup_common_params(self, config: PlotConfig) -> Tuple[ Union[str, ListedColormap], Optional[Union[BoundaryNorm, Normalize]], Optional[Tuple[float, float]], str, str, ]: """Centralise common parameter setup""" self.setup_plot_params() if config.plot_IDs: cmap, norm, var_units = self.setup_id_plot_params(config.cmap) clim = None extend = "neither" self.da = self.da.where(self.da > 0) # Fill value to NaN (get rid of 0s) else: if config.cmap is None: cmap = "RdBu_r" if config.issym else "viridis" else: cmap = config.cmap norm = config.norm if config.clim is None and norm is None: # Sample data to avoid loading entire time series into memory time_dim = self.dimensions.get("time", "time") if time_dim in self.da.dims: sampled_da = self.da.isel({time_dim: slice(None, None, 10)}) else: sampled_da = self.da clim = self.clim_robust(sampled_da.values, config.issym, config.cperc) else: clim = config.clim var_units = config.var_units extend = config.extend return cmap, norm, clim, var_units, extend def _setup_axes(self, ax: Optional[Axes] = None, projection: Optional[Any] = None) -> Tuple[Figure, Axes]: """Create or use existing axes with projection""" if ax is None: # Use provided projection or default to Robinson proj = projection if projection is not None else ccrs.Robinson() fig = plt.figure(figsize=(7, 5)) ax = plt.axes(projection=proj) else: fig = ax.get_figure() return fig, ax def _add_map_features(self, ax: Axes, grid_lines: bool = True, grid_labels: bool = True) -> None: """Add common map features to the plot""" ax.add_feature(self._land, facecolor="darkgrey", zorder=2) ax.add_feature(self._coastlines, linewidth=0.5, zorder=3) if grid_lines: ax.gridlines( crs=ccrs.PlateCarree(), draw_labels=grid_labels, linewidth=1, color="gray", alpha=0.5, linestyle="--", zorder=4, ) def _setup_colorbar( self, fig: Figure, im: Union[ScalarMappable, Any], show_colorbar: bool, var_units: str, extend: str = "both", position: Optional[List[float]] = None, ) -> Optional[Colorbar]: """Set up colorbar with common parameters""" if not show_colorbar: return None if position is not None: # For column plots cbar_ax = fig.add_axes(position) cb = fig.colorbar(im, cax=cbar_ax, extend=extend) else: # For single plots cb = plt.colorbar(im, shrink=0.6, ax=plt.gca(), extend=extend) if var_units: cb.ax.set_ylabel(var_units, fontsize=10) cb.ax.tick_params(labelsize=10) return cb def _get_title(self, time_index: int, col_name: str, dimensions: Optional[Dict[str, str]] = None) -> str: """Generate appropriate title based on dimension""" if dimensions is None: dimensions = self.dimensions if col_name == dimensions["time"]: time_coord = self.coordinates.get("time", "time") return f"{self.da[time_coord].isel({col_name: time_index}).dt.strftime('%Y-%m-%d').values}" return f"{col_name}={self.da[col_name].isel({col_name: time_index}).values}"
[docs] def single_plot(self, config: PlotConfig, ax: Optional[Axes] = None) -> Tuple[Figure, Axes, Any]: """Make a single plot with given configuration""" cmap, norm, clim, var_units, extend = self._setup_common_params(config) fig, ax = self._setup_axes(ax, config.projection) # Call implementation-specific plot function ax, im = self.plot(ax=ax, cmap=cmap, clim=clim, norm=norm) if config.title: ax.set_title(config.title, size=12) self._setup_colorbar(fig, im, config.show_colorbar, var_units, extend) self._add_map_features(ax, config.grid_lines, config.grid_labels) return fig, ax, im
[docs] def multi_plot( self, config: PlotConfig, col: str = "time", col_wrap: int = 3 ) -> Tuple[Figure, NDArray[Any]]: # pragma: no cover """Make wrapped subplots with given configuration""" npanels = self.da[col].size nrows = int(np.ceil(npanels / col_wrap)) ncols = min(npanels, col_wrap) cmap, norm, clim, var_units, extend = self._setup_common_params(config) fig = plt.figure(figsize=(6 * ncols, 3 * nrows)) axes = fig.subplots(nrows, ncols, subplot_kw={"projection": config.projection}).flatten() # Create a single plotter instance to be reused base_plotter = type(self)(self.da) for attr in ["fpath_tgrid", "fpath_ckdtree"]: if hasattr(self, attr): setattr(base_plotter, attr, getattr(self, attr)) for i, ax in enumerate(axes): if i < npanels: title = self._get_title(i, col, config.dimensions) # Create new config for individual panel panel_config = PlotConfig( title=title, cmap=cmap, clim=clim, show_colorbar=False, grid_labels=False, norm=norm, plot_IDs=False, extend=extend, dimensions=config.dimensions, coordinates=config.coordinates, projection=config.projection, ) # Update data in base plotter instead of creating new instance base_plotter.da = self.da.isel({col: i}) # Plot individual panel using the same plotter instance base_plotter.single_plot(panel_config, ax=ax) else: fig.delaxes(ax) # Add single colorbar for all panels if config.show_colorbar: fig.subplots_adjust(right=0.9) if norm is None and clim is not None: # Create a proper norm from clim from matplotlib.colors import Normalize norm = Normalize(vmin=clim[0], vmax=clim[1]) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) self._setup_colorbar(fig, sm, True, var_units, extend, position=[0.92, 0.15, 0.02, 0.7]) return fig, axes
[docs] def animate( self, config: PlotConfig, plot_dir: Union[str, Path] = "./", file_name: Optional[str] = None, centroids: Optional[xr.DataArray] = None, object_ids: Optional[xr.DataArray] = None, ) -> Optional[str]: # pragma: no cover """Create an animation from time series data Args: config: Plot configuration (including framerate for animation, default 10 fps) plot_dir: Directory to save animation files file_name: Name for the output animation file centroids: Optional DataArray containing centroid data with dimensions (component, time, ID) object_ids: Optional DataArray containing object ID field with integers > 0 for drawing contour outlines """ # Check if PIL is available for image processing from .._dependencies import require_dependencies require_dependencies(["pillow"], "Animation functionality") # Check if ffmpeg is installed if shutil.which("ffmpeg") is None: warnings.warn( "ffmpeg executable not found in system PATH. Cannot create animation.\n" "Please install ffmpeg using one of the following methods:\n" " - Linux: sudo apt install ffmpeg (Ubuntu/Debian) or sudo yum install ffmpeg (CentOS/RHEL)\n" " - Conda: conda install -c conda-forge ffmpeg\n" "Alternatively, use matplotlib for animation in Jupyter notebooks.", stacklevel=2, ) return None plot_dir = Path(plot_dir) plot_dir.mkdir(exist_ok=True) # Use dask's scratch directory for temporary frames dask_temp = dask.config.get("temporary-directory", default=None) if dask_temp is None: dask_temp = tempfile.gettempdir() temp_dir = Path(tempfile.mkdtemp(prefix="marex_animate_", dir=dask_temp)) if not file_name: file_name = f"movie_{self.da.name}.mp4" output_file = plot_dir / f"{file_name}.mp4" # Set up plotting parameters cmap, norm, clim, var_units, extend = self._setup_common_params(config) plot_params = { "cmap": cmap, "norm": norm, "clim": clim, "var_units": var_units, "extend": extend, "show_colorbar": config.show_colorbar, "grid_labels": config.grid_labels, } # Set up grid information if needed grid_info = None if hasattr(self, "fpath_tgrid") or hasattr(self, "fpath_ckdtree"): grid_info = { "type": "unstructured", "tgrid_path": getattr(self, "fpath_tgrid", None), "ckdtree_path": getattr(self, "fpath_ckdtree", None), "res": 0.3, } # Generate frames using dask for parallel processing delayed_tasks = [] time_dim = config.dimensions["time"] if config.dimensions else "time" time_coord = config.coordinates.get("time", time_dim) if config.coordinates else time_dim # Use provided centroids or None if not provided centroid_data = centroids for time_ind in range(len(self.da[time_dim])): data_slice = self.da.isel({time_dim: time_ind}) # Create fresh copy of plot_params for this frame to avoid shared references frame_params = plot_params.copy() frame_params["time_str"] = str(self.da[time_coord].isel({time_dim: time_ind}).dt.strftime("%Y-%m-%d").values) # Extract centroids for this time step if available if centroid_data is not None: try: centroids_time = centroid_data.isel({time_dim: time_ind}) frame_params["centroids"] = centroids_time except Exception: frame_params["centroids"] = None else: frame_params["centroids"] = None # Extract object IDs for this time step if available if object_ids is not None: try: object_ids_time = object_ids.isel({time_dim: time_ind}) frame_params["object_ids"] = object_ids_time except Exception: frame_params["object_ids"] = None else: frame_params["object_ids"] = None delayed_tasks.append(make_frame(data_slice, time_ind, temp_dir, frame_params, grid_info)) # Process frames in batches to manage memory efficiently batch_size = 200 filenames = [] for i in range(0, len(delayed_tasks), batch_size): batch = delayed_tasks[i : i + batch_size] batch_results = dask.compute(*batch) filenames.extend(batch_results) # Force garbage collection between batches to release memory gc.collect() filenames = sorted(filenames, key=lambda x: int(x.split("_")[-1].split(".")[0])) # Create movie using ffmpeg subprocess.run( [ "ffmpeg", "-y", "-threads", "0", "-framerate", str(config.framerate), "-i", str(temp_dir / "time_%04d.jpg"), "-c:v", "libx264", "-preset", "fast", "-crf", "22", "-pix_fmt", "yuv420p", "-movflags", "+faststart", str(output_file), ], check=True, ) # Clean up temporary frames directory shutil.rmtree(temp_dir) return str(output_file)
[docs] def clim_robust(self, data: NDArray[Any], issym: bool, percentiles: Optional[List[int]] = None) -> NDArray[np.float64]: """Compute robust colour limits from data percentiles.""" if percentiles is None: percentiles = [2, 98] clim = np.nanpercentile(data, percentiles) if issym: clim = np.abs(clim).max() clim = np.array([-clim, clim]) elif percentiles[0] == 0: clim = np.array([0, clim[1]]) return clim
[docs] def setup_plot_params(self) -> None: """Set up common plotting parameters""" plt.rc("text", usetex=False) plt.rc("font", family="serif")
[docs] def setup_id_plot_params(self, cmap: Optional[Union[str, ListedColormap]] = None) -> Tuple[ListedColormap, BoundaryNorm, str]: """Set up parameters for plotting IDs""" # Use min=1 and max from data without computing all unique values max_id = int(self.da.max().values) bounds = np.arange(1, max_id + 2) - 0.5 n_bins = len(bounds) - 1 if cmap is None: np.random.seed(42) cmap = ListedColormap(np.random.random(size=(n_bins, 3))) norm = BoundaryNorm(bounds, cmap.N) return cmap, norm, "ID"
[docs] def plot( self, ax: Axes, cmap: Union[str, ListedColormap] = "viridis", clim: Optional[Tuple[float, float]] = None, norm: Optional[Union[BoundaryNorm, Normalize]] = None, ) -> Tuple[Axes, Any]: """Abstract method to be implemented by subclasses""" raise NotImplementedError("Subclasses must implement plot method")
@dask.delayed def make_frame( data_slice: xr.DataArray, time_ind: int, temp_dir: Path, plot_params: Dict[str, Any], grid_info: Optional[Dict[str, Any]] = None, ) -> str: # pragma: no cover """Create a single frame for movies - minimise memory usage with dask Args: data_slice: The data for this specific frame time_ind: Frame index temp_dir: Directory for temporary files plot_params: Dict containing plotting parameters grid_info: Dict containing grid paths and settings for unstructured data """ # Set up plotting parameters plt.rc("text", usetex=False) plt.rc("font", family="serif") fig = plt.figure(figsize=(7, 5)) ax = plt.axes(projection=ccrs.Robinson()) data_slice_np = data_slice.values # Set up plot kwargs plot_kwargs = { "transform": ccrs.PlateCarree(), "cmap": plot_params["cmap"], "shading": "auto", } if plot_params.get("norm") is not None: plot_kwargs["norm"] = plot_params["norm"] elif plot_params.get("clim") is not None: plot_kwargs["vmin"] = plot_params["clim"][0] plot_kwargs["vmax"] = plot_params["clim"][1] # Handle different grid types if grid_info and grid_info.get("type") == "unstructured": try: from .unstructured import _load_ckdtree, _load_triangulation except ImportError as e: raise DependencyError( "Unstructured plotting dependencies missing", details=str(e), suggestions=[ "Install plotting dependencies: pip install marEx[plot]", "Check that scipy and matplotlib are available", "Verify unstructured grid support is properly installed", ], context={"missing_dependency": str(e), "plot_type": "unstructured"}, ) if grid_info.get("ckdtree_path"): # Use cached ckdtree data ckdt_data = _load_ckdtree(grid_info["ckdtree_path"], grid_info.get("res", 0.3)) grid_data = data_slice_np[ckdt_data["indices"]].reshape(ckdt_data["lat"].size, ckdt_data["lon"].size) grid_data = np.ma.masked_invalid(grid_data) im = ax.pcolormesh(ckdt_data["lon"], ckdt_data["lat"], grid_data, **plot_kwargs) elif grid_info.get("tgrid_path"): # Use triangulation triang = _load_triangulation(grid_info["tgrid_path"]) data_masked = np.ma.masked_invalid(data_slice_np) im = ax.tripcolor(triang, data_masked, **plot_kwargs) else: # Regular grid plotting lat = data_slice.lat.values lon = data_slice.lon.values im = ax.pcolormesh(lon, lat, data_slice_np, **plot_kwargs) time_str = plot_params.get("time_str", f"Frame {time_ind}") ax.set_title(time_str, size=12) # Plot object ID contours if available object_ids_data = plot_params.get("object_ids") if object_ids_data is not None: try: object_ids_np = object_ids_data.values # Create binary mask where object IDs > 0 object_mask = object_ids_np > 0 if np.any(object_mask): # Handle different grid types for contouring if grid_info and grid_info.get("type") == "unstructured": # For unstructured grids, we need to handle contouring differently # This is more complex and may require interpolation to regular grid pass else: # Regular grid plotting - use lat/lon coordinates lat = data_slice.lat.values lon = data_slice.lon.values # Draw contours around object boundaries (treating all IDs > 0 the same) ax.contour( lon, lat, object_mask.astype(float), levels=[0.5], colors=["white"], linewidths=1.5, transform=ccrs.PlateCarree(), zorder=6, ) except Exception: # Silently skip object ID contouring if any error occurs pass # Plot centroids if available centroids = plot_params.get("centroids") if centroids is not None: try: # Get unique object IDs present in this frame unique_ids = np.unique(data_slice_np) unique_ids = unique_ids[unique_ids > 0] # Remove background (0) if len(unique_ids) > 0: # Extract centroid coordinates for present objects # centroids shape: (component, ID) where component 0=lat, 1=lon centroids_np = centroids.values # Find which IDs have valid centroids valid_centroids = [] for obj_id in unique_ids: try: # Find ID index in centroids id_idx = np.where(centroids.ID.values == obj_id)[0] if len(id_idx) > 0: idx = id_idx[0] lat_centroid = centroids_np[0, idx] # component 0 = latitude lon_centroid = centroids_np[1, idx] # component 1 = longitude # Check if centroid is valid (not NaN) if not (np.isnan(lat_centroid) or np.isnan(lon_centroid)): valid_centroids.append((lon_centroid, lat_centroid)) except (IndexError, KeyError): continue # Plot centroids as scatter points if valid_centroids: centroid_lons, centroid_lats = zip(*valid_centroids) ax.scatter( centroid_lons, centroid_lats, c="black", s=20, marker="o", edgecolors="white", linewidth=1.5, transform=ccrs.PlateCarree(), zorder=5, # Plot above data but below grid lines alpha=0.8, ) except Exception: # Silently skip centroid plotting if any error occurs pass if plot_params.get("show_colorbar"): cb = plt.colorbar(im, shrink=0.6, ax=ax, extend=plot_params.get("extend", "both")) if plot_params.get("var_units"): cb.ax.set_ylabel(plot_params["var_units"], fontsize=10) cb.ax.tick_params(labelsize=10) land = cfeature.LAND.with_scale("50m") coastlines = cfeature.COASTLINE.with_scale("50m") ax.add_feature(land, facecolor="darkgrey", zorder=2) ax.add_feature(coastlines, linewidth=0.5, zorder=3) ax.gridlines( crs=ccrs.PlateCarree(), draw_labels=plot_params.get("grid_labels", False), linewidth=1, color="gray", alpha=0.5, linestyle="--", zorder=4, ) # Save and process frame filename = f"time_{time_ind:04d}.jpg" temp_file = temp_dir / f"temp_{filename}" fig.savefig(str(temp_file), dpi=300, bbox_inches="tight") plt.close(fig) # Ensure dimensions are even for video encoding image = Image.open(str(temp_file)) width, height = image.size new_width = width - (width % 2) new_height = height - (height % 2) image = image.resize((new_width, new_height), Image.LANCZOS) image.save(str(temp_dir / filename)) temp_file.unlink() return filename