"""Core mesh utilities for nereus.
This module provides utilities for creating, validating, and normalizing
mesh datasets. Meshes are represented as xr.Dataset objects with
standardized variable names.
"""
from __future__ import annotations
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import numpy as np
import xarray as xr
from numpy.typing import NDArray
from nereus.core.coordinates import EARTH_RADIUS
if TYPE_CHECKING:
pass
# Threshold for automatic dask usage
DASK_THRESHOLD_POINTS = 1_000_000
# Nereus mesh metadata version
MESH_VERSION = "1.0"
[docs]
def should_use_dask(npoints: int, use_dask: bool | None = None) -> bool:
"""Determine whether to use dask arrays based on mesh size.
Parameters
----------
npoints : int
Number of mesh points.
use_dask : bool, optional
Explicit setting. If None, auto-detects based on threshold.
Returns
-------
bool
Whether to use dask arrays.
"""
if use_dask is not None:
return use_dask
return npoints > DASK_THRESHOLD_POINTS
[docs]
def normalize_lon(
lon: NDArray[np.floating],
convention: Literal["pm180", "0360"] = "pm180",
) -> NDArray[np.floating]:
"""Normalize longitude to specified convention.
Parameters
----------
lon : array_like
Longitude values in degrees.
convention : {"pm180", "0360"}
Target convention:
- "pm180": [-180, 180]
- "0360": [0, 360]
Returns
-------
ndarray
Normalized longitude values.
"""
lon = np.asarray(lon, dtype=np.float64)
if convention == "pm180":
# Normalize to [-180, 180]
lon = np.mod(lon + 180, 360) - 180
elif convention == "0360":
# Normalize to [0, 360]
lon = np.mod(lon, 360)
else:
raise ValueError(f"Unknown convention: {convention}")
return lon
[docs]
def ensure_lon_pm180(ds: xr.Dataset) -> xr.Dataset:
"""Ensure longitude is normalized to [-180, 180].
Parameters
----------
ds : xr.Dataset
Mesh dataset with 'lon' variable.
Returns
-------
xr.Dataset
Dataset with normalized longitude.
"""
if "lon" not in ds:
return ds
lon_data = ds["lon"].values
if np.any(lon_data > 180) or np.any(lon_data < -180):
ds = ds.copy()
ds["lon"] = (ds["lon"].dims, normalize_lon(lon_data, "pm180"))
ds["lon"].attrs = {
"units": "degrees_east",
"long_name": "Longitude",
"standard_name": "longitude",
}
return ds
[docs]
def validate_mesh(ds: xr.Dataset, strict: bool = False) -> list[str]:
"""Validate mesh dataset against nereus conventions.
Parameters
----------
ds : xr.Dataset
Dataset to validate.
strict : bool
If True, raise ValueError on validation errors.
Returns
-------
list of str
List of validation warnings/errors.
Raises
------
ValueError
If strict=True and validation fails.
"""
errors = []
# Check required variables
required = ["lon", "lat", "area"]
for var in required:
if var not in ds:
errors.append(f"Missing required variable: {var}")
# Check lon/lat have same dimension
if "lon" in ds and "lat" in ds:
if ds["lon"].dims != ds["lat"].dims:
errors.append(
f"lon and lat have different dimensions: "
f"{ds['lon'].dims} vs {ds['lat'].dims}"
)
# Check lon range
if "lon" in ds:
lon_min = float(ds["lon"].min())
lon_max = float(ds["lon"].max())
if lon_min < -180 or lon_max > 180:
errors.append(
f"Longitude out of [-180, 180] range: [{lon_min}, {lon_max}]"
)
# Check lat range
if "lat" in ds:
lat_min = float(ds["lat"].min())
lat_max = float(ds["lat"].max())
if lat_min < -90 or lat_max > 90:
errors.append(
f"Latitude out of [-90, 90] range: [{lat_min}, {lat_max}]"
)
# Check area is positive
if "area" in ds:
if float(ds["area"].min()) <= 0:
errors.append("Area contains non-positive values")
# Check triangles if present
if "triangles" in ds:
tri_min = int(ds["triangles"].min())
if tri_min < 0:
errors.append(f"Triangles contain negative indices: min={tri_min}")
if "lon" in ds:
npoints = ds.sizes.get("npoints", len(ds["lon"]))
tri_max = int(ds["triangles"].max())
if tri_max >= npoints:
errors.append(
f"Triangle index {tri_max} exceeds npoints={npoints}"
)
if strict and errors:
raise ValueError("Mesh validation failed:\n" + "\n".join(errors))
return errors
[docs]
def is_nereus_mesh(ds: xr.Dataset) -> bool:
"""Check if dataset is a nereus mesh.
Parameters
----------
ds : xr.Dataset
Dataset to check.
Returns
-------
bool
True if dataset has nereus mesh metadata.
"""
return "nereus_mesh_type" in ds.attrs
[docs]
def get_mesh_type(ds: xr.Dataset) -> str | None:
"""Get mesh type from nereus mesh dataset.
Parameters
----------
ds : xr.Dataset
Mesh dataset.
Returns
-------
str or None
Mesh type, or None if not a nereus mesh.
"""
return ds.attrs.get("nereus_mesh_type")
[docs]
def create_lonlat_mesh(
resolution: float | tuple[float, float],
*,
lon_bounds: tuple[float, float] = (-180, 180),
lat_bounds: tuple[float, float] = (-90, 90),
use_dask: bool | None = None,
) -> xr.Dataset:
"""Create regular lon-lat mesh.
Parameters
----------
resolution : float or tuple
Grid resolution in degrees. If tuple, (dlon, dlat).
lon_bounds : tuple
Longitude bounds (min, max).
lat_bounds : tuple
Latitude bounds (min, max).
use_dask : bool, optional
Whether to use dask arrays. Auto-detects if None.
Returns
-------
xr.Dataset
Mesh dataset with flattened lon, lat, area.
"""
if isinstance(resolution, (int, float)):
dlon = dlat = float(resolution)
else:
dlon, dlat = resolution
# Create 1D coordinate arrays
lon_1d = np.arange(lon_bounds[0] + dlon / 2, lon_bounds[1], dlon)
lat_1d = np.arange(lat_bounds[0] + dlat / 2, lat_bounds[1], dlat)
# Create 2D grid and flatten
lon_2d, lat_2d = np.meshgrid(lon_1d, lat_1d)
lon_flat = lon_2d.ravel()
lat_flat = lat_2d.ravel()
# Compute cell areas
area_flat = _compute_lonlat_cell_area(lat_flat, dlon, dlat)
npoints = len(lon_flat)
use_dask_actual = should_use_dask(npoints, use_dask)
if use_dask_actual:
import dask.array as da
lon_data = da.from_array(lon_flat, chunks=-1)
lat_data = da.from_array(lat_flat, chunks=-1)
area_data = da.from_array(area_flat, chunks=-1)
else:
lon_data = lon_flat
lat_data = lat_flat
area_data = area_flat
ds = xr.Dataset(
{
"lon": (("npoints",), lon_data, {
"units": "degrees_east",
"long_name": "Longitude",
"standard_name": "longitude",
}),
"lat": (("npoints",), lat_data, {
"units": "degrees_north",
"long_name": "Latitude",
"standard_name": "latitude",
}),
"area": (("npoints",), area_data, {
"units": "m2",
"long_name": "Cell area",
}),
},
attrs={
"resolution_lon": dlon,
"resolution_lat": dlat,
"nlon": len(lon_1d),
"nlat": len(lat_1d),
},
)
return add_mesh_metadata(ds, "lonlat", use_dask=use_dask_actual)
[docs]
def mesh_from_arrays(
lon: NDArray[np.floating],
lat: NDArray[np.floating],
*,
area: NDArray[np.floating] | None = None,
use_dask: bool | None = None,
) -> xr.Dataset:
"""Create mesh from existing coordinate arrays.
Handles 1D, 2D, and regular grid side coordinates. If lon and lat are
both 1D but have different sizes, they are treated as regular grid side
coordinates and expanded via meshgrid before flattening. 2D arrays are
flattened directly.
Parameters
----------
lon : array_like
Longitude coordinates in degrees.
lat : array_like
Latitude coordinates in degrees.
area : array_like, optional
Cell areas in m^2. If None, estimates from grid spacing.
use_dask : bool, optional
Whether to use dask arrays. Auto-detects if None.
Returns
-------
xr.Dataset
Mesh dataset with lon, lat, area.
"""
lon = np.asarray(lon, dtype=np.float64)
lat = np.asarray(lat, dtype=np.float64)
# Handle 1D arrays of different sizes (regular grid side coordinates)
is_regular_grid = False
if lon.ndim == 1 and lat.ndim == 1 and lon.size != lat.size:
warnings.warn(
f"Creating meshgrid from 1D lon ({lon.size}) and lat ({lat.size}), "
"then raveling to 1D.",
stacklevel=2,
)
dlon = float(np.median(np.diff(lon)))
dlat = float(np.median(np.diff(lat)))
lon, lat = np.meshgrid(lon, lat)
lon = lon.ravel()
lat = lat.ravel()
is_regular_grid = True
# Flatten if 2D
elif lon.ndim == 2:
lon = lon.ravel()
lat = lat.ravel()
# Normalize longitude
lon = normalize_lon(lon, "pm180")
npoints = len(lon)
use_dask_actual = should_use_dask(npoints, use_dask)
# Estimate area if not provided
if area is None:
if is_regular_grid:
area = _compute_lonlat_cell_area(lat, dlon, dlat)
else:
area = _estimate_area(lon, lat)
else:
area = np.asarray(area, dtype=np.float64)
if area.ndim == 2:
area = area.ravel()
if use_dask_actual:
import dask.array as da
lon_data = da.from_array(lon, chunks=-1)
lat_data = da.from_array(lat, chunks=-1)
area_data = da.from_array(area, chunks=-1)
else:
lon_data = lon
lat_data = lat
area_data = area
ds = xr.Dataset(
{
"lon": (("npoints",), lon_data, {
"units": "degrees_east",
"long_name": "Longitude",
"standard_name": "longitude",
}),
"lat": (("npoints",), lat_data, {
"units": "degrees_north",
"long_name": "Latitude",
"standard_name": "latitude",
}),
"area": (("npoints",), area_data, {
"units": "m2",
"long_name": "Cell area",
}),
},
)
return add_mesh_metadata(ds, "custom", use_dask=use_dask_actual)
def _compute_lonlat_cell_area(
lat: NDArray[np.floating],
dlon: float,
dlat: float,
) -> NDArray[np.floating]:
"""Compute area of regular lon-lat grid cells.
Parameters
----------
lat : array_like
Cell center latitudes in degrees.
dlon : float
Longitudinal grid spacing in degrees.
dlat : float
Latitudinal grid spacing in degrees.
Returns
-------
ndarray
Cell areas in m^2.
"""
# Convert to radians
lat_rad = np.deg2rad(lat)
dlon_rad = np.deg2rad(dlon)
dlat_rad = np.deg2rad(dlat)
# Area = R^2 * dlon * (sin(lat+dlat/2) - sin(lat-dlat/2))
lat_min = lat_rad - dlat_rad / 2
lat_max = lat_rad + dlat_rad / 2
# Clip to valid range
lat_min = np.clip(lat_min, -np.pi / 2, np.pi / 2)
lat_max = np.clip(lat_max, -np.pi / 2, np.pi / 2)
area = EARTH_RADIUS**2 * dlon_rad * (np.sin(lat_max) - np.sin(lat_min))
return np.abs(area)
def _estimate_area(
lon: NDArray[np.floating],
lat: NDArray[np.floating],
) -> NDArray[np.floating]:
"""Estimate cell area from irregular grid (rough approximation).
Uses Earth surface area divided by number of points.
Parameters
----------
lon : array_like
Longitude in degrees.
lat : array_like
Latitude in degrees.
Returns
-------
ndarray
Estimated cell areas in m^2.
"""
npoints = len(lon)
earth_area = 4 * np.pi * EARTH_RADIUS**2
avg_area = earth_area / npoints
# Simple latitude weighting
lat_rad = np.deg2rad(lat)
weights = np.cos(lat_rad)
weights = weights / weights.mean()
return avg_area * weights