Source code for nereus.plotting.maps

"""2D map plotting for unstructured data.

This module provides functions for plotting unstructured geophysical data
on various map projections.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

from nereus.core.grids import extract_coordinates, prepare_input_arrays
from nereus.plotting.projections import (
    get_data_bounds_for_projection,
    get_projection,
    is_global_projection,
    is_polar_projection,
)
from nereus.regrid.cache import get_cache
from nereus.regrid.interpolator import RegridInterpolator

if TYPE_CHECKING:
    import xarray as xr
    from matplotlib.axes import Axes
    from matplotlib.colorbar import Colorbar
    from matplotlib.figure import Figure


[docs] def plot( data: NDArray | "xr.DataArray", lon: NDArray[np.floating] | None = None, lat: NDArray[np.floating] | None = None, *, projection: str | ccrs.Projection = "pc", extent: tuple[float, float, float, float] | None = None, resolution: float | tuple[int, int] = 1.0, interpolator: RegridInterpolator | None = None, method: Literal["nearest", "idw", "linear", "cubic"] = "nearest", influence_radius: float = 80_000.0, cmap: str = "viridis", vmin: float | None = None, vmax: float | None = None, coastlines: bool = True, land: bool = False, gridlines: bool = False, colorbar: bool = True, colorbar_label: str | None = None, title: str | None = None, figsize: tuple[float, float] | None = None, ax: "Axes | None" = None, use_cache: bool = True, **kwargs: Any, ) -> tuple["Figure", "Axes", RegridInterpolator]: """Plot 2D map of unstructured data. This function regrids unstructured data to a regular grid and plots it on a map with the specified projection. The function accepts various input formats and automatically transforms them to 1D arrays for plotting: - All 1D arrays of same size: used directly (no warning) - 2D data with 2D lon/lat (same shape): all raveled to 1D - 1D data with 2D lon/lat: lon/lat raveled to match data - 2D data with 1D lon/lat: meshgrid created, then all raveled A warning is issued whenever array transformations are applied. If lon/lat are not provided and data is an xarray DataArray, the function will attempt to extract coordinates automatically by looking for common coordinate names (lon/lat, longitude/latitude, x/y, etc.). Parameters ---------- data : array_like Data values at unstructured points. Can be 1D or 2D array. If xarray DataArray, coordinates may be extracted automatically. lon : array_like, optional Longitude coordinates. Can be 1D or 2D array. If None, will attempt to extract from data (xarray only). lat : array_like, optional Latitude coordinates. Can be 1D or 2D array. If None, will attempt to extract from data (xarray only). projection : str or Projection Map projection. Options: "pc", "rob", "merc", "npstere", "spstere", "moll", "ortho", "lcc", or a Cartopy Projection. extent : tuple of float, optional Map extent (lon_min, lon_max, lat_min, lat_max). resolution : float or tuple of int Grid resolution for regridding. interpolator : RegridInterpolator, optional Pre-computed interpolator. If None, one will be created. method : {"nearest", "idw", "linear", "cubic"} Interpolation method. "nearest" uses nearest-neighbor lookup (fast). "idw" uses inverse distance weighting (fast, smooth). "linear" uses Delaunay triangulation with barycentric interpolation. "cubic" uses Clough-Tocher C1 interpolation (smoothest). Default is "nearest". influence_radius : float Maximum influence radius in meters for interpolation. Default is 80 km. cmap : str Colormap name. vmin, vmax : float, optional Color scale limits. coastlines : bool Whether to draw coastlines. land : bool Whether to fill land areas. gridlines : bool Whether to draw gridlines. colorbar : bool Whether to add a colorbar. colorbar_label : str, optional Label for the colorbar. title : str, optional Plot title. figsize : tuple of float, optional Figure size (width, height) in inches. ax : Axes, optional Existing axes to plot on. If None, creates new figure. use_cache : bool Whether to use the interpolator cache. **kwargs Additional arguments passed to pcolormesh. Returns ------- fig : Figure The matplotlib Figure. ax : Axes The matplotlib Axes (GeoAxes). interpolator : RegridInterpolator The interpolator used (can be reused). Examples -------- >>> fig, ax, interp = nr.plot(temp, mesh.lon, mesh.lat) >>> fig, ax, _ = nr.plot(salinity, mesh.lon, mesh.lat, interpolator=interp) """ # Extract coordinates from xarray if not provided if lon is None or lat is None: extracted_lon, extracted_lat = extract_coordinates(data) if lon is None: lon = extracted_lon if lat is None: lat = extracted_lat # Validate that we have coordinates if lon is None or lat is None: raise ValueError( "lon and lat coordinates are required. Either provide them explicitly " "or use an xarray DataArray with recognizable coordinate names " "(lon/lat, longitude/latitude, x/y, etc.)." ) # Prepare inputs: handle various array shapes and validate data_values, lon_arr, lat_arr = prepare_input_arrays(data, lon, lat) # Get projection proj = get_projection(projection) data_crs = ccrs.PlateCarree() # Determine data bounds based on projection lon_bounds, lat_bounds = get_data_bounds_for_projection(projection, extent) # Get or create interpolator if interpolator is None: if use_cache: cache = get_cache() interpolator = cache.get_or_create( lon_arr, lat_arr, resolution=resolution, method=method, influence_radius=influence_radius, lon_bounds=lon_bounds, lat_bounds=lat_bounds, ) else: interpolator = RegridInterpolator( source_lon=lon_arr, source_lat=lat_arr, resolution=resolution, method=method, influence_radius=influence_radius, lon_bounds=lon_bounds, lat_bounds=lat_bounds, ) # Regrid data regridded = interpolator(data_values) # Create figure if needed if ax is None: if figsize is None: # Default figure size based on projection if is_polar_projection(projection): figsize = (8, 8) elif is_global_projection(projection): figsize = (12, 6) else: figsize = (10, 6) fig, ax = plt.subplots( 1, 1, figsize=figsize, subplot_kw={"projection": proj}, ) else: fig = ax.get_figure() # Set up map if is_global_projection(projection): ax.set_global() elif extent: ax.set_extent(extent, crs=data_crs) # Add map features if land: ax.add_feature( cfeature.LAND, facecolor="lightgray", edgecolor="none", zorder=1, ) # Plot data im = ax.pcolormesh( interpolator.target_lon, interpolator.target_lat, regridded, cmap=cmap, vmin=vmin, vmax=vmax, transform=data_crs, zorder=0, **kwargs, ) # Add coastlines on top if coastlines: ax.coastlines(linewidth=0.5, color="black", zorder=2) if gridlines: ax.gridlines(draw_labels=not is_polar_projection(projection), linewidth=0.5, alpha=0.5) # Add colorbar (horizontal at bottom) if colorbar: cbar = fig.colorbar(im, ax=ax, orientation="horizontal", shrink=0.8, pad=0.05) if colorbar_label: cbar.set_label(colorbar_label) elif hasattr(data, "name") and data.name: cbar.set_label(data.name) if title: ax.set_title(title) return fig, ax, interpolator