Source code for nereus.diag.hovmoller

"""Hovmoller diagram generation for nereus.

This module provides functions for computing and plotting Hovmoller diagrams
(time-depth or time-latitude plots).

The hovmoller function is dask-friendly: if inputs are dask arrays, the result
will be lazy dask arrays for both mode="depth" and mode="latitude".
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Literal

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

from nereus.core.grids import extract_coordinates, flatten_spatial
from nereus.core.types import get_array_data, is_dask_array, wrap_as_xarray

if TYPE_CHECKING:
    import xarray as xr
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure


def _lat_bin_chunk(
    data_chunk: NDArray,
    area_1d: NDArray,
    bin_indices: NDArray,
    nlat: int,
) -> NDArray:
    """Process a chunk of time steps for latitude binning.

    Parameters
    ----------
    data_chunk : ndarray
        Data chunk with shape (ntime_chunk, npoints).
    area_1d : ndarray
        Area weights with shape (npoints,).
    bin_indices : ndarray
        Precomputed bin index for each point, shape (npoints,).
    nlat : int
        Number of latitude bins.

    Returns
    -------
    ndarray
        Binned results with shape (ntime_chunk, nlat).
    """
    ntime_chunk = data_chunk.shape[0]
    result = np.full((ntime_chunk, nlat), np.nan)

    for i in range(nlat):
        in_bin = bin_indices == i
        if not np.any(in_bin):
            continue

        # Extract data for this bin
        bin_data = data_chunk[:, in_bin]  # (ntime_chunk, npoints_in_bin)
        bin_area = area_1d[in_bin]

        # Compute valid mask
        valid = np.isfinite(bin_data)
        valid_area = np.where(valid, bin_area[np.newaxis, :], 0.0)
        total_area = np.sum(valid_area, axis=1)

        # Weighted sum
        data_filled = np.where(valid, bin_data, 0.0)
        weighted_sum = np.sum(data_filled * valid_area, axis=1)

        # Mean
        with np.errstate(divide="ignore", invalid="ignore"):
            result[:, i] = np.where(total_area > 0, weighted_sum / total_area, np.nan)

    return result


[docs] def hovmoller( data: NDArray | "xr.DataArray", area: NDArray[np.floating], time: NDArray | None = None, depth: NDArray[np.floating] | None = None, lat: NDArray[np.floating] | None = None, *, mode: Literal["depth", "latitude"] = "depth", lat_bins: NDArray[np.floating] | None = None, mask: NDArray[np.bool_] | None = None, as_xarray: bool = False, ) -> tuple[NDArray, NDArray, NDArray] | "xr.DataArray": """Compute Hovmoller diagram data. Computes area-weighted means at each time step, binned by either depth level or latitude. This function is dask-friendly: if inputs are dask arrays, the result will be lazy dask arrays. If lat is not provided and data is an xarray DataArray, the function will attempt to extract latitude coordinates automatically by looking for common coordinate names (lat, latitude, y, etc.). Parameters ---------- data : array_like Data array. For depth mode: shape (ntime, nlevels, npoints). For latitude mode: shape (ntime, npoints) or (ntime, nlevels, npoints). 4D arrays with shape (ntime, nlevels, nlat, nlon) are automatically flattened to (ntime, nlevels, nlat*nlon) in both modes. If xarray DataArray, lat coordinates may be extracted automatically. area : array_like Grid cell areas in m^2. Can be either: - 1D array of shape (npoints,) for surface area (uniform across depth) - 2D array of shape (nlevels, npoints) for depth-dependent area If 2D and has one extra level compared to data layers, the extra level is dropped with a warning (levels vs layers). time : array_like, optional Time coordinates. If None, uses integer indices. depth : array_like, optional Depth levels in meters. Required for mode="depth". lat : array_like, optional Latitude coordinates in degrees. Required for mode="latitude". If None and mode="latitude", will attempt to extract from data (xarray only). mode : {"depth", "latitude"} Type of Hovmoller diagram. lat_bins : array_like, optional Latitude bin edges for mode="latitude". Default is 5-degree bins. mask : array_like, optional Boolean mask for horizontal points, shape (npoints,). True = include. as_xarray : bool If True, return the result as an xarray DataArray with time and depth/latitude dimension coordinates instead of a 3-tuple (default False). Returns ------- tuple or xr.DataArray If as_xarray=False (default): Tuple of (time_out, y_out, data_out) where data_out has shape (ntime, ny). If as_xarray=True: xr.DataArray with dims ("time", "depth") or ("time", "latitude") and corresponding coordinates. Examples -------- >>> # Time-depth Hovmoller >>> t, z, hov = nr.hovmoller(temp, mesh.area, time, depth, mode="depth") >>> # Time-latitude Hovmoller >>> t, lat, hov = nr.hovmoller(sst, mesh.area, time, lat=mesh.lat, mode="latitude") >>> # With dask arrays (depth mode) >>> t, z, hov = nr.hovmoller(temp_dask, mesh.area, time, depth, mode="depth") >>> hov.compute() # triggers actual computation """ # Try to extract lat from xarray if not provided and needed for latitude mode if lat is None and mode == "latitude": _, extracted_lat = extract_coordinates(data) if extracted_lat is not None: lat = extracted_lat # Extract arrays, preserving dask for depth mode data_arr = get_array_data(data) area_arr = get_array_data(area) is_lazy = is_dask_array(data) # Apply mask to horizontal points if mask is not None: horiz_mask = get_array_data(mask) if hasattr(horiz_mask, "ravel"): horiz_mask = horiz_mask.ravel() else: horiz_mask = np.asarray(horiz_mask).ravel() else: horiz_mask = None if mode == "depth": if depth is None: raise ValueError("depth array required for mode='depth'") depth_arr = np.asarray(get_array_data(depth)).ravel() # Flatten 4D regular-grid data: (ntime, nlevels, nlat, nlon) -> (ntime, nlevels, npoints) if data_arr.ndim == 4: data_arr = flatten_spatial(data_arr) # Expect data shape: (ntime, nlevels, npoints) if data_arr.ndim == 2: # Assume (nlevels, npoints) - single timestep data_arr = data_arr[np.newaxis, :, :] ntime, nlevels, npoints = data_arr.shape # Handle area: can be 1D (npoints,) or 2D (nlevels, npoints) if area_arr.ndim == 1: if hasattr(area_arr, "ravel"): area_arr = area_arr.ravel() if area_arr.shape[0] != npoints: raise ValueError( f"area has {area_arr.shape[0]} points but data has {npoints}" ) area_is_2d = False elif area_arr.ndim == 2: nlev_area = area_arr.shape[0] area_is_2d = True # Check if area has one extra level (levels vs layers mismatch) if nlev_area != nlevels: diff = nlev_area - nlevels if diff != 1: raise ValueError( f"area has {nlev_area} vertical levels but data has {nlevels}; " "only area having one extra level is supported (levels vs layers)." ) warnings.warn( f"area has one more vertical level than data; " f"using the first {nlevels} levels of area to match data " "(levels vs layers).", UserWarning, stacklevel=2, ) area_arr = area_arr[:nlevels, :] else: raise ValueError(f"area must be 1D or 2D, got {area_arr.ndim}D") # Apply horizontal mask if provided if horiz_mask is not None: horiz_mask_float = horiz_mask.astype(np.float64) if area_is_2d: area_arr = area_arr * horiz_mask_float[np.newaxis, :] else: area_arr = area_arr * horiz_mask_float # Compute area-weighted mean at each depth level for each time # Vectorized approach for dask compatibility # data_arr shape: (ntime, nlevels, npoints) # area_arr shape: (npoints,) or (nlevels, npoints) # Get valid mask valid = np.isfinite(data_arr) # Prepare area for broadcasting if area_is_2d: # area_arr: (nlevels, npoints) -> (1, nlevels, npoints) area_broadcast = area_arr[np.newaxis, :, :] else: # area_arr: (npoints,) -> (1, 1, npoints) area_broadcast = area_arr[np.newaxis, np.newaxis, :] # Compute valid area (zero where data is NaN) valid_area = np.where(valid, area_broadcast, 0.0) # Replace NaN with 0 for summation data_filled = np.where(valid, data_arr, 0.0) # Sum over points (last axis) weighted_sum = np.sum(data_filled * valid_area, axis=-1) # (ntime, nlevels) total_area = np.sum(valid_area, axis=-1) # (ntime, nlevels) # Compute mean, handling zero area result = np.where(total_area > 0, weighted_sum / total_area, np.nan) # Time array if time is None: time_out = np.arange(ntime) else: time_out = np.asarray(get_array_data(time)) if as_xarray: return _wrap_hovmoller_as_xarray( result, time_out, depth_arr, "depth", data, ) return time_out, depth_arr, result elif mode == "latitude": if lat is None: raise ValueError("lat array required for mode='latitude'") # Get lat array as numpy (small array, safe to compute) lat_arr = np.asarray(get_array_data(lat)).ravel() # Set up latitude bins if lat_bins is None: lat_bins = np.arange(-90, 95, 5) # 5-degree bins lat_bins = np.asarray(lat_bins) lat_centers = 0.5 * (lat_bins[:-1] + lat_bins[1:]) nlat = len(lat_centers) # Precompute bin indices for each point (small numpy array) # digitize returns 1-based indices for bins, so subtract 1 # Points outside the range get clipped to valid bin indices bin_indices = np.digitize(lat_arr, lat_bins) - 1 bin_indices = np.clip(bin_indices, 0, nlat - 1) # Flatten 4D regular-grid data: (ntime, nlevels, nlat, nlon) -> (ntime, nlevels, npoints) if data_arr.ndim == 4: data_arr = flatten_spatial(data_arr) # Handle data shape - need to work with original array (possibly dask) if data_arr.ndim == 1: # Single timestep, single level (npoints,) data_arr = data_arr[np.newaxis, np.newaxis, :] elif data_arr.ndim == 2: # Could be (ntime, npoints) or (nlevels, npoints) # Assume (ntime, npoints) for latitude mode data_arr = data_arr[:, np.newaxis, :] ntime = data_arr.shape[0] npoints = data_arr.shape[-1] # For latitude mode, use surface area (first level if 2D) # Area needs to be numpy for the binning function if is_dask_array(area_arr): area_arr = np.asarray(area_arr) if area_arr.ndim == 2: area_1d = area_arr[0, :] else: area_1d = area_arr.ravel() # Apply horizontal mask if provided if horiz_mask is not None: area_1d = np.where(horiz_mask, area_1d, 0.0) # Vertically integrate first if 3D if data_arr.shape[1] > 1: # Vertical mean (simple average across levels) if is_lazy: import dask.array as da data_2d = da.nanmean(data_arr, axis=1) else: data_2d = np.nanmean(data_arr, axis=1) else: data_2d = data_arr[:, 0, :] # Process using dask or numpy if is_lazy: import dask.array as da # Try to use distributed client.scatter to avoid graph bloat # This sends the arrays to workers once, then tasks reference by future use_scatter = False try: from distributed import get_client client = get_client() # Scatter arrays to all workers (broadcast=True) area_future = client.scatter(area_1d, broadcast=True) bin_future = client.scatter(bin_indices, broadcast=True) use_scatter = True except (ImportError, ValueError): # No distributed client available, use regular approach pass if use_scatter: # With scattered data, use map_blocks with futures result = da.map_blocks( _lat_bin_chunk, data_2d, area_1d=area_future, bin_indices=bin_future, nlat=nlat, dtype=np.float64, drop_axis=1, new_axis=1, chunks=(data_2d.chunks[0], (nlat,)), ) else: # Fallback: use map_blocks with embedded arrays # This works fine for local schedulers result = da.map_blocks( _lat_bin_chunk, data_2d, area_1d=area_1d, bin_indices=bin_indices, nlat=nlat, dtype=np.float64, drop_axis=1, new_axis=1, chunks=(data_2d.chunks[0], (nlat,)), ) else: # Numpy path data_2d = np.asarray(data_2d) result = _lat_bin_chunk(data_2d, area_1d, bin_indices, nlat) # Time array if time is None: time_out = np.arange(ntime) else: time_out = np.asarray(get_array_data(time)) if as_xarray: return _wrap_hovmoller_as_xarray( result, time_out, lat_centers, "latitude", data, ) return time_out, lat_centers, result else: raise ValueError(f"Invalid mode: {mode}. Must be 'depth' or 'latitude'.")
def _wrap_hovmoller_as_xarray( result: NDArray, time_out: NDArray, y_out: NDArray, y_name: str, source_data: NDArray | "xr.DataArray", ) -> "xr.DataArray": """Wrap hovmoller result as an xarray DataArray. Parameters ---------- result : ndarray Hovmoller data, shape (ntime, ny). time_out : ndarray Time coordinate values. y_out : ndarray Depth or latitude coordinate values. y_name : str Name for the y dimension ("depth" or "latitude"). source_data : array_like Original input data for extracting name and attrs. """ import xarray as xr if hasattr(source_data, "name"): var_name = source_data.name or "data" else: var_name = "data" if hasattr(source_data, "attrs"): var_attrs = dict(source_data.attrs) else: var_attrs = {} y_attrs = {} if y_name == "depth": y_attrs = {"units": "m", "positive": "down"} elif y_name == "latitude": y_attrs = {"units": "degrees_north", "standard_name": "latitude"} return xr.DataArray( np.asarray(result), dims=("time", y_name), coords={ "time": ("time", time_out), y_name: (y_name, y_out, y_attrs), }, name=var_name, attrs=var_attrs, ) def _apply_y_scale( ax: "Axes", scale: Literal["sqrt", "power", "symlog"], scale_kw: dict[str, Any], ) -> None: """Apply non-linear y-axis scaling for depth plots. Parameters ---------- ax : Axes Matplotlib axes to modify. scale : str Scale type: "sqrt", "power", or "symlog". scale_kw : dict Scale-specific parameters. """ from matplotlib.scale import FuncScale if scale == "sqrt": # Square root transform: handles zero naturally, spreads surface layers def forward(x: NDArray) -> NDArray: return np.sqrt(np.maximum(x, 0)) def inverse(x: NDArray) -> NDArray: return x**2 ax.set_yscale(FuncScale(ax, (forward, inverse))) elif scale == "power": # Power transform with configurable exponent exponent = scale_kw.get("exponent", 0.4) if exponent <= 0 or exponent >= 1: raise ValueError( f"Power exponent must be between 0 and 1, got {exponent}" ) def forward(x: NDArray) -> NDArray: return np.power(np.maximum(x, 0), exponent) def inverse(x: NDArray) -> NDArray: return np.power(np.maximum(x, 0), 1.0 / exponent) ax.set_yscale(FuncScale(ax, (forward, inverse))) elif scale == "symlog": # Symmetric log: linear near zero, logarithmic further out linthresh = scale_kw.get("linthresh", 10.0) ax.set_yscale("symlog", linthresh=linthresh) else: raise ValueError(f"Unknown y_scale: {scale}")
[docs] def plot_hovmoller( time: NDArray, y: NDArray, data: NDArray, *, mode: Literal["depth", "latitude"] = "depth", cmap: str = "RdBu_r", vmin: float | None = None, vmax: float | None = None, colorbar: bool = True, colorbar_label: str | None = None, title: str | None = None, figsize: tuple[float, float] | None = None, ax: "Axes | None" = None, invert_y: bool | None = None, anomaly: bool = False, y_scale: Literal["linear", "sqrt", "power", "symlog"] = "linear", y_scale_kw: dict[str, Any] | None = None, **kwargs: Any, ) -> tuple["Figure", "Axes"]: """Plot a Hovmoller diagram. Parameters ---------- time : array_like Time coordinates. y : array_like Depth or latitude coordinates. data : array_like Hovmoller data, shape (ntime, ny). mode : {"depth", "latitude"} Type of diagram (affects axis labels and orientation). cmap : str Colormap name. vmin, vmax : float, optional Color scale limits. colorbar : bool Whether to add a colorbar. colorbar_label : str, optional Label for the colorbar. title : str, optional Plot title. figsize : tuple of float, optional Figure size. ax : Axes, optional Existing axes to plot on. invert_y : bool, optional Whether to invert y-axis. Default True for depth, False for latitude. anomaly : bool If True and mode="depth", plot anomaly relative to first time step (data - data[0, :]). Default False. y_scale : {"linear", "sqrt", "power", "symlog"} Vertical axis scaling for depth mode. Options: - "linear": No transformation (default) - "sqrt": Square root transform, gives more space to surface layers - "power": Power transform with configurable exponent (see y_scale_kw) - "symlog": Symmetric log scale, linear near zero then logarithmic y_scale_kw : dict, optional Additional parameters for y_scale: - For "power": {"exponent": 0.4} (default 0.4, smaller = more surface detail) - For "symlog": {"linthresh": 10} (linear threshold in meters, default 10) **kwargs Additional arguments passed to pcolormesh. Returns ------- fig : Figure The matplotlib Figure. ax : Axes The matplotlib Axes. Examples -------- >>> # Square root scaling for more surface detail >>> fig, ax = plot_hovmoller(time, depth, data, y_scale="sqrt") >>> # Power scaling with custom exponent (smaller = more surface detail) >>> fig, ax = plot_hovmoller(time, depth, data, y_scale="power", ... y_scale_kw={"exponent": 0.3}) >>> # Symmetric log: linear in upper 20m, log below >>> fig, ax = plot_hovmoller(time, depth, data, y_scale="symlog", ... y_scale_kw={"linthresh": 20}) """ time = np.asarray(time) y = np.asarray(y) data = np.asarray(data) # Validate dimensions ntime_data, ny_data = data.shape if len(time) != ntime_data: raise ValueError( f"time array has {len(time)} elements but data has {ntime_data} " f"time steps (data shape: {data.shape})" ) if len(y) != ny_data: raise ValueError( f"y array has {len(y)} elements but data has {ny_data} " f"y values (data shape: {data.shape}). " f"Make sure you're using the y coordinates returned by hovmoller(), " f"not the original mesh coordinates." ) # Compute anomaly if requested (only for depth mode) if anomaly and mode == "depth": data = data - data[0, :] # Create figure if needed if ax is None: if figsize is None: figsize = (12, 6) fig, ax = plt.subplots(1, 1, figsize=figsize) else: fig = ax.get_figure() # Compute explicit cell edges for the y-axis so pcolormesh cells are # well-defined. For depth mode the upper boundary is clamped to 0 # (the ocean surface) to avoid a negative edge that causes artefacts # with non-linear y-scales like sqrt. y_edges = np.empty(len(y) + 1) # Interior edges: midpoints between consecutive coordinates y_edges[1:-1] = 0.5 * (y[:-1] + y[1:]) if mode == "depth": # Surface boundary at 0 (not negative) y_edges[0] = max(0.0, y[0] - 0.5 * (y[1] - y[0])) else: y_edges[0] = y[0] - 0.5 * (y[1] - y[0]) y_edges[-1] = y[-1] + 0.5 * (y[-1] - y[-2]) # Compute time edges for shading="flat" (needs n+1 edges for n cells) time_num = np.arange(len(time), dtype=float) if np.issubdtype( time.dtype, np.datetime64 ) else time.astype(float) t_edges = np.empty(len(time) + 1) t_edges[1:-1] = 0.5 * (time_num[:-1] + time_num[1:]) t_edges[0] = time_num[0] - 0.5 * (time_num[1] - time_num[0]) t_edges[-1] = time_num[-1] + 0.5 * (time_num[-1] - time_num[-2]) if np.issubdtype(time.dtype, np.datetime64): # Convert back: compute offsets as timedeltas base = time[0] dt = time[1] - time[0] t_edges_dt = np.empty(len(time) + 1, dtype=time.dtype) t_edges_dt[1:-1] = time[:-1] + (time[1:] - time[:-1]) / 2 t_edges_dt[0] = time[0] - (time[1] - time[0]) / 2 t_edges_dt[-1] = time[-1] + (time[-1] - time[-2]) / 2 t_edges = t_edges_dt # Plot im = ax.pcolormesh( t_edges, y_edges, data.T, # Transpose so y is on vertical axis cmap=cmap, vmin=vmin, vmax=vmax, shading="flat", **kwargs, ) # Axis labels ax.set_xlabel("Time") if mode == "depth": ax.set_ylabel("Depth (m)") if invert_y is None: invert_y = True else: ax.set_ylabel("Latitude (°)") if invert_y is None: invert_y = False if invert_y: ax.invert_yaxis() # Apply y-axis scaling for depth mode if y_scale != "linear" and mode == "depth": scale_kw = y_scale_kw or {} _apply_y_scale(ax, y_scale, scale_kw) # Colorbar if colorbar: cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) if colorbar_label: cbar.set_label(colorbar_label) if title: ax.set_title(title) return fig, ax