"""Grid utilities for nereus.
Functions for creating regular grids for regridding and plotting.
"""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Literal
import numpy as np
from numpy.typing import NDArray
if TYPE_CHECKING:
import xarray as xr
# Common coordinate name patterns for longitude
LON_NAMES = (
"lon",
"longitude",
"x",
"nav_lon",
"glon",
"xt_ocean",
"xu_ocean",
"xh",
"xq",
"nod2d_lon",
)
# Common coordinate name patterns for latitude
LAT_NAMES = (
"lat",
"latitude",
"y",
"nav_lat",
"glat",
"yt_ocean",
"yu_ocean",
"yh",
"yq",
"nod2d_lat",
)
def _find_coord_by_names(
data: "xr.DataArray",
names: tuple[str, ...],
coord_type: str,
) -> NDArray | None:
"""Find a coordinate in DataArray by trying common names.
Parameters
----------
data : xr.DataArray
The DataArray to search for coordinates.
names : tuple of str
Tuple of possible coordinate names to try.
coord_type : str
Type of coordinate for error messages ("longitude" or "latitude").
Returns
-------
ndarray or None
The coordinate values if found, None otherwise.
"""
# First check coordinates (dims + coords)
all_coords = set(data.dims) | set(data.coords.keys())
for name in names:
# Try exact match (case-insensitive)
for coord_name in all_coords:
if coord_name.lower() == name.lower():
return np.asarray(data.coords[coord_name].values)
# Try attributes that might contain coordinate info
if hasattr(data, "attrs"):
for attr_name in ("coordinates", "grid_mapping"):
if attr_name in data.attrs:
coord_str = data.attrs[attr_name]
for name in names:
if name in coord_str.lower():
# Found reference, try to get from coords
for coord_name in all_coords:
if coord_name.lower() == name.lower():
return np.asarray(data.coords[coord_name].values)
return None
[docs]
def prepare_coordinates(
lon: NDArray[np.floating],
lat: NDArray[np.floating],
data_shape: tuple[int, ...] | None = None,
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
"""Prepare and validate coordinate arrays, converting to 1D if necessary.
This function handles various coordinate formats and transforms them to 1D
arrays. It issues warnings when transformations are applied.
Parameters
----------
lon : array_like
Longitude coordinates, can be 1D or 2D.
lat : array_like
Latitude coordinates, can be 1D or 2D.
data_shape : tuple of int, optional
Expected data shape for validation. If provided with 1D lon/lat,
used to create a meshgrid matching the data dimensions.
Returns
-------
lon_1d : ndarray
1D array of longitude coordinates.
lat_1d : ndarray
1D array of latitude coordinates.
Raises
------
ValueError
If array dimensions are incompatible or sizes don't match.
Examples
--------
>>> # 2D coordinates
>>> lon_2d, lat_2d = np.meshgrid(np.arange(10), np.arange(5))
>>> lon_1d, lat_1d = prepare_coordinates(lon_2d, lat_2d)
>>> lon_1d.shape
(50,)
>>> # 1D coordinates with data shape
>>> lon = np.arange(10)
>>> lat = np.arange(5)
>>> lon_1d, lat_1d = prepare_coordinates(lon, lat, data_shape=(5, 10))
>>> lon_1d.shape
(50,)
"""
lon_arr = np.asarray(lon)
lat_arr = np.asarray(lat)
lon_ndim = lon_arr.ndim
lat_ndim = lat_arr.ndim
# Case 1: Both 1D with same size - already flat coordinates
if lon_ndim == 1 and lat_ndim == 1 and lon_arr.size == lat_arr.size:
return lon_arr, lat_arr
# Case 2: Both 2D with same shape - ravel them
if lon_ndim == 2 and lat_ndim == 2:
if lon_arr.shape != lat_arr.shape:
raise ValueError(
f"2D lon and lat arrays must have the same shape. "
f"Got lon: {lon_arr.shape}, lat: {lat_arr.shape}"
)
warnings.warn(
f"Raveling 2D lon/lat arrays (shape {lon_arr.shape}) to 1D.",
stacklevel=3,
)
return lon_arr.ravel(), lat_arr.ravel()
# Case 3: Both 1D with different sizes - need meshgrid
if lon_ndim == 1 and lat_ndim == 1 and lon_arr.size != lat_arr.size:
# This is the side-coordinates case
if data_shape is not None:
if len(data_shape) != 2:
raise ValueError(
f"data_shape must be 2D for meshgrid creation, got {len(data_shape)}D"
)
ny, nx = data_shape
if lon_arr.size != nx:
raise ValueError(
f"1D lon array size ({lon_arr.size}) must match data columns ({nx})."
)
if lat_arr.size != ny:
raise ValueError(
f"1D lat array size ({lat_arr.size}) must match data rows ({ny})."
)
warnings.warn(
f"Creating meshgrid from 1D lon ({lon_arr.size}) and lat ({lat_arr.size}), "
"then raveling to 1D.",
stacklevel=3,
)
lon_2d, lat_2d = np.meshgrid(lon_arr, lat_arr)
return lon_2d.ravel(), lat_2d.ravel()
# Case 4: Mixed 1D/2D
if (lon_ndim == 1) != (lat_ndim == 1):
raise ValueError(
f"lon and lat must both be 1D or both be 2D. "
f"Got lon: {lon_ndim}D, lat: {lat_ndim}D"
)
# Catch-all for unsupported dimensions
raise ValueError(
f"Unsupported coordinate dimensions: lon {lon_ndim}D, lat {lat_ndim}D. "
"Supported: both 1D (same or different sizes) or both 2D (same shape)."
)
[docs]
def flatten_spatial(data):
"""Flatten the last two spatial dimensions of an array into one.
Reshapes ``(..., nlat, nlon)`` to ``(..., npoints)`` where
``npoints = nlat * nlon``. Works with both NumPy and dask arrays.
Parameters
----------
data : array_like
Array with at least 2 dimensions.
Returns
-------
array_like
Reshaped array with the last two dimensions merged.
Raises
------
ValueError
If *data* has fewer than 2 dimensions.
"""
if data.ndim < 2:
raise ValueError(
f"flatten_spatial requires at least 2 dimensions, got {data.ndim}D."
)
if data.ndim == 2:
return data.reshape(-1)
return data.reshape(data.shape[:-2] + (-1,))
[docs]
def create_regular_grid(
resolution: float | tuple[int, int] = 1.0,
lon_bounds: tuple[float, float] = (-180.0, 180.0),
lat_bounds: tuple[float, float] = (-90.0, 90.0),
center: Literal["cell", "node"] = "cell",
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
"""Create a regular lon/lat grid.
Parameters
----------
resolution : float or tuple of int
Grid resolution. If float, specifies degrees per grid cell.
If tuple (nlon, nlat), specifies number of grid points.
lon_bounds : tuple of float
Longitude bounds (lon_min, lon_max) in degrees.
lat_bounds : tuple of float
Latitude bounds (lat_min, lat_max) in degrees.
center : {"cell", "node"}
Whether coordinates are at cell centers or nodes.
"cell" means coordinates at center of grid boxes.
"node" means coordinates at corners.
Returns
-------
lon, lat : tuple of ndarrays
2D arrays of longitude and latitude coordinates.
Examples
--------
>>> lon, lat = create_regular_grid(1.0) # 1 degree resolution
>>> lon.shape
(180, 360)
>>> lon, lat = create_regular_grid((360, 180)) # 360x180 grid
>>> lon.shape
(180, 360)
"""
lon_min, lon_max = lon_bounds
lat_min, lat_max = lat_bounds
if isinstance(resolution, (list, tuple)):
nlon, nlat = resolution
else:
nlon = int((lon_max - lon_min) / resolution)
nlat = int((lat_max - lat_min) / resolution)
if center == "cell":
# Cell centers
dlon = (lon_max - lon_min) / nlon
dlat = (lat_max - lat_min) / nlat
lon_1d = np.linspace(lon_min + dlon / 2, lon_max - dlon / 2, nlon)
lat_1d = np.linspace(lat_min + dlat / 2, lat_max - dlat / 2, nlat)
else:
# Node positions
lon_1d = np.linspace(lon_min, lon_max, nlon)
lat_1d = np.linspace(lat_min, lat_max, nlat)
lon, lat = np.meshgrid(lon_1d, lat_1d)
return lon, lat
[docs]
def grid_cell_area(
lon: NDArray[np.floating],
lat: NDArray[np.floating],
radius: float = 6_371_000.0,
) -> NDArray[np.floating]:
"""Compute area of regular grid cells.
Parameters
----------
lon : ndarray
2D array of longitude coordinates (cell centers).
lat : ndarray
2D array of latitude coordinates (cell centers).
radius : float
Earth radius in meters.
Returns
-------
ndarray
2D array of cell areas in square meters.
Notes
-----
Assumes uniform spacing in lon and lat.
Area of a spherical rectangle:
A = R^2 * |sin(lat1) - sin(lat2)| * |lon2 - lon1|
"""
# Get grid spacing
if lon.ndim == 2:
dlon = np.abs(lon[0, 1] - lon[0, 0])
dlat = np.abs(lat[1, 0] - lat[0, 0])
else:
dlon = np.abs(lon[1] - lon[0])
dlat = np.abs(lat[1] - lat[0])
dlon_rad = np.deg2rad(dlon)
lat_rad = np.deg2rad(lat)
# Half grid spacing in lat
dlat_rad = np.deg2rad(dlat / 2)
# sin(lat + dlat/2) - sin(lat - dlat/2)
sin_diff = np.sin(lat_rad + dlat_rad) - np.sin(lat_rad - dlat_rad)
area = radius**2 * np.abs(sin_diff) * dlon_rad
return area
[docs]
def expand_bounds_for_polar(
lon_bounds: tuple[float, float],
lat_bounds: tuple[float, float],
factor: float = 1.414, # sqrt(2)
) -> tuple[tuple[float, float], tuple[float, float]]:
"""Expand bounding box for polar projections.
Polar projections need a larger data extent to fill the circular
plot area without gaps.
Parameters
----------
lon_bounds : tuple of float
Original longitude bounds.
lat_bounds : tuple of float
Original latitude bounds.
factor : float
Expansion factor. Default is sqrt(2).
Returns
-------
lon_bounds, lat_bounds : tuple of tuples
Expanded bounds.
"""
lon_min, lon_max = lon_bounds
lat_min, lat_max = lat_bounds
lon_center = (lon_min + lon_max) / 2
lat_center = (lat_min + lat_max) / 2
lon_half = (lon_max - lon_min) / 2 * factor
lat_half = (lat_max - lat_min) / 2 * factor
new_lon_bounds = (
max(-180.0, lon_center - lon_half),
min(180.0, lon_center + lon_half),
)
new_lat_bounds = (
max(-90.0, lat_center - lat_half),
min(90.0, lat_center + lat_half),
)
return new_lon_bounds, new_lat_bounds