Source code for nereus.plotting.transect

"""Vertical transect plotting for nereus.

This module provides functions for plotting vertical transects (cross-sections)
of 3D data along arbitrary paths.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
from scipy.spatial import cKDTree

from nereus.core.coordinates import great_circle_path, lonlat_to_cartesian
from nereus.core.grids import extract_coordinates, flatten_spatial, prepare_coordinates

if TYPE_CHECKING:
    import xarray as xr
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure


[docs] def transect( data: NDArray | "xr.DataArray", lon: NDArray[np.floating] | None = None, lat: NDArray[np.floating] | None = None, depth: NDArray[np.floating] | None = None, start: tuple[float, float] | None = None, end: tuple[float, float] | None = None, *, n_points: int = 100, cmap: str = "viridis", vmin: float | None = None, vmax: float | None = None, depth_lim: tuple[float, float] | None = None, invert_depth: bool = True, colorbar: bool = True, colorbar_label: str | None = None, title: str | None = None, figsize: tuple[float, float] | None = None, ax: "Axes | None" = None, **kwargs: Any, ) -> tuple["Figure", "Axes"]: """Plot vertical transect along a great circle path. The function accepts various coordinate formats and automatically transforms them to 1D arrays: - Both 1D with same size: used directly (no warning) - Both 2D with same shape: raveled to 1D - Both 1D with different sizes: meshgrid created, then raveled A warning is issued whenever coordinate transformations are applied. If lon/lat are not provided and data is an xarray DataArray, the function will attempt to extract coordinates automatically by looking for common coordinate names (lon/lat, longitude/latitude, x/y, etc.). Parameters ---------- data : array_like Data values with shape (nlevels, npoints) for 2D, or (nlevels, nlat, nlon) for 3D regular grids. 3D data is automatically reshaped to 2D. If xarray DataArray, coordinates may be extracted automatically. lon : array_like, optional Longitude coordinates. Can be 1D or 2D array. If None, will attempt to extract from data (xarray only). lat : array_like, optional Latitude coordinates. Can be 1D or 2D array. If None, will attempt to extract from data (xarray only). depth : array_like 1D array of depth levels (positive downward). start : tuple of float Start point (lon, lat). end : tuple of float End point (lon, lat). n_points : int Number of points along the transect. cmap : str Colormap name. vmin, vmax : float, optional Color scale limits. depth_lim : tuple of float, optional Depth/height limits (min, max). If None, uses data range. invert_depth : bool Whether to invert vertical axis. Default True for ocean (0 at top, depth increases downward). Set False for atmosphere (height increases upward). colorbar : bool Whether to add a colorbar. colorbar_label : str, optional Label for the colorbar. title : str, optional Plot title. figsize : tuple of float, optional Figure size. ax : Axes, optional Existing axes to plot on. **kwargs Additional arguments passed to pcolormesh. Returns ------- fig : Figure The matplotlib Figure. ax : Axes The matplotlib Axes. Examples -------- >>> fig, ax = nr.transect( ... temp, mesh.lon, mesh.lat, depth, ... start=(-30, 60), end=(30, 60) ... ) """ # Extract coordinates from xarray if not provided if lon is None or lat is None: extracted_lon, extracted_lat = extract_coordinates(data) if lon is None: lon = extracted_lon if lat is None: lat = extracted_lat # Validate required parameters if lon is None or lat is None: raise ValueError( "lon and lat coordinates are required. Either provide them explicitly " "or use an xarray DataArray with recognizable coordinate names " "(lon/lat, longitude/latitude, x/y, etc.)." ) if depth is None: raise ValueError("depth array is required for transect plots.") if start is None or end is None: raise ValueError("start and end points are required for transect plots.") # Handle xarray DataArray if hasattr(data, "values"): data = data.values data = np.asarray(data) # Prepare coordinates: handle various array shapes and validate lon_arr, lat_arr = prepare_coordinates(lon, lat) depth_arr = np.asarray(depth).ravel() # Handle 3D data on regular grids: (depth, lat, lon) -> (depth, lat*lon) # This ensures indexing is consistent with the flattened coordinates if data.ndim == 3: data = flatten_spatial(data) # Generate transect path path_lon, path_lat = great_circle_path( start[0], start[1], end[0], end[1], n_points ) # Build KDTree for source coordinates source_xyz = np.column_stack(lonlat_to_cartesian(lon_arr, lat_arr)) tree = cKDTree(source_xyz) # Find nearest points along path path_xyz = np.column_stack(lonlat_to_cartesian(path_lon, path_lat)) _, indices = tree.query(path_xyz, k=1) # Extract data along path if data.ndim == 1: # Single level transect_data = data[indices].reshape(1, -1) else: # Multiple levels (nlevels, npoints) transect_data = data[:, indices] # Compute distance along path (approximate) distance = np.zeros(n_points) for i in range(1, n_points): # Simple euclidean distance on path coordinates for display dlat = path_lat[i] - path_lat[i - 1] dlon = path_lon[i] - path_lon[i - 1] # Approximate km distance[i] = distance[i - 1] + np.sqrt(dlat**2 + (dlon * np.cos(np.deg2rad(path_lat[i])))**2) * 111 # Create figure if needed if ax is None: if figsize is None: figsize = (12, 6) fig, ax = plt.subplots(1, 1, figsize=figsize) else: fig = ax.get_figure() # Plot im = ax.pcolormesh( distance, depth_arr, transect_data, cmap=cmap, vmin=vmin, vmax=vmax, shading="auto", **kwargs, ) # Configure axes ax.set_xlabel("Distance (km)") ax.set_ylabel("Depth (m)" if invert_depth else "Height (m)") if depth_lim: if invert_depth: # For ocean: 0 at top, max depth at bottom ax.set_ylim(depth_lim[1], depth_lim[0]) else: # For atmosphere: 0 at bottom, max height at top ax.set_ylim(depth_lim) elif invert_depth: ax.invert_yaxis() if colorbar: cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) if colorbar_label: cbar.set_label(colorbar_label) if title: ax.set_title(title) return fig, ax