"""Regridding interpolator for unstructured to regular grid conversion.
This module provides the RegridInterpolator class for efficiently regridding
unstructured data (like FESOM, ICON) to regular lat/lon grids.
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
from numpy.typing import NDArray
from scipy.spatial import cKDTree
from nereus.core.coordinates import lonlat_to_cartesian, meters_to_chord
from nereus.core.grids import (
create_regular_grid,
extract_coordinates,
flatten_spatial,
prepare_coordinates,
)
if TYPE_CHECKING:
import xarray as xr
def _normalize_lon(lon: NDArray[np.floating], center: float) -> NDArray[np.floating]:
"""Normalize longitudes to a 360-degree range centered on *center*.
Maps every value into [center - 180, center + 180).
"""
return (lon - (center - 180)) % 360 + (center - 180)
[docs]
@dataclass
class RegridInterpolator:
"""Pre-computed interpolation for fast repeated regridding.
This class computes and stores interpolation weights for regridding
unstructured data to a regular grid. The computation is done once
during initialization, allowing fast repeated application.
Parameters
----------
source_lon : array_like
Source grid longitude coordinates in degrees.
source_lat : array_like
Source grid latitude coordinates in degrees.
resolution : float or tuple of int
Target grid resolution. If float, specifies degrees per cell.
If tuple (nlon, nlat), specifies number of grid points.
method : {"nearest", "idw", "linear", "cubic"}
Interpolation method. "nearest" uses nearest-neighbor lookup via
KDTree (fast). "idw" uses inverse distance weighting with 8
nearest neighbors (fast, smooth). "linear" uses Delaunay
triangulation with barycentric interpolation (slower but
smoother). "cubic" uses Clough-Tocher C1 interpolation for the
smoothest results. Source longitudes are automatically
normalized to match the target grid's ``lon_bounds`` so that any
input convention (0-360 or -180-180) works transparently.
influence_radius : float
Maximum influence radius in meters. Points beyond this distance
from any source point are masked. Default is 80 km.
lon_bounds : tuple of float
Target grid longitude bounds. Default is (-180, 180).
lat_bounds : tuple of float
Target grid latitude bounds. Default is (-90, 90).
Attributes
----------
target_lon : ndarray
Target grid longitude coordinates (2D).
target_lat : ndarray
Target grid latitude coordinates (2D).
indices : ndarray
Source indices for each target point.
distances : ndarray
Distances from target to source points (in chord units).
valid_mask : ndarray
Boolean mask of valid target points within influence radius.
Examples
--------
>>> interpolator = RegridInterpolator(mesh_lon, mesh_lat, resolution=1.0)
>>> regridded = interpolator(data)
>>> regridded.shape
(180, 360)
Use linear interpolation for smoother results:
>>> interpolator = RegridInterpolator(
... mesh_lon, mesh_lat, resolution=1.0, method="linear"
... )
"""
source_lon: NDArray[np.floating]
source_lat: NDArray[np.floating]
resolution: float | tuple[int, int] = 1.0
method: Literal["nearest", "idw", "linear", "cubic"] = "nearest"
influence_radius: float = 80_000.0
lon_bounds: tuple[float, float] = (-180.0, 180.0)
lat_bounds: tuple[float, float] = (-90.0, 90.0)
# Computed attributes (initialized in __post_init__)
target_lon: NDArray[np.floating] = field(init=False, repr=False)
target_lat: NDArray[np.floating] = field(init=False, repr=False)
indices: NDArray[np.intp] = field(init=False, repr=False)
distances: NDArray[np.floating] = field(init=False, repr=False)
valid_mask: NDArray[np.bool_] = field(init=False, repr=False)
_tree: cKDTree = field(init=False, repr=False)
_delaunay: Any = field(init=False, repr=False, default=None)
_source_2d: NDArray[np.floating] | None = field(
init=False, repr=False, default=None
)
_idw_weights: NDArray[np.floating] | None = field(
init=False, repr=False, default=None
)
_idw_indices: NDArray[np.intp] | None = field(
init=False, repr=False, default=None
)
[docs]
def __post_init__(self) -> None:
"""Initialize interpolation weights."""
# Prepare source coordinates: handle 1D/2D and validate
self.source_lon, self.source_lat = prepare_coordinates(
self.source_lon, self.source_lat
)
# Create target grid
self.target_lon, self.target_lat = create_regular_grid(
self.resolution,
lon_bounds=self.lon_bounds,
lat_bounds=self.lat_bounds,
)
# Convert source coordinates to Cartesian (unit sphere)
source_xyz = np.column_stack(
lonlat_to_cartesian(self.source_lon, self.source_lat)
)
# Build KDTree
self._tree = cKDTree(source_xyz)
# Convert target coordinates to Cartesian
target_xyz = np.column_stack(
lonlat_to_cartesian(self.target_lon.ravel(), self.target_lat.ravel())
)
# Query nearest neighbors
self.distances, self.indices = self._tree.query(target_xyz, k=1)
# Reshape to target grid shape
self.distances = self.distances.reshape(self.target_lon.shape)
self.indices = self.indices.reshape(self.target_lon.shape)
# Create valid mask based on influence radius
max_chord = meters_to_chord(self.influence_radius)
self.valid_mask = self.distances <= max_chord
# Pre-compute IDW weights
if self.method == "idw":
k = 8
dists, idxs = self._tree.query(target_xyz, k=k)
target_shape = self.target_lon.shape
dists = dists.reshape(target_shape + (k,))
idxs = idxs.reshape(target_shape + (k,))
# Inverse distance squared weights
# Handle exact matches (distance == 0)
exact = dists == 0.0
has_exact = exact.any(axis=-1)
weights = np.zeros_like(dists)
# For points with an exact match, set weight=1 for first exact neighbor
weights[has_exact] = 0.0
first_exact = exact & (np.cumsum(exact, axis=-1) == 1)
weights[first_exact] = 1.0
# For points without exact match, use 1/d^2
with np.errstate(divide="ignore"):
inv_d2 = np.where(
~has_exact[..., np.newaxis], 1.0 / dists**2, weights
)
weights = np.where(has_exact[..., np.newaxis], weights, inv_d2)
# Normalize so weights sum to 1
weight_sum = weights.sum(axis=-1, keepdims=True)
weight_sum = np.where(weight_sum == 0.0, 1.0, weight_sum)
weights = weights / weight_sum
self._idw_weights = weights
self._idw_indices = idxs
# Build Delaunay triangulation for linear/cubic interpolation
if self.method in ("linear", "cubic"):
from scipy.spatial import Delaunay
# Normalize source longitudes to match the target grid range
# so that e.g. 0-360 source works with -180-180 target.
lon_center = (self.lon_bounds[0] + self.lon_bounds[1]) / 2
source_lon_norm = _normalize_lon(self.source_lon, lon_center)
self._source_2d = np.column_stack([source_lon_norm, self.source_lat])
self._delaunay = Delaunay(self._source_2d)
[docs]
def __call__(
self,
data: NDArray | "xr.DataArray",
fill_value: float = np.nan,
) -> NDArray[np.floating]:
"""Apply interpolation to data.
Parameters
----------
data : array_like
Data to interpolate. Can be:
- 1D array of shape (npoints,)
- 2D array of shape (nlevels, npoints) or (ntime, npoints)
- ND array with last axis = npoints
fill_value : float
Value for invalid points outside influence radius.
Returns
-------
ndarray
Regridded data. Shape depends on input:
- 1D input: (nlat, nlon)
- 2D input: (extra_dim, nlat, nlon)
- ND input: (*leading_dims, nlat, nlon)
"""
# Handle xarray DataArray
if hasattr(data, "values"):
data = data.values
data = np.asarray(data)
target_shape = self.target_lon.shape
# Handle different input dimensions
if data.ndim == 1:
# Simple 1D case
result = self._interpolate_1d(data, fill_value)
elif data.ndim == 2:
# 2D case: (extra_dim, npoints)
n_extra = data.shape[0]
result = np.empty((n_extra,) + target_shape, dtype=np.float64)
for i in range(n_extra):
result[i] = self._interpolate_1d(data[i], fill_value)
else:
# ND case: (*leading_dims, npoints)
leading_shape = data.shape[:-1]
npoints = data.shape[-1]
data_flat = data.reshape(-1, npoints)
result_flat = np.empty(
(data_flat.shape[0],) + target_shape, dtype=np.float64
)
for i in range(data_flat.shape[0]):
result_flat[i] = self._interpolate_1d(data_flat[i], fill_value)
result = result_flat.reshape(leading_shape + target_shape)
return result
def _interpolate_1d(
self,
data: NDArray[np.floating],
fill_value: float,
) -> NDArray[np.floating]:
"""Interpolate 1D data array."""
if self.method == "nearest":
result = data[self.indices]
if not np.isnan(fill_value):
result = result.astype(np.float64)
result[~self.valid_mask] = fill_value
elif self.method == "idw":
result = np.sum(
self._idw_weights * data[self._idw_indices], axis=-1
)
result[~self.valid_mask] = fill_value
elif self.method == "linear":
from scipy.interpolate import LinearNDInterpolator
interp = LinearNDInterpolator(
self._delaunay, data, fill_value=fill_value
)
target_2d = np.column_stack(
[self.target_lon.ravel(), self.target_lat.ravel()]
)
result = interp(target_2d).reshape(self.target_lon.shape)
# Apply distance-based valid_mask on top
result[~self.valid_mask] = fill_value
elif self.method == "cubic":
from scipy.interpolate import CloughTocher2DInterpolator
valid_src = np.isfinite(data)
if valid_src.all():
interp = CloughTocher2DInterpolator(
self._delaunay, data, fill_value=fill_value
)
else:
interp = CloughTocher2DInterpolator(
self._source_2d[valid_src],
data[valid_src],
fill_value=fill_value,
)
target_2d = np.column_stack(
[self.target_lon.ravel(), self.target_lat.ravel()]
)
result = interp(target_2d).reshape(self.target_lon.shape)
# Apply distance-based valid_mask on top
result[~self.valid_mask] = fill_value
else:
raise ValueError(f"Unknown method: {self.method!r}")
return result
@property
def shape(self) -> tuple[int, int]:
"""Shape of target grid (nlat, nlon)."""
return self.target_lon.shape
[docs]
def regrid(
data: NDArray | "xr.DataArray",
lon: NDArray[np.floating] | None = None,
lat: NDArray[np.floating] | None = None,
resolution: float | tuple[int, int] = 1.0,
method: Literal["nearest", "idw", "linear", "cubic"] = "nearest",
influence_radius: float = 80_000.0,
fill_value: float = np.nan,
lon_bounds: tuple[float, float] = (-180.0, 180.0),
lat_bounds: tuple[float, float] = (-90.0, 90.0),
as_xarray: bool = False,
) -> tuple[NDArray[np.floating], RegridInterpolator]:
"""Regrid unstructured data to regular grid.
This is a convenience function that creates a RegridInterpolator and
applies it. For repeated regridding with the same source grid, create
a RegridInterpolator once and reuse it.
Supports multi-dimensional data where the last axis contains the spatial
points. For example:
- 1D data (npoints,): single field
- 2D data (nlevels, npoints): multi-level unstructured data (e.g., FESOM, ICON)
- ND data (*dims, npoints): arbitrary leading dimensions
Coordinate arrays can be:
- 1D arrays of same size: unstructured mesh coordinates (used directly)
- 1D arrays of different sizes: regular grid side coordinates (meshgrid created)
- 2D arrays of same shape: full coordinate arrays (raveled to 1D)
A warning is issued whenever coordinate transformations are applied.
If lon/lat are not provided and data is an xarray DataArray, the function
will attempt to extract coordinates automatically by looking for common
coordinate names (lon/lat, longitude/latitude, x/y, etc.).
Parameters
----------
data : array_like
Data to interpolate. Last axis must be npoints (matching coordinates).
Can be 1D (npoints,), 2D (nlevels, npoints), or ND (*dims, npoints).
If xarray DataArray, coordinates may be extracted automatically.
lon : array_like, optional
Source grid longitude coordinates. Can be 1D or 2D array.
If None, will attempt to extract from data (xarray only).
lat : array_like, optional
Source grid latitude coordinates. Can be 1D or 2D array.
If None, will attempt to extract from data (xarray only).
resolution : float or tuple of int
Target grid resolution.
method : {"nearest", "idw", "linear", "cubic"}
Interpolation method. "nearest" uses nearest-neighbor lookup.
"idw" uses inverse distance weighting (fast, smooth). "linear"
uses Delaunay triangulation with barycentric interpolation.
"cubic" uses Clough-Tocher C1 interpolation (smoothest).
influence_radius : float
Maximum influence radius in meters.
fill_value : float
Value for invalid points.
lon_bounds : tuple of float
Target grid longitude bounds.
lat_bounds : tuple of float
Target grid latitude bounds.
as_xarray : bool, default False
If True, wrap the regridded array in an ``xr.DataArray`` with
``lat`` and ``lon`` as 1-D dimension coordinates. Leading
dimensions (e.g. time, depth) and their coordinates are
preserved when the input is an ``xr.DataArray``. The return
type of the tuple's first element changes from ``NDArray`` to
``xr.DataArray``.
Returns
-------
regridded : ndarray or xr.DataArray
Regridded data. Returns ``xr.DataArray`` when ``as_xarray=True``,
otherwise ``ndarray``.
interpolator : RegridInterpolator
The interpolator used (can be reused for other variables).
"""
# Extract coordinates from xarray if not provided
if lon is None or lat is None:
extracted_lon, extracted_lat = extract_coordinates(data)
if lon is None:
lon = extracted_lon
if lat is None:
lat = extracted_lat
# Validate that we have coordinates
if lon is None or lat is None:
raise ValueError(
"lon and lat coordinates are required. Either provide them explicitly "
"or use an xarray DataArray with recognizable coordinate names "
"(lon/lat, longitude/latitude, x/y, etc.)."
)
# Handle xarray DataArray
if hasattr(data, "values"):
data_values = data.values
else:
data_values = np.asarray(data)
lon_arr = np.asarray(lon)
lat_arr = np.asarray(lat)
# Determine the data/coordinate format and prepare accordingly
# Key insight: for unstructured data, lon and lat have SAME size matching data's last dim
# For regular grids, lon and lat have DIFFERENT sizes matching data's last two dims
if lon_arr.ndim == 1 and lat_arr.ndim == 1:
if lon_arr.size == lat_arr.size:
# Case: Unstructured mesh coordinates (both 1D, same size)
# Data can be 1D (npoints,) or multi-level (nlevels, npoints)
npoints = data_values.shape[-1] if data_values.ndim >= 1 else data_values.size
if lon_arr.size != npoints:
raise ValueError(
f"Coordinate size ({lon_arr.size}) must match data's last dimension ({npoints}). "
f"Data shape: {data_values.shape}"
)
# Coordinates are ready, data stays as-is
else:
# Case: Regular grid with side coordinates (1D lon, 1D lat, different sizes)
# Data shape should be (..., nlat, nlon) or (nlat, nlon)
if data_values.ndim < 2:
raise ValueError(
f"For regular grid coordinates (lon size {lon_arr.size}, lat size {lat_arr.size}), "
f"data must be at least 2D, got shape {data_values.shape}"
)
nlat, nlon = data_values.shape[-2], data_values.shape[-1]
if lon_arr.size != nlon or lat_arr.size != nlat:
raise ValueError(
f"Coordinate sizes (lon: {lon_arr.size}, lat: {lat_arr.size}) must match "
f"data's last two dimensions (nlat: {nlat}, nlon: {nlon}). "
f"Data shape: {data_values.shape}"
)
# Create meshgrid and ravel
warnings.warn(
f"Creating meshgrid from 1D lon ({lon_arr.size}) and lat ({lat_arr.size}) "
f"for data (shape {data_values.shape}), then raveling spatial dimensions.",
stacklevel=2,
)
lon_arr, lat_arr = np.meshgrid(lon_arr, lat_arr)
lon_arr = lon_arr.ravel()
lat_arr = lat_arr.ravel()
data_values = flatten_spatial(data_values)
else:
# 2D coordinates - prepare them and handle data accordingly
lon_arr, lat_arr = prepare_coordinates(lon_arr, lat_arr)
# Ravel data if it matches the original 2D coordinate shape
if data_values.ndim >= 2 and data_values.shape[-2:] == np.asarray(lon).shape:
data_values = flatten_spatial(data_values)
interpolator = RegridInterpolator(
source_lon=lon_arr,
source_lat=lat_arr,
resolution=resolution,
method=method,
influence_radius=influence_radius,
lon_bounds=lon_bounds,
lat_bounds=lat_bounds,
)
regridded = interpolator(data_values, fill_value=fill_value)
if as_xarray:
import xarray as xr
lat_1d = interpolator.target_lat[:, 0]
lon_1d = interpolator.target_lon[0, :]
if hasattr(data, "dims"):
n_leading = regridded.ndim - 2
leading_dim_names = list(data.dims[:n_leading])
leading_coords = {d: data.coords[d] for d in leading_dim_names if d in data.coords}
var_name = data.name or "data"
var_attrs = dict(data.attrs)
else:
n_leading = regridded.ndim - 2
leading_dim_names = [f"dim_{i}" for i in range(n_leading)]
leading_coords = {}
var_name = "data"
var_attrs = {}
dims = (*leading_dim_names, "lat", "lon")
coords = {
**leading_coords,
"lat": ("lat", lat_1d, {"units": "degrees_north", "standard_name": "latitude"}),
"lon": ("lon", lon_1d, {"units": "degrees_east", "standard_name": "longitude"}),
}
regridded = xr.DataArray(
regridded,
dims=dims,
coords=coords,
name=var_name,
attrs=var_attrs,
)
return regridded, interpolator