Source code for nereus.diag.vertical

"""Vertical/ocean diagnostics for nereus.

This module provides functions for computing ocean diagnostics:
- surface_mean: Area-weighted mean for 2D fields (SST, SSS, etc.)
- volume_mean: Volume-weighted mean in a depth range
- heat_content: Ocean heat content
- find_closest_depth: Find index and value of closest depth to target
- interpolate_to_depth: Interpolate 3D data to target depths

All functions are dask-friendly: if inputs are dask arrays, the result
will be a lazy dask array that can be computed later with ``.compute()``.
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

from nereus.core.types import get_array_data, is_dask_array, wrap_as_xarray

if TYPE_CHECKING:
    import xarray as xr

# Physical constants
RHO_SEAWATER = 1025.0  # kg/m^3
CP_SEAWATER = 3990.0  # J/(kg·K) - consistent with FESOM2


[docs] def surface_mean( data: NDArray | "xr.DataArray", area: NDArray[np.floating], *, mask: NDArray[np.bool_] | None = None, as_xarray: bool = False, ) -> float | NDArray | "xr.DataArray": """Compute area-weighted mean of a 2D field (single level). This is commonly used for surface fields like SST, SSS, or for analyzing a single depth level. This function is dask-friendly: if inputs are dask arrays, the result will be a lazy dask array that can be computed later with ``.compute()``. Parameters ---------- data : array_like 2D data with shape (npoints,) or higher-dimensional with the last axis being npoints. For time series, shape would be (ntime, npoints). Regular-grid data with 2D spatial dimensions (e.g. (nlat, nlon) or (ntime, nlat, nlon)) is automatically flattened when nlat*nlon matches the area size. area : array_like Grid cell areas in m^2, shape (npoints,). mask : array_like, optional Boolean mask for horizontal points, shape (npoints,). True = include. as_xarray : bool If True, return the result as an xarray DataArray with dimension names and coordinates preserved from the input (default False). Returns ------- float or ndarray or dask.array or xr.DataArray Area-weighted mean. Returns float for 1D numpy input (npoints,), ndarray for higher-dimensional numpy input, dask array if inputs are dask, or xr.DataArray if as_xarray=True. Examples -------- >>> # Mean SST >>> mean_sst = nr.surface_mean(sst, mesh.area) >>> # Mean SST in a region >>> mean_sst = nr.surface_mean(sst, mesh.area, mask=region_mask) >>> # With dask arrays (lazy computation) >>> mean_sst = nr.surface_mean(sst_dask, mesh.area) >>> mean_sst.compute() # triggers actual computation """ # Extract arrays, preserving dask data_arr = get_array_data(data) area_arr = get_array_data(area) is_lazy = is_dask_array(data) # Warn if dask data is mixed with large numpy arrays (causes graph bloat) if is_lazy and not is_dask_array(area_arr) and area_arr.nbytes > 10_000_000: warnings.warn( f"Data is a dask array but area ({area_arr.nbytes / 1e6:.1f} MB) is numpy. " "This can cause very large dask graphs. Consider loading all " "large arrays with dask (e.g., xr.open_dataset(..., chunks='auto')).", UserWarning, stacklevel=2, ) # Flatten area to 1D if hasattr(area_arr, "ravel"): area_arr = area_arr.ravel() else: area_arr = np.asarray(area_arr).ravel() npoints_area = area_arr.shape[0] # Flatten 2D spatial dims (e.g. nlat, nlon) into single npoints axis if needed if data_arr.shape[-1] != npoints_area and data_arr.ndim >= 2: if data_arr.shape[-1] * data_arr.shape[-2] == npoints_area: data_arr = data_arr.reshape(data_arr.shape[:-2] + (-1,)) else: raise ValueError( f"data last axis has {data_arr.shape[-1]} points but area has " f"{npoints_area}; last two data dims {data_arr.shape[-2:]} " f"also don't multiply to {npoints_area}." ) # Build weights from area, applying mask if provided if mask is not None: mask_arr = get_array_data(mask) if hasattr(mask_arr, "ravel"): mask_arr = mask_arr.ravel() else: mask_arr = np.asarray(mask_arr).ravel() # Set weights to NaN where mask is False (will be ignored by nansum) weights = np.where(mask_arr, area_arr, np.nan) else: weights = area_arr # Compute weighted mean using nansum # NaN values in data or weights are automatically excluded weighted_sum = np.nansum(data_arr * weights, axis=-1) # For total weight, only count weights where data is valid # Use indicator trick: data * 0 + 1 gives 1 where valid, NaN where NaN valid_indicator = data_arr * 0 + 1 total_weight = np.nansum(weights * valid_indicator, axis=-1) # Compute mean, handling zero weight result = np.where(total_weight > 0, weighted_sum / total_weight, np.nan) # Return appropriate type if as_xarray: return wrap_as_xarray(result, data, "surface_mean") elif is_lazy: return result elif np.ndim(result) == 0: return float(result) else: return result
[docs] def volume_mean( data: NDArray | "xr.DataArray", area: NDArray[np.floating], thickness: NDArray[np.floating], depth: NDArray[np.floating] | None = None, *, depth_min: float | None = None, depth_max: float | None = None, mask: NDArray[np.bool_] | None = None, as_xarray: bool = False, ) -> float | NDArray | "xr.DataArray": """Compute volume-weighted mean of a quantity. This function is dask-friendly: if inputs are dask arrays, the result will be a lazy dask array that can be computed later with ``.compute()``. Parameters ---------- data : array_like 3D data with shape (nlevels, npoints) or higher-dimensional with the last two axes being (nlevels, npoints). For time series, shape would be (ntime, nlevels, npoints). area : array_like Grid cell areas in m^2. Can be either: - 1D array of shape (npoints,) for surface area (uniform across depth) - 2D array of shape (nlevels, npoints) for depth-dependent area If 2D and has one extra level compared to data layers, the extra level is dropped with a warning (levels vs layers). thickness : array_like Layer thicknesses in meters, shape (nlevels, npoints) or (nlevels,) if uniform across points. depth : array_like, optional Depth of layer centers in meters (positive downward), shape (nlevels,). Required if depth_min or depth_max are specified. depth_min : float, optional Minimum depth to include (meters, positive downward). depth_max : float, optional Maximum depth to include (meters, positive downward). mask : array_like, optional Boolean mask for horizontal points, shape (npoints,). True = include. as_xarray : bool If True, return the result as an xarray DataArray with dimension names and coordinates preserved from the input (default False). Returns ------- float or ndarray or dask.array or xr.DataArray Volume-weighted mean. Returns float for 2D numpy input (nlevels, npoints), ndarray for higher-dimensional numpy input, dask array if inputs are dask, or xr.DataArray if as_xarray=True. Examples -------- >>> # Mean temperature in upper 500m >>> mean_temp = nr.volume_mean( ... temp, mesh.area, mesh.layer_thickness, mesh.depth, ... depth_max=500 ... ) >>> # Mean salinity over full depth >>> mean_sal = nr.volume_mean(sal, mesh.area, mesh.layer_thickness) >>> # With dask arrays (lazy computation) >>> mean_temp = nr.volume_mean(temp_dask, mesh.area, mesh.layer_thickness) >>> mean_temp.compute() # triggers actual computation """ # Extract arrays, preserving dask data_arr = get_array_data(data) area_arr = get_array_data(area) thick_arr = get_array_data(thickness) is_lazy = is_dask_array(data) # Warn if dask data is mixed with large numpy arrays (causes graph bloat) if is_lazy: large_numpy_arrays = [] # Check area - threshold ~10MB (large enough to cause issues) if not is_dask_array(area_arr) and area_arr.nbytes > 10_000_000: large_numpy_arrays.append(f"area ({area_arr.nbytes / 1e6:.1f} MB)") # Check thickness if not is_dask_array(thick_arr) and thick_arr.nbytes > 10_000_000: large_numpy_arrays.append(f"thickness ({thick_arr.nbytes / 1e6:.1f} MB)") if large_numpy_arrays: warnings.warn( f"Data is a dask array but {', '.join(large_numpy_arrays)} " f"{'is' if len(large_numpy_arrays) == 1 else 'are'} numpy. " "This can cause very large dask graphs. Consider loading all " "large arrays with dask (e.g., xr.open_dataset(..., chunks='auto')).", UserWarning, stacklevel=2, ) # Get number of levels from data nlev_data = data_arr.shape[-2] npoints = data_arr.shape[-1] # Handle area: can be 1D (npoints,) or 2D (nlevels, npoints) if area_arr.ndim == 1: # Surface area only - will broadcast later area_is_2d = False elif area_arr.ndim == 2: nlev_area = area_arr.shape[0] area_is_2d = True # Check if area has one extra level (levels vs layers mismatch) if nlev_area != nlev_data: diff = nlev_area - nlev_data if diff != 1: raise ValueError( f"area has {nlev_area} vertical levels but data has {nlev_data}; " "only area having one extra level is supported (levels vs layers)." ) warnings.warn( f"area has one more vertical level than data; " f"using the first {nlev_data} levels of area to match data " "(levels vs layers).", UserWarning, stacklevel=2, ) area_arr = area_arr[:nlev_data, :] else: raise ValueError(f"area must be 1D or 2D, got {area_arr.ndim}D") # Handle thickness - need to broadcast if 1D if thick_arr.ndim == 1: nlevels = thick_arr.shape[0] # Use broadcasting instead of np.broadcast_to for dask compatibility thick_2d = thick_arr[:, np.newaxis] # Shape: (nlevels, 1) else: nlevels = thick_arr.shape[0] thick_2d = thick_arr # Validate dimensions if nlevels != nlev_data: raise ValueError( f"thickness has {nlevels} levels but data has {nlev_data}" ) # For dask arrays, ensure thickness is also dask to avoid graph bloat # When numpy arrays are broadcast with dask arrays, they get embedded # in every task, causing massive graph sizes if is_lazy and not is_dask_array(thick_2d): import dask.array as da # Get chunks from data_arr for the last two axes (nlevels, npoints) data_chunks = data_arr.chunks level_chunks = data_chunks[-2] # chunks along nlevels axis point_chunks = data_chunks[-1] # chunks along npoints axis if thick_2d.shape[-1] == 1: # thick_2d is (nlevels, 1) - broadcast along points thick_2d = da.from_array(thick_2d, chunks=(level_chunks, 1)) else: # thick_2d is (nlevels, npoints) thick_2d = da.from_array(thick_2d, chunks=(level_chunks, point_chunks)) # Build depth mask if needed (this is small, keep as numpy) level_mask = np.ones(nlevels, dtype=np.float64) if depth_min is not None or depth_max is not None: if depth is None: raise ValueError("depth array required when using depth_min/depth_max") depth_arr = np.asarray(get_array_data(depth)).ravel() if depth_min is not None: level_mask = level_mask * (depth_arr >= depth_min) if depth_max is not None: level_mask = level_mask * (depth_arr <= depth_max) # Compute cell volumes: thickness * area # Shape: (nlevels, npoints) or broadcasts to it if area_is_2d: volumes = thick_2d * area_arr else: volumes = thick_2d * area_arr[np.newaxis, :] # Apply depth mask by setting excluded levels to NaN if depth_min is not None or depth_max is not None: level_mask_nan = np.where(level_mask, 1.0, np.nan) volumes = volumes * level_mask_nan[:, np.newaxis] # Apply horizontal mask by setting excluded points to NaN if mask is not None: horiz_mask = get_array_data(mask) if hasattr(horiz_mask, "ravel"): horiz_mask = horiz_mask.ravel() else: horiz_mask = np.asarray(horiz_mask).ravel() horiz_mask_nan = np.where(horiz_mask, 1.0, np.nan) volumes = volumes * horiz_mask_nan # Compute weighted mean using nansum # NaN values in data or volumes are automatically excluded weighted_sum = np.nansum(data_arr * volumes, axis=(-2, -1)) # For total volume, only count volumes where data is valid # Use indicator trick: data * 0 + 1 gives 1 where valid, NaN where NaN valid_indicator = data_arr * 0 + 1 total_volume = np.nansum(volumes * valid_indicator, axis=(-2, -1)) # Compute mean, handling zero volume result = np.where(total_volume > 0, weighted_sum / total_volume, np.nan) # Return appropriate type if as_xarray: return wrap_as_xarray(result, data, "volume_mean", skip_dims=2) elif is_lazy: return result elif np.ndim(result) == 0: return float(result) else: return result
[docs] def heat_content( temperature: NDArray | "xr.DataArray", area: NDArray[np.floating], thickness: NDArray[np.floating], depth: NDArray[np.floating] | None = None, *, depth_min: float | None = None, depth_max: float | None = None, reference_temp: float = 0.0, mask: NDArray[np.bool_] | None = None, rho: float = RHO_SEAWATER, cp: float = CP_SEAWATER, output: str = "total", as_xarray: bool = False, ) -> float | NDArray | "xr.DataArray": """Compute ocean heat content. Heat content can be computed as either: - Total (default): OHC = rho * cp * sum(T * thickness * area) in Joules - Map: OHC = rho * cp * sum_z(T * thickness) in J/m² at each point This function is dask-friendly: if inputs are dask arrays, the result will be a lazy dask array that can be computed later with ``.compute()``. Parameters ---------- temperature : array_like Temperature in degrees Celsius, shape (nlevels, npoints) or higher dimensional with the last two axes being (nlevels, npoints). area : array_like Grid cell areas in m^2. Can be either: - 1D array of shape (npoints,) for surface area (uniform across depth) - 2D array of shape (nlevels, npoints) for depth-dependent area If 2D and has one extra level compared to data layers, the extra level is dropped with a warning (levels vs layers). Note: area is not used when output="map". thickness : array_like Layer thicknesses in meters, shape (nlevels, npoints) or (nlevels,). depth : array_like, optional Depth of layer centers in meters (positive downward). Required if depth_min or depth_max are specified. depth_min : float, optional Minimum depth to include (meters, positive downward). depth_max : float, optional Maximum depth to include (meters, positive downward). reference_temp : float Reference temperature for heat content calculation. Default 0°C. mask : array_like, optional Boolean mask for horizontal points. True = include. rho : float Seawater density in kg/m^3. Default 1025. cp : float Specific heat capacity in J/(kg·K). Default 3990. output : str Output type. One of: - "total": Total heat content in Joules (scalar per timestep) - "map": Heat content per unit area in J/m² (2D field at each point) Default is "total". as_xarray : bool If True, return the result as an xarray DataArray with dimension names and coordinates preserved from the input (default False). Returns ------- float or ndarray or dask.array or xr.DataArray If output="total": Ocean heat content in Joules. If output="map": Heat content per unit area in J/m², shape (npoints,) or (..., npoints) for higher-dimensional input. Returns a dask array if inputs are dask, or xr.DataArray if as_xarray=True. Examples -------- >>> # Total ocean heat content >>> ohc = nr.heat_content(temp, mesh.area, mesh.layer_thickness) >>> # Heat content in upper 700m >>> ohc_700 = nr.heat_content( ... temp, mesh.area, mesh.layer_thickness, mesh.depth, ... depth_max=700 ... ) >>> # Heat content map (J/m² at each point, like FESOM2 output) >>> ohc_map = nr.heat_content( ... temp, mesh.area, mesh.layer_thickness, ... output="map" ... ) >>> # With dask arrays (lazy computation) >>> ohc = nr.heat_content(temp_dask, mesh.area, mesh.layer_thickness) >>> ohc.compute() # triggers actual computation """ # Validate output parameter if output not in ("total", "map"): raise ValueError(f"output must be 'total' or 'map', got '{output}'") # Extract arrays, preserving dask temp_arr = get_array_data(temperature) area_arr = get_array_data(area) thick_arr = get_array_data(thickness) is_lazy = is_dask_array(temperature) # Warn if dask data is mixed with large numpy arrays (causes graph bloat) if is_lazy: large_numpy_arrays = [] # Check area - threshold ~10MB (large enough to cause issues) if not is_dask_array(area_arr) and area_arr.nbytes > 10_000_000: large_numpy_arrays.append(f"area ({area_arr.nbytes / 1e6:.1f} MB)") # Check thickness if not is_dask_array(thick_arr) and thick_arr.nbytes > 10_000_000: large_numpy_arrays.append(f"thickness ({thick_arr.nbytes / 1e6:.1f} MB)") if large_numpy_arrays: warnings.warn( f"Data is a dask array but {', '.join(large_numpy_arrays)} " f"{'is' if len(large_numpy_arrays) == 1 else 'are'} numpy. " "This can cause very large dask graphs. Consider loading all " "large arrays with dask (e.g., xr.open_dataset(..., chunks='auto')).", UserWarning, stacklevel=2, ) # Get number of levels from data nlev_data = temp_arr.shape[-2] npoints = temp_arr.shape[-1] # Handle area: can be 1D (npoints,) or 2D (nlevels, npoints) # Only needed for output="total" area_is_2d = False if output == "total": if area_arr.ndim == 1: area_is_2d = False elif area_arr.ndim == 2: nlev_area = area_arr.shape[0] area_is_2d = True # Check if area has one extra level (levels vs layers mismatch) if nlev_area != nlev_data: diff = nlev_area - nlev_data if diff != 1: raise ValueError( f"area has {nlev_area} vertical levels but data has {nlev_data}; " "only area having one extra level is supported (levels vs layers)." ) warnings.warn( f"area has one more vertical level than data; " f"using the first {nlev_data} levels of area to match data " "(levels vs layers).", UserWarning, stacklevel=2, ) area_arr = area_arr[:nlev_data, :] else: raise ValueError(f"area must be 1D or 2D, got {area_arr.ndim}D") # Handle thickness - need to broadcast if 1D if thick_arr.ndim == 1: nlevels = thick_arr.shape[0] thick_2d = thick_arr[:, np.newaxis] # Shape: (nlevels, 1) else: nlevels = thick_arr.shape[0] thick_2d = thick_arr # Validate dimensions if nlevels != nlev_data: raise ValueError( f"thickness has {nlevels} levels but data has {nlev_data}" ) # Build depth mask if needed (this is small, keep as numpy) level_mask = np.ones(nlevels, dtype=np.float64) if depth_min is not None or depth_max is not None: if depth is None: raise ValueError("depth array required when using depth_min/depth_max") depth_arr = np.asarray(get_array_data(depth)).ravel() if depth_min is not None: level_mask = level_mask * (depth_arr >= depth_min) if depth_max is not None: level_mask = level_mask * (depth_arr <= depth_max) # Compute heat content: rho * cp * sum((T - T_ref) * thickness [* area]) temp_anomaly = temp_arr - reference_temp # For dask arrays, ensure thickness is also dask to avoid graph bloat # When numpy arrays are broadcast with dask arrays, they get embedded # in every task, causing massive graph sizes if is_lazy and not is_dask_array(thick_2d): import dask.array as da # Get chunks from temp_arr for the last two axes (nlevels, npoints) temp_chunks = temp_arr.chunks level_chunks = temp_chunks[-2] # chunks along nlevels axis point_chunks = temp_chunks[-1] # chunks along npoints axis if thick_2d.shape[-1] == 1: # thick_2d is (nlevels, 1) - broadcast along points thick_2d = da.from_array(thick_2d, chunks=(level_chunks, 1)) else: # thick_2d is (nlevels, npoints) thick_2d = da.from_array(thick_2d, chunks=(level_chunks, point_chunks)) if output == "total": # Compute cell volumes: thickness * area if area_is_2d: volumes = thick_2d * area_arr else: volumes = thick_2d * area_arr[np.newaxis, :] # Apply depth mask by setting excluded levels to NaN if depth_min is not None or depth_max is not None: level_mask_nan = np.where(level_mask, 1.0, np.nan) volumes = volumes * level_mask_nan[:, np.newaxis] # Apply horizontal mask by setting excluded points to NaN if mask is not None: horiz_mask = get_array_data(mask) if hasattr(horiz_mask, "ravel"): horiz_mask = horiz_mask.ravel() else: horiz_mask = np.asarray(horiz_mask).ravel() horiz_mask_nan = np.where(horiz_mask, 1.0, np.nan) volumes = volumes * horiz_mask_nan # Sum over last two axes (nlevels, npoints) using nansum # NaN values in temp_anomaly or volumes are automatically excluded heat = np.nansum(temp_anomaly * volumes, axis=(-2, -1)) result = rho * cp * heat # Return appropriate type if as_xarray: return wrap_as_xarray(result, temperature, "heat_content", skip_dims=2) elif is_lazy: return result elif np.ndim(result) == 0: return float(result) else: return result else: # output == "map" # Compute heat content per unit area: rho * cp * sum_z(T * thickness) thick_masked = thick_2d # Apply depth mask by setting excluded levels to NaN if depth_min is not None or depth_max is not None: level_mask_nan = np.where(level_mask, 1.0, np.nan) thick_masked = thick_masked * level_mask_nan[:, np.newaxis] # Sum over vertical axis only (second to last axis) using nansum # NaN values in temp_anomaly or thickness are automatically excluded heat_per_area = np.nansum(temp_anomaly * thick_masked, axis=-2) # Apply horizontal mask (use 0 for masked points, not NaN, for map output) if mask is not None: horiz_mask = get_array_data(mask) if hasattr(horiz_mask, "ravel"): horiz_mask = horiz_mask.ravel() else: horiz_mask = np.asarray(horiz_mask).ravel() heat_per_area = heat_per_area * horiz_mask.astype(np.float64) result = rho * cp * heat_per_area # Return appropriate type if as_xarray: return _wrap_map_as_xarray(result, temperature, "heat_content") elif is_lazy: return result else: return np.asarray(result)
def _wrap_map_as_xarray( result: NDArray, source_data: NDArray | "xr.DataArray", default_name: str, ) -> "xr.DataArray": """Wrap a partial-reduction result (depth removed, spatial kept) as xarray. For heat_content(output="map"), the input has shape (..., nlevels, npoints) and the result has shape (..., npoints). The leading dims plus the last spatial dim should be preserved. """ import xarray as xr result_arr = np.asarray(result) if hasattr(source_data, "dims"): # Input dims: e.g. ("time", "level", "npoints") # Result dims: e.g. ("time", "npoints") — skip the second-to-last input_dims = list(source_data.dims) # Keep all dims except the second-to-last (depth/level axis) output_dims = input_dims[:-2] + input_dims[-1:] output_coords = { d: source_data.coords[d] for d in output_dims if d in source_data.coords } var_name = source_data.name or default_name var_attrs = dict(source_data.attrs) else: n = result_arr.ndim output_dims = [f"dim_{i}" for i in range(n)] output_coords = {} var_name = default_name var_attrs = {} return xr.DataArray( result_arr, dims=output_dims, coords=output_coords, name=var_name, attrs=var_attrs, )
[docs] def find_closest_depth( depths: NDArray[np.floating] | list | "xr.DataArray", target: float, ) -> tuple[int, float]: """Find the index and value of the depth closest to a target depth. This is useful when comparing multiple models with different depth levels to find corresponding levels, and to assess how far model depths are from target depths. Parameters ---------- depths : array_like 1D array of depth values (typically positive downward in meters). target : float Target depth value to find the closest match for. Returns ------- tuple[int, float] A tuple of (index, value) where index is the position of the closest depth in the input array, and value is the actual depth at that index. Examples -------- >>> depths = [0, 10, 25, 50, 100, 200, 500, 1000] >>> idx, val = nr.find_closest_depth(depths, 100) >>> print(f"Index: {idx}, Depth: {val}m") Index: 4, Depth: 100.0m >>> # Check how far model depth is from target >>> idx, val = nr.find_closest_depth(depths, 75) >>> print(f"Closest depth: {val}m, difference: {abs(val - 75)}m") Closest depth: 50.0m, difference: 25.0m """ # Extract array data depth_arr = get_array_data(depths) depth_arr = np.asarray(depth_arr).ravel() # Find index of minimum absolute difference idx = int(np.argmin(np.abs(depth_arr - target))) value = float(depth_arr[idx]) return idx, value
[docs] def interpolate_to_depth( data: NDArray | "xr.DataArray", lon: NDArray[np.floating] | "xr.DataArray" | None, lat: NDArray[np.floating] | "xr.DataArray" | None, model_depths: NDArray[np.floating] | list | "xr.DataArray", target_depths: NDArray[np.floating] | list | float, ) -> NDArray | tuple[NDArray, NDArray, NDArray]: """Interpolate 3D data to target depth levels using linear interpolation. Performs column-wise linear interpolation from model depth levels to specified target depths. Values outside the model depth range are extrapolated (with a warning for significant extrapolation). This function is dask-friendly: if inputs are dask arrays, the result will be a lazy dask array that can be computed later with ``.compute()``. Parameters ---------- data : array_like 3D data with shape (nlevels, npoints) or higher-dimensional with the last two axes being (nlevels, npoints). For time series, shape would be (ntime, nlevels, npoints). lon : array_like or None Longitude coordinates, shape (npoints,). If provided along with lat, these are returned with the result for convenience. Pass None if not needed. lat : array_like or None Latitude coordinates, shape (npoints,). If provided along with lon, these are returned with the result for convenience. Pass None if not needed. model_depths : array_like Depth levels of the input data in meters (positive downward), shape (nlevels,). target_depths : array_like or float Target depth(s) to interpolate to in meters. Can be a single value or an array of depths. Returns ------- ndarray or tuple If lon and lat are None: Interpolated data with shape (ntargets, npoints) or (..., ntargets, npoints) for higher-dimensional input. If target_depths is a scalar, ntargets=1. If lon and lat are provided: Tuple of (interpolated_data, lon, lat). Examples -------- >>> # Interpolate temperature to 100m depth (without coordinates) >>> temp_100m = nr.interpolate_to_depth(temp, None, None, mesh.depth, 100) >>> # Interpolate to multiple standard depths >>> standard_depths = [10, 50, 100, 200, 500, 1000] >>> temp_interp = nr.interpolate_to_depth(temp, None, None, mesh.depth, standard_depths) >>> # With coordinates for plotting >>> temp_100m, lon, lat = nr.interpolate_to_depth( ... temp, mesh.lon, mesh.lat, mesh.depth, 100 ... ) >>> nr.plot(temp_100m.squeeze(), lon, lat) >>> # Compare models at the same depth >>> temp_model1 = nr.interpolate_to_depth(temp1, None, None, depths1, 100) >>> temp_model2 = nr.interpolate_to_depth(temp2, None, None, depths2, 100) """ # Extract arrays, preserving dask data_arr = get_array_data(data) is_lazy = is_dask_array(data) # Handle model depths depth_arr = get_array_data(model_depths) depth_arr = np.asarray(depth_arr).ravel() nlevels = len(depth_arr) # Validate data shape if data_arr.shape[-2] != nlevels: raise ValueError( f"data has {data_arr.shape[-2]} levels but model_depths has {nlevels}" ) # Handle target depths - ensure array target_arr = np.atleast_1d(np.asarray(target_depths)).ravel() ntargets = len(target_arr) # Check for extrapolation depth_min, depth_max = depth_arr.min(), depth_arr.max() targets_below = target_arr[target_arr < depth_min] targets_above = target_arr[target_arr > depth_max] if len(targets_below) > 0 or len(targets_above) > 0: extrap_msg = [] if len(targets_below) > 0: extrap_msg.append( f"{len(targets_below)} target(s) shallower than model minimum ({depth_min}m)" ) if len(targets_above) > 0: extrap_msg.append( f"{len(targets_above)} target(s) deeper than model maximum ({depth_max}m)" ) warnings.warn( f"Extrapolation required: {'; '.join(extrap_msg)}. " "Results may be unreliable outside model depth range.", UserWarning, stacklevel=2, ) # Get shape information npoints = data_arr.shape[-1] leading_dims = data_arr.shape[:-2] # e.g., (ntime,) or () # Reshape data to (nbatch, nlevels, npoints) for uniform processing if len(leading_dims) == 0: # Shape: (nlevels, npoints) data_3d = data_arr[np.newaxis, :, :] # (1, nlevels, npoints) nbatch = 1 else: # Shape: (..., nlevels, npoints) nbatch = int(np.prod(leading_dims)) data_3d = data_arr.reshape(nbatch, nlevels, npoints) # Perform linear interpolation column by column # For each target depth, find bracketing levels and interpolate if is_lazy: import dask.array as da # Interpolation requires all depth levels at once, so rechunk # to have a single chunk along the levels axis data_3d = data_3d.rechunk({1: -1}) # Create closure to capture depth_arr and target_arr (not passed as block args) def _make_interp_func(depths, targets): def _interp_chunk(data_chunk): return _linear_interp_vectorized(data_chunk, depths, targets) return _interp_chunk interp_func = _make_interp_func(depth_arr, target_arr) # Get chunks from data (after rechunking) data_chunks = data_3d.chunks result = da.map_blocks( interp_func, data_3d, dtype=data_3d.dtype, chunks=(data_chunks[0], (ntargets,), data_chunks[2]), ) else: result = _linear_interp_vectorized(data_3d, depth_arr, target_arr) # Reshape result to match input dimensions if len(leading_dims) == 0: result = result[0, :, :] # Remove batch dimension: (ntargets, npoints) else: result = result.reshape(*leading_dims, ntargets, npoints) # Handle coordinate returns if lon is not None and lat is not None: lon_arr = get_array_data(lon) lat_arr = get_array_data(lat) if hasattr(lon_arr, "ravel"): lon_arr = lon_arr.ravel() if hasattr(lat_arr, "ravel"): lat_arr = lat_arr.ravel() return result, np.asarray(lon_arr), np.asarray(lat_arr) return result
def _linear_interp_vectorized( data: NDArray, depths: NDArray, targets: NDArray, ) -> NDArray: """Vectorized linear interpolation for depth profiles. Parameters ---------- data : ndarray Shape (nbatch, nlevels, npoints). depths : ndarray Shape (nlevels,), must be monotonic. targets : ndarray Shape (ntargets,). Returns ------- ndarray Shape (nbatch, ntargets, npoints). """ nbatch, nlevels, npoints = data.shape ntargets = len(targets) # Check if depths are monotonically increasing or decreasing if depths[0] > depths[-1]: # Depths are decreasing, flip for interpolation depths = depths[::-1] data = data[:, ::-1, :] # Output array result = np.empty((nbatch, ntargets, npoints), dtype=data.dtype) for t_idx, target in enumerate(targets): # Find bracketing indices # np.searchsorted returns index where target would be inserted idx_upper = np.searchsorted(depths, target) if idx_upper == 0: # Target is above/at shallowest level - extrapolate using first two levels idx_lower, idx_upper = 0, 1 elif idx_upper >= nlevels: # Target is below deepest level - extrapolate using last two levels idx_lower, idx_upper = nlevels - 2, nlevels - 1 else: idx_lower = idx_upper - 1 # Get bracketing depths and data z_lower = depths[idx_lower] z_upper = depths[idx_upper] data_lower = data[:, idx_lower, :] # (nbatch, npoints) data_upper = data[:, idx_upper, :] # (nbatch, npoints) # Linear interpolation weight if z_upper != z_lower: weight = (target - z_lower) / (z_upper - z_lower) else: weight = 0.0 # Interpolate result[:, t_idx, :] = data_lower + weight * (data_upper - data_lower) return result