Source code for nereus.models.fesom.functions

"""FESOM-specific functions.

This module provides functions specific to FESOM mesh operations,
such as computing element centers and node-to-element interpolation.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import numpy as np
import xarray as xr
from numpy.typing import NDArray

from nereus.core.coordinates import compute_element_centers as _compute_element_centers

if TYPE_CHECKING:
    pass


[docs] def compute_element_centers(mesh: xr.Dataset) -> xr.Dataset: """Compute element center coordinates. Adds lon_tri and lat_tri variables to the mesh dataset. Parameters ---------- mesh : xr.Dataset FESOM mesh dataset with lon, lat, and triangles. Returns ------- xr.Dataset Mesh with lon_tri and lat_tri added. Examples -------- >>> mesh = nr.fesom.load_mesh(path) >>> mesh = nr.fesom.compute_element_centers(mesh) >>> lon_elem = mesh["lon_tri"] """ if "lon_tri" in mesh and "lat_tri" in mesh: return mesh if "triangles" not in mesh: raise ValueError("Mesh does not contain triangles") lon = mesh["lon"].values lat = mesh["lat"].values triangles = mesh["triangles"].values lon_tri, lat_tri = _compute_element_centers(lon, lat, triangles) mesh = mesh.copy() mesh["lon_tri"] = xr.DataArray( lon_tri, dims=("nelem",), attrs={ "units": "degrees_east", "long_name": "Element center longitude", }, ) mesh["lat_tri"] = xr.DataArray( lat_tri, dims=("nelem",), attrs={ "units": "degrees_north", "long_name": "Element center latitude", }, ) return mesh
[docs] def node_to_element( data: xr.DataArray | NDArray[np.floating], mesh: xr.Dataset, method: Literal["mean", "median", "min", "max"] = "mean", ) -> xr.DataArray | NDArray[np.floating]: """Interpolate data from nodes to element centers. For each triangle, computes an aggregate of its three vertex values. Parameters ---------- data : xr.DataArray or ndarray Node data with shape (..., npoints). mesh : xr.Dataset FESOM mesh dataset with triangles. method : {"mean", "median", "min", "max"} Aggregation method. Returns ------- xr.DataArray or ndarray Element data with shape (..., nelem). Examples -------- >>> mesh = nr.fesom.load_mesh(path) >>> temp_elem = nr.fesom.node_to_element(temp_node, mesh) """ if "triangles" not in mesh: raise ValueError("Mesh does not contain triangles") triangles = mesh["triangles"].values is_xarray = isinstance(data, xr.DataArray) if is_xarray: data_np = data.values orig_dims = data.dims else: data_np = np.asarray(data) orig_dims = None # Get the spatial dimension (last one) spatial_shape = data_np.shape[:-1] npoints = data_np.shape[-1] nelem = triangles.shape[0] # Get data at triangle vertices: shape (..., nelem, 3) vertex_data = data_np[..., triangles] # Aggregate if method == "mean": elem_data = np.mean(vertex_data, axis=-1) elif method == "median": elem_data = np.median(vertex_data, axis=-1) elif method == "min": elem_data = np.min(vertex_data, axis=-1) elif method == "max": elem_data = np.max(vertex_data, axis=-1) else: raise ValueError(f"Unknown method: {method}") if is_xarray: # Create new DataArray with element dimension new_dims = orig_dims[:-1] + ("nelem",) result = xr.DataArray( elem_data, dims=new_dims, attrs=data.attrs, ) # Copy coordinates except the spatial one for coord, coord_data in data.coords.items(): if coord not in data.dims[-1:]: result = result.assign_coords({coord: coord_data}) return result else: return elem_data
[docs] def element_to_node( data: xr.DataArray | NDArray[np.floating], mesh: xr.Dataset, method: Literal["mean", "sum"] = "mean", ) -> xr.DataArray | NDArray[np.floating]: """Interpolate data from element centers to nodes. For each node, aggregates values from all connected elements. Parameters ---------- data : xr.DataArray or ndarray Element data with shape (..., nelem). mesh : xr.Dataset FESOM mesh dataset with triangles. method : {"mean", "sum"} Aggregation method. Returns ------- xr.DataArray or ndarray Node data with shape (..., npoints). Examples -------- >>> mesh = nr.fesom.load_mesh(path) >>> temp_node = nr.fesom.element_to_node(temp_elem, mesh) """ if "triangles" not in mesh: raise ValueError("Mesh does not contain triangles") triangles = mesh["triangles"].values npoints = mesh.sizes["npoints"] nelem = triangles.shape[0] is_xarray = isinstance(data, xr.DataArray) if is_xarray: data_np = data.values orig_dims = data.dims else: data_np = np.asarray(data) orig_dims = None # Count how many elements each node belongs to node_count = np.zeros(npoints, dtype=np.float64) for tri in triangles: node_count[tri] += 1 # Shape handling for broadcasting spatial_shape = data_np.shape[:-1] # Initialize output node_data = np.zeros(spatial_shape + (npoints,), dtype=data_np.dtype) # Accumulate element values to nodes for i, tri in enumerate(triangles): node_data[..., tri] += data_np[..., i:i + 1] if method == "mean": # Avoid division by zero node_count = np.where(node_count > 0, node_count, 1) node_data = node_data / node_count elif method == "sum": pass else: raise ValueError(f"Unknown method: {method}") if is_xarray: new_dims = orig_dims[:-1] + ("npoints",) result = xr.DataArray( node_data, dims=new_dims, attrs=data.attrs, ) for coord, coord_data in data.coords.items(): if coord not in data.dims[-1:]: result = result.assign_coords({coord: coord_data}) return result else: return node_data
[docs] def mask_by_depth( data: xr.DataArray | NDArray[np.floating], nod_area_nans: xr.DataArray | NDArray[np.floating], ) -> NDArray[np.floating]: """Apply depth mask to data, setting invalid cells to NaN. This function masks ocean data based on bottom topography using ``nod_area_nans`` from the FESOM mesh, which contains NaN for cells below the ocean floor. Parameters ---------- data : xr.DataArray or ndarray Data to mask. Can be: - 2D array (nz, npoints) for 3D fields - 1D array (npoints,) for 2D fields at a single level nod_area_nans : xr.DataArray or ndarray 3D node area array with NaN where cells are below bottom/land. Use ``mesh.nod_area_nans`` or a slice of it (e.g., ``mesh.nod_area_nans[10, :]``). Must match data shape. Returns ------- ndarray Masked data as numpy array with NaN where nod_area_nans is NaN. Returns float64 to support NaN values. Examples -------- >>> mesh = nr.fesom.load_mesh(path) >>> # Mask 3D temperature field >>> temp_masked = nr.fesom.mask_by_depth(temp_3d, mesh.nod_area_nans) >>> >>> # Mask single level (e.g., level 10) >>> temp_lev10_masked = nr.fesom.mask_by_depth( ... temp_3d[10, :], mesh.nod_area_nans[10, :] ... ) Notes ----- The function is dimension-name-agnostic - it works purely by array shape. This avoids issues with inconsistent dimension names across different FESOM output files (nod2, npoints, ncells, etc.). For volume computations, you can use ``nod_area_nans`` directly since it already contains the area values with NaN for invalid cells. """ # Convert to numpy arrays if hasattr(data, "values"): data_np = data.values else: data_np = np.asarray(data) if hasattr(nod_area_nans, "values"): area_np = nod_area_nans.values else: area_np = np.asarray(nod_area_nans) # Validate shapes match if data_np.shape != area_np.shape: raise ValueError( f"Data shape {data_np.shape} does not match nod_area_nans shape {area_np.shape}. " "Ensure you're using the correct depth levels." ) # Create output array (float64 to support NaN) result = data_np.astype(np.float64, copy=True) # Apply mask: set NaN locations in area to NaN in result result[np.isnan(area_np)] = np.nan return result