Source code for nereus.regrid.interpolator

"""Regridding interpolator for unstructured to regular grid conversion.

This module provides the RegridInterpolator class for efficiently regridding
unstructured data (like FESOM, ICON) to regular lat/lon grids.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
from numpy.typing import NDArray
from scipy.spatial import cKDTree

from nereus.core.coordinates import lonlat_to_cartesian, meters_to_chord
from nereus.core.grids import (
    create_regular_grid,
    extract_coordinates,
    flatten_spatial,
    prepare_coordinates,
)

if TYPE_CHECKING:
    import xarray as xr


def _normalize_lon(lon: NDArray[np.floating], center: float) -> NDArray[np.floating]:
    """Normalize longitudes to a 360-degree range centered on *center*.

    Maps every value into [center - 180, center + 180).
    """
    return (lon - (center - 180)) % 360 + (center - 180)


[docs] @dataclass class RegridInterpolator: """Pre-computed interpolation for fast repeated regridding. This class computes and stores interpolation weights for regridding unstructured data to a regular grid. The computation is done once during initialization, allowing fast repeated application. Parameters ---------- source_lon : array_like Source grid longitude coordinates in degrees. source_lat : array_like Source grid latitude coordinates in degrees. resolution : float or tuple of int Target grid resolution. If float, specifies degrees per cell. If tuple (nlon, nlat), specifies number of grid points. method : {"nearest", "idw", "linear", "cubic"} Interpolation method. "nearest" uses nearest-neighbor lookup via KDTree (fast). "idw" uses inverse distance weighting with 8 nearest neighbors (fast, smooth). "linear" uses Delaunay triangulation with barycentric interpolation (slower but smoother). "cubic" uses Clough-Tocher C1 interpolation for the smoothest results. Source longitudes are automatically normalized to match the target grid's ``lon_bounds`` so that any input convention (0-360 or -180-180) works transparently. influence_radius : float Maximum influence radius in meters. Points beyond this distance from any source point are masked. Default is 80 km. lon_bounds : tuple of float Target grid longitude bounds. Default is (-180, 180). lat_bounds : tuple of float Target grid latitude bounds. Default is (-90, 90). Attributes ---------- target_lon : ndarray Target grid longitude coordinates (2D). target_lat : ndarray Target grid latitude coordinates (2D). indices : ndarray Source indices for each target point. distances : ndarray Distances from target to source points (in chord units). valid_mask : ndarray Boolean mask of valid target points within influence radius. Examples -------- >>> interpolator = RegridInterpolator(mesh_lon, mesh_lat, resolution=1.0) >>> regridded = interpolator(data) >>> regridded.shape (180, 360) Use linear interpolation for smoother results: >>> interpolator = RegridInterpolator( ... mesh_lon, mesh_lat, resolution=1.0, method="linear" ... ) """ source_lon: NDArray[np.floating] source_lat: NDArray[np.floating] resolution: float | tuple[int, int] = 1.0 method: Literal["nearest", "idw", "linear", "cubic"] = "nearest" influence_radius: float = 80_000.0 lon_bounds: tuple[float, float] = (-180.0, 180.0) lat_bounds: tuple[float, float] = (-90.0, 90.0) # Computed attributes (initialized in __post_init__) target_lon: NDArray[np.floating] = field(init=False, repr=False) target_lat: NDArray[np.floating] = field(init=False, repr=False) indices: NDArray[np.intp] = field(init=False, repr=False) distances: NDArray[np.floating] = field(init=False, repr=False) valid_mask: NDArray[np.bool_] = field(init=False, repr=False) _tree: cKDTree = field(init=False, repr=False) _delaunay: Any = field(init=False, repr=False, default=None) _source_2d: NDArray[np.floating] | None = field( init=False, repr=False, default=None ) _idw_weights: NDArray[np.floating] | None = field( init=False, repr=False, default=None ) _idw_indices: NDArray[np.intp] | None = field( init=False, repr=False, default=None )
[docs] def __post_init__(self) -> None: """Initialize interpolation weights.""" # Prepare source coordinates: handle 1D/2D and validate self.source_lon, self.source_lat = prepare_coordinates( self.source_lon, self.source_lat ) # Create target grid self.target_lon, self.target_lat = create_regular_grid( self.resolution, lon_bounds=self.lon_bounds, lat_bounds=self.lat_bounds, ) # Convert source coordinates to Cartesian (unit sphere) source_xyz = np.column_stack( lonlat_to_cartesian(self.source_lon, self.source_lat) ) # Build KDTree self._tree = cKDTree(source_xyz) # Convert target coordinates to Cartesian target_xyz = np.column_stack( lonlat_to_cartesian(self.target_lon.ravel(), self.target_lat.ravel()) ) # Query nearest neighbors self.distances, self.indices = self._tree.query(target_xyz, k=1) # Reshape to target grid shape self.distances = self.distances.reshape(self.target_lon.shape) self.indices = self.indices.reshape(self.target_lon.shape) # Create valid mask based on influence radius max_chord = meters_to_chord(self.influence_radius) self.valid_mask = self.distances <= max_chord # Pre-compute IDW weights if self.method == "idw": k = 8 dists, idxs = self._tree.query(target_xyz, k=k) target_shape = self.target_lon.shape dists = dists.reshape(target_shape + (k,)) idxs = idxs.reshape(target_shape + (k,)) # Inverse distance squared weights # Handle exact matches (distance == 0) exact = dists == 0.0 has_exact = exact.any(axis=-1) weights = np.zeros_like(dists) # For points with an exact match, set weight=1 for first exact neighbor weights[has_exact] = 0.0 first_exact = exact & (np.cumsum(exact, axis=-1) == 1) weights[first_exact] = 1.0 # For points without exact match, use 1/d^2 with np.errstate(divide="ignore"): inv_d2 = np.where( ~has_exact[..., np.newaxis], 1.0 / dists**2, weights ) weights = np.where(has_exact[..., np.newaxis], weights, inv_d2) # Normalize so weights sum to 1 weight_sum = weights.sum(axis=-1, keepdims=True) weight_sum = np.where(weight_sum == 0.0, 1.0, weight_sum) weights = weights / weight_sum self._idw_weights = weights self._idw_indices = idxs # Build Delaunay triangulation for linear/cubic interpolation if self.method in ("linear", "cubic"): from scipy.spatial import Delaunay # Normalize source longitudes to match the target grid range # so that e.g. 0-360 source works with -180-180 target. lon_center = (self.lon_bounds[0] + self.lon_bounds[1]) / 2 source_lon_norm = _normalize_lon(self.source_lon, lon_center) self._source_2d = np.column_stack([source_lon_norm, self.source_lat]) self._delaunay = Delaunay(self._source_2d)
[docs] def __call__( self, data: NDArray | "xr.DataArray", fill_value: float = np.nan, ) -> NDArray[np.floating]: """Apply interpolation to data. Parameters ---------- data : array_like Data to interpolate. Can be: - 1D array of shape (npoints,) - 2D array of shape (nlevels, npoints) or (ntime, npoints) - ND array with last axis = npoints fill_value : float Value for invalid points outside influence radius. Returns ------- ndarray Regridded data. Shape depends on input: - 1D input: (nlat, nlon) - 2D input: (extra_dim, nlat, nlon) - ND input: (*leading_dims, nlat, nlon) """ # Handle xarray DataArray if hasattr(data, "values"): data = data.values data = np.asarray(data) target_shape = self.target_lon.shape # Handle different input dimensions if data.ndim == 1: # Simple 1D case result = self._interpolate_1d(data, fill_value) elif data.ndim == 2: # 2D case: (extra_dim, npoints) n_extra = data.shape[0] result = np.empty((n_extra,) + target_shape, dtype=np.float64) for i in range(n_extra): result[i] = self._interpolate_1d(data[i], fill_value) else: # ND case: (*leading_dims, npoints) leading_shape = data.shape[:-1] npoints = data.shape[-1] data_flat = data.reshape(-1, npoints) result_flat = np.empty( (data_flat.shape[0],) + target_shape, dtype=np.float64 ) for i in range(data_flat.shape[0]): result_flat[i] = self._interpolate_1d(data_flat[i], fill_value) result = result_flat.reshape(leading_shape + target_shape) return result
def _interpolate_1d( self, data: NDArray[np.floating], fill_value: float, ) -> NDArray[np.floating]: """Interpolate 1D data array.""" if self.method == "nearest": result = data[self.indices] if not np.isnan(fill_value): result = result.astype(np.float64) result[~self.valid_mask] = fill_value elif self.method == "idw": result = np.sum( self._idw_weights * data[self._idw_indices], axis=-1 ) result[~self.valid_mask] = fill_value elif self.method == "linear": from scipy.interpolate import LinearNDInterpolator interp = LinearNDInterpolator( self._delaunay, data, fill_value=fill_value ) target_2d = np.column_stack( [self.target_lon.ravel(), self.target_lat.ravel()] ) result = interp(target_2d).reshape(self.target_lon.shape) # Apply distance-based valid_mask on top result[~self.valid_mask] = fill_value elif self.method == "cubic": from scipy.interpolate import CloughTocher2DInterpolator valid_src = np.isfinite(data) if valid_src.all(): interp = CloughTocher2DInterpolator( self._delaunay, data, fill_value=fill_value ) else: interp = CloughTocher2DInterpolator( self._source_2d[valid_src], data[valid_src], fill_value=fill_value, ) target_2d = np.column_stack( [self.target_lon.ravel(), self.target_lat.ravel()] ) result = interp(target_2d).reshape(self.target_lon.shape) # Apply distance-based valid_mask on top result[~self.valid_mask] = fill_value else: raise ValueError(f"Unknown method: {self.method!r}") return result @property def shape(self) -> tuple[int, int]: """Shape of target grid (nlat, nlon).""" return self.target_lon.shape
[docs] def regrid( data: NDArray | "xr.DataArray", lon: NDArray[np.floating] | None = None, lat: NDArray[np.floating] | None = None, resolution: float | tuple[int, int] = 1.0, method: Literal["nearest", "idw", "linear", "cubic"] = "nearest", influence_radius: float = 80_000.0, fill_value: float = np.nan, lon_bounds: tuple[float, float] = (-180.0, 180.0), lat_bounds: tuple[float, float] = (-90.0, 90.0), as_xarray: bool = False, ) -> tuple[NDArray[np.floating], RegridInterpolator]: """Regrid unstructured data to regular grid. This is a convenience function that creates a RegridInterpolator and applies it. For repeated regridding with the same source grid, create a RegridInterpolator once and reuse it. Supports multi-dimensional data where the last axis contains the spatial points. For example: - 1D data (npoints,): single field - 2D data (nlevels, npoints): multi-level unstructured data (e.g., FESOM, ICON) - ND data (*dims, npoints): arbitrary leading dimensions Coordinate arrays can be: - 1D arrays of same size: unstructured mesh coordinates (used directly) - 1D arrays of different sizes: regular grid side coordinates (meshgrid created) - 2D arrays of same shape: full coordinate arrays (raveled to 1D) A warning is issued whenever coordinate transformations are applied. If lon/lat are not provided and data is an xarray DataArray, the function will attempt to extract coordinates automatically by looking for common coordinate names (lon/lat, longitude/latitude, x/y, etc.). Parameters ---------- data : array_like Data to interpolate. Last axis must be npoints (matching coordinates). Can be 1D (npoints,), 2D (nlevels, npoints), or ND (*dims, npoints). If xarray DataArray, coordinates may be extracted automatically. lon : array_like, optional Source grid longitude coordinates. Can be 1D or 2D array. If None, will attempt to extract from data (xarray only). lat : array_like, optional Source grid latitude coordinates. Can be 1D or 2D array. If None, will attempt to extract from data (xarray only). resolution : float or tuple of int Target grid resolution. method : {"nearest", "idw", "linear", "cubic"} Interpolation method. "nearest" uses nearest-neighbor lookup. "idw" uses inverse distance weighting (fast, smooth). "linear" uses Delaunay triangulation with barycentric interpolation. "cubic" uses Clough-Tocher C1 interpolation (smoothest). influence_radius : float Maximum influence radius in meters. fill_value : float Value for invalid points. lon_bounds : tuple of float Target grid longitude bounds. lat_bounds : tuple of float Target grid latitude bounds. as_xarray : bool, default False If True, wrap the regridded array in an ``xr.DataArray`` with ``lat`` and ``lon`` as 1-D dimension coordinates. Leading dimensions (e.g. time, depth) and their coordinates are preserved when the input is an ``xr.DataArray``. The return type of the tuple's first element changes from ``NDArray`` to ``xr.DataArray``. Returns ------- regridded : ndarray or xr.DataArray Regridded data. Returns ``xr.DataArray`` when ``as_xarray=True``, otherwise ``ndarray``. interpolator : RegridInterpolator The interpolator used (can be reused for other variables). """ # Extract coordinates from xarray if not provided if lon is None or lat is None: extracted_lon, extracted_lat = extract_coordinates(data) if lon is None: lon = extracted_lon if lat is None: lat = extracted_lat # Validate that we have coordinates if lon is None or lat is None: raise ValueError( "lon and lat coordinates are required. Either provide them explicitly " "or use an xarray DataArray with recognizable coordinate names " "(lon/lat, longitude/latitude, x/y, etc.)." ) # Handle xarray DataArray if hasattr(data, "values"): data_values = data.values else: data_values = np.asarray(data) lon_arr = np.asarray(lon) lat_arr = np.asarray(lat) # Determine the data/coordinate format and prepare accordingly # Key insight: for unstructured data, lon and lat have SAME size matching data's last dim # For regular grids, lon and lat have DIFFERENT sizes matching data's last two dims if lon_arr.ndim == 1 and lat_arr.ndim == 1: if lon_arr.size == lat_arr.size: # Case: Unstructured mesh coordinates (both 1D, same size) # Data can be 1D (npoints,) or multi-level (nlevels, npoints) npoints = data_values.shape[-1] if data_values.ndim >= 1 else data_values.size if lon_arr.size != npoints: raise ValueError( f"Coordinate size ({lon_arr.size}) must match data's last dimension ({npoints}). " f"Data shape: {data_values.shape}" ) # Coordinates are ready, data stays as-is else: # Case: Regular grid with side coordinates (1D lon, 1D lat, different sizes) # Data shape should be (..., nlat, nlon) or (nlat, nlon) if data_values.ndim < 2: raise ValueError( f"For regular grid coordinates (lon size {lon_arr.size}, lat size {lat_arr.size}), " f"data must be at least 2D, got shape {data_values.shape}" ) nlat, nlon = data_values.shape[-2], data_values.shape[-1] if lon_arr.size != nlon or lat_arr.size != nlat: raise ValueError( f"Coordinate sizes (lon: {lon_arr.size}, lat: {lat_arr.size}) must match " f"data's last two dimensions (nlat: {nlat}, nlon: {nlon}). " f"Data shape: {data_values.shape}" ) # Create meshgrid and ravel warnings.warn( f"Creating meshgrid from 1D lon ({lon_arr.size}) and lat ({lat_arr.size}) " f"for data (shape {data_values.shape}), then raveling spatial dimensions.", stacklevel=2, ) lon_arr, lat_arr = np.meshgrid(lon_arr, lat_arr) lon_arr = lon_arr.ravel() lat_arr = lat_arr.ravel() data_values = flatten_spatial(data_values) else: # 2D coordinates - prepare them and handle data accordingly lon_arr, lat_arr = prepare_coordinates(lon_arr, lat_arr) # Ravel data if it matches the original 2D coordinate shape if data_values.ndim >= 2 and data_values.shape[-2:] == np.asarray(lon).shape: data_values = flatten_spatial(data_values) interpolator = RegridInterpolator( source_lon=lon_arr, source_lat=lat_arr, resolution=resolution, method=method, influence_radius=influence_radius, lon_bounds=lon_bounds, lat_bounds=lat_bounds, ) regridded = interpolator(data_values, fill_value=fill_value) if as_xarray: import xarray as xr lat_1d = interpolator.target_lat[:, 0] lon_1d = interpolator.target_lon[0, :] if hasattr(data, "dims"): n_leading = regridded.ndim - 2 leading_dim_names = list(data.dims[:n_leading]) leading_coords = {d: data.coords[d] for d in leading_dim_names if d in data.coords} var_name = data.name or "data" var_attrs = dict(data.attrs) else: n_leading = regridded.ndim - 2 leading_dim_names = [f"dim_{i}" for i in range(n_leading)] leading_coords = {} var_name = "data" var_attrs = {} dims = (*leading_dim_names, "lat", "lon") coords = { **leading_coords, "lat": ("lat", lat_1d, {"units": "degrees_north", "standard_name": "latitude"}), "lon": ("lon", lon_1d, {"units": "degrees_east", "standard_name": "longitude"}), } regridded = xr.DataArray( regridded, dims=dims, coords=coords, name=var_name, attrs=var_attrs, ) return regridded, interpolator