"""
Gridded data visualisation module for regular rectangular grids.
Provides specialised plotting capabilities for structured oceanographic data
with lat/lon coordinates on regular grids (3D arrays: time, lat, lon).
"""
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import xarray as xr
try:
import cartopy.crs as ccrs
from matplotlib.axes import Axes
from matplotlib.collections import QuadMesh
from matplotlib.colors import BoundaryNorm, Normalize
HAS_PLOTTING_DEPS = True
except ImportError:
# These will be checked in the base class
ccrs = None
Axes = None
QuadMesh = None
BoundaryNorm = None
Normalize = None
HAS_PLOTTING_DEPS = False
from ..logging_config import get_logger, log_timing
from .base import PlotterBase
# Get module logger
logger = get_logger(__name__)
[docs]
class GriddedPlotter(PlotterBase):
"""Plotter for structured oceanographic data on regular rectangular grids."""
[docs]
def __init__(
self,
xarray_obj: xr.DataArray,
dimensions: Optional[Dict[str, str]] = None,
coordinates: Optional[Dict[str, str]] = None,
) -> None:
"""Initialise GriddedPlotter."""
super().__init__(xarray_obj, dimensions, coordinates)
[docs]
def wrap_lon(self, data: xr.DataArray) -> xr.DataArray:
"""Handle periodic boundary in longitude by adding a column of data."""
lon = data[self.dimensions["x"]]
# Check if we're dealing with global data that needs wrapping
lon_spacing = np.diff(lon)[0]
if abs(360 - (lon.max() - lon.min())) < 2 * lon_spacing:
# Add a column at lon=360 that equals the data at lon=0
new_lon = np.append(lon, lon[0] + 360)
wrapped_data = xr.concat([data, data.isel({self.dimensions["x"]: 0})], dim=self.dimensions["x"])
wrapped_data[self.dimensions["x"]] = new_lon
return wrapped_data
return data
[docs]
def plot(
self,
ax: Axes,
cmap: Union[str, Any] = "viridis",
clim: Optional[Tuple[float, float]] = None,
norm: Optional[Union[BoundaryNorm, Normalize]] = None,
) -> Tuple[Axes, QuadMesh]:
"""Implement plotting for gridded (i.e. regular grid) data."""
logger.debug(f"Plotting gridded data with shape {self.da.shape}")
with log_timing(logger, "Gridded plot rendering", show_progress=True):
data = self.wrap_lon(self.da)
# Ensure data has only required dimensions for imshow
if self.dimensions["time"] in data.dims and len(data[self.dimensions["time"]]) == 1:
data = data.squeeze(dim=self.dimensions["time"]) # Remove time dimension if singular
plot_kwargs = {
"transform": ccrs.PlateCarree(),
"cmap": cmap,
"shading": "auto",
}
if norm is not None:
plot_kwargs["norm"] = norm
elif clim is not None:
plot_kwargs["vmin"] = clim[0]
plot_kwargs["vmax"] = clim[1]
lons = data[self.dimensions["x"]].values
lats = data[self.dimensions["y"]].values
values = data.values
logger.debug(f"Rendering plot with {len(lons)} x {len(lats)} grid points")
# imshow has some dimension issues with cartopy...
im = ax.pcolormesh(lons, lats, values, **plot_kwargs)
return ax, im