Source code for nereus.models.fesom.mesh

"""FESOM2 mesh loading.

This module provides functionality for loading FESOM2 meshes as xr.Dataset
objects with standardized variable names.
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr
from numpy.typing import NDArray

from nereus.core.coordinates import EARTH_RADIUS, compute_element_centers
from nereus.core.mesh import (
    add_mesh_metadata,
    normalize_lon,
    should_use_dask,
    validate_mesh,
)

if TYPE_CHECKING:
    pass


[docs] def load_mesh( path: str | os.PathLike, *, use_dask: bool | None = None, ) -> xr.Dataset: """Load FESOM mesh from directory or NetCDF file. Parameters ---------- path : str or Path Path to mesh directory (containing nod2d.out, etc.) or to fesom.mesh.diag.nc file. use_dask : bool, optional Whether to use dask arrays. If None, auto-detects based on mesh size (>1M points triggers dask). Returns ------- xr.Dataset Standardized mesh dataset with: - lon, lat: Node coordinates (npoints,) - area: Node cluster area in m^2 (npoints,) - triangles: 0-indexed triangle connectivity (nelem, 3) - lon_tri, lat_tri: Element center coordinates (nelem,) - depth: Layer center depths in meters (nz,) - depth_bounds: Layer interfaces (nz, 2) - layer_thickness: Layer thickness in meters (nz,) Plus original FESOM variables with their native names. Examples -------- >>> mesh = nr.fesom.load_mesh("/path/to/mesh") >>> print(f"Mesh has {mesh.sizes['npoints']} nodes") >>> area = mesh["area"] >>> lon = mesh["lon"].values """ path = Path(path) if _is_netcdf(path): return _load_from_netcdf(path, use_dask=use_dask) else: return _load_from_ascii(path, use_dask=use_dask)
def _is_netcdf(path: Path) -> bool: """Check if path is a netCDF file.""" if path.is_file(): return path.suffix in (".nc", ".nc4") return False def _load_from_netcdf(filepath: Path, use_dask: bool | None = None) -> xr.Dataset: """Load mesh from fesom.mesh.diag.nc file. Parameters ---------- filepath : Path Path to netCDF file. use_dask : bool, optional Whether to use dask arrays. Returns ------- xr.Dataset Standardized mesh dataset. """ # Open with xarray to get dimensions with xr.open_dataset(filepath) as ds_orig: npoints = ds_orig.sizes.get("nod2", ds_orig.sizes.get("nod_n", len(ds_orig["lon"]))) use_dask_actual = should_use_dask(npoints, use_dask) # Reopen with appropriate chunking if use_dask_actual: ds_orig = xr.open_dataset(filepath, chunks={}) else: ds_orig = xr.open_dataset(filepath) # Build standardized dataset ds = xr.Dataset() # --- First, copy ALL original variables with dimension renaming --- # Map original dimensions to standardized/renamed dimensions # nod2 -> npoints (standardized name for node dimension) # nz stays as nz (original 48 levels for interfaces) dim_map = { "nod2": "npoints", "nod_n": "npoints", } # Copy all original variables (preserve dask arrays if enabled) for var_name in ds_orig.data_vars: var = ds_orig[var_name] # Rename dimensions according to map new_dims = tuple(dim_map.get(d, d) for d in var.dims) # Use .data to preserve dask arrays, .values would force computation var_data = var.data if use_dask_actual else var.values ds[var_name] = xr.DataArray( var_data, dims=new_dims, attrs=var.attrs, ) # Copy original coordinates (with dimension renaming) for coord_name in ds_orig.coords: coord = ds_orig[coord_name] new_dims = tuple(dim_map.get(d, d) for d in coord.dims) coord_data = coord.data if use_dask_actual else coord.values ds.coords[coord_name] = xr.DataArray( coord_data, dims=new_dims if new_dims else (coord_name,), attrs=coord.attrs, ) # --- Now add standardized variables --- # Standardized lon/lat (normalized) lon_data = ds_orig["lon"].values if not use_dask_actual else ds_orig["lon"].data lat_data = ds_orig["lat"].values if not use_dask_actual else ds_orig["lat"].data # Normalize longitude to [-180, 180] if not use_dask_actual: lon_data = normalize_lon(lon_data, "pm180") else: import dask.array as da lon_data = da.map_blocks(lambda x: normalize_lon(x, "pm180"), lon_data, dtype=np.float64) # Override lon with normalized version ds["lon"] = xr.DataArray( lon_data, dims=("npoints",), attrs={ "units": "degrees_east", "long_name": "Longitude", "standard_name": "longitude", }, ) ds["lat"] = xr.DataArray( lat_data, dims=("npoints",), attrs={ "units": "degrees_north", "long_name": "Latitude", "standard_name": "latitude", }, ) # --- Standardized area (surface level) and 3D mask --- area_var = None for name in ["nod_area", "cluster_area", "area"]: if name in ds_orig: area_var = name break if area_var: # Use .data to preserve dask arrays if enabled area_raw = ds_orig[area_var].data if use_dask_actual else ds_orig[area_var].values # nod_area may have shape (nz, nod2) - use surface level for area if area_raw.ndim == 2: area_data = area_raw[0, :] # Surface level # Create 3D nod_area with NaN where area == 0 (below bottom/land) if use_dask_actual: import dask.array as da # Use da.where to replace zeros with NaN (lazy operation) nod_area_nans = da.where(area_raw == 0, np.nan, area_raw.astype(np.float64)) else: nod_area_nans = area_raw.astype(np.float64, copy=True) nod_area_nans[nod_area_nans == 0] = np.nan ds["nod_area_nans"] = xr.DataArray( nod_area_nans, dims=("nz", "npoints"), attrs={ "units": "m2", "long_name": "Node cluster area (3D, NaN below bottom)", "comment": "Derived from nod_area with zeros replaced by NaN", }, ) else: area_data = area_raw ds["area"] = xr.DataArray( area_data, dims=("npoints",), attrs={ "units": "m2", "long_name": "Node cluster area", }, ) else: area_data = None # --- Standardized triangles (0-indexed, shape nelem x 3) --- tri_var = None for name in ["face_nodes", "elem", "triangles"]: if name in ds_orig: tri_var = name break if tri_var: tri_data = ds_orig[tri_var].values # Convert from 1-indexed to 0-indexed if needed if tri_data.min() >= 1: tri_data = tri_data - 1 # Ensure shape is (nelem, 3) if tri_data.shape[0] == 3 and tri_data.shape[1] != 3: tri_data = tri_data.T ds["triangles"] = xr.DataArray( tri_data, dims=("nelem", "three"), attrs={ "long_name": "Triangle connectivity (0-indexed)", "cf_role": "face_node_connectivity", "start_index": 0, }, ) # Compute element centers lon_np = ds["lon"].values if not use_dask_actual else ds["lon"].compute().values lat_np = ds["lat"].values if not use_dask_actual else ds["lat"].compute().values lon_tri, lat_tri = compute_element_centers(lon_np, lat_np, tri_data) ds["lon_tri"] = xr.DataArray( lon_tri, dims=("nelem",), attrs={ "units": "degrees_east", "long_name": "Element center longitude", }, ) ds["lat_tri"] = xr.DataArray( lat_tri, dims=("nelem",), attrs={ "units": "degrees_north", "long_name": "Element center latitude", }, ) # Compute area from triangles if not available if area_data is None: area_data = _compute_cluster_area(lon_np, lat_np, tri_data) ds["area"] = xr.DataArray( area_data, dims=("npoints",), attrs={ "units": "m2", "long_name": "Node cluster area (computed)", }, ) # --- Standardized depth levels --- # FESOM uses nz1 for layer centers (47 levels) and nz for layer interfaces (48 levels) # Original nz is kept as-is; standardized depth uses 'depth_level' dimension depth_centers = None depth_interfaces = None if "nz1" in ds_orig.coords: depth_centers = ds_orig["nz1"].values if "nz" in ds_orig.coords: depth_interfaces = ds_orig["nz"].values # Create standardized depth variables with 'depth_level' dimension (layer centers) if depth_centers is not None: ds["depth"] = xr.DataArray( depth_centers, dims=("depth_level",), attrs={ "units": "m", "long_name": "Depth of layer centers", "positive": "down", }, ) # Create depth_bounds and layer_thickness from interfaces if available if depth_interfaces is not None and len(depth_interfaces) > 1: depth_bounds = np.column_stack([ depth_interfaces[:-1], depth_interfaces[1:], ]) ds["depth_bounds"] = xr.DataArray( depth_bounds, dims=("depth_level", "nv"), attrs={ "units": "m", "long_name": "Layer depth bounds", }, ) layer_thickness = np.diff(depth_interfaces) ds["layer_thickness"] = xr.DataArray( layer_thickness, dims=("depth_level",), attrs={ "units": "m", "long_name": "Layer thickness", }, ) elif depth_interfaces is not None and len(depth_interfaces) > 1: # No layer centers, compute from interfaces depth_centers = (depth_interfaces[:-1] + depth_interfaces[1:]) / 2 ds["depth"] = xr.DataArray( depth_centers, dims=("depth_level",), attrs={ "units": "m", "long_name": "Depth of layer centers", "positive": "down", }, ) depth_bounds = np.column_stack([ depth_interfaces[:-1], depth_interfaces[1:], ]) ds["depth_bounds"] = xr.DataArray( depth_bounds, dims=("depth_level", "nv"), attrs={ "units": "m", "long_name": "Layer depth bounds", }, ) layer_thickness = np.diff(depth_interfaces) ds["layer_thickness"] = xr.DataArray( layer_thickness, dims=("depth_level",), attrs={ "units": "m", "long_name": "Layer thickness", }, ) # --- Global attributes --- ds.attrs.update(ds_orig.attrs) ds_orig.close() return add_mesh_metadata(ds, "fesom", filepath, use_dask=use_dask_actual) def _load_from_ascii(mesh_dir: Path, use_dask: bool | None = None) -> xr.Dataset: """Load mesh from ASCII files. Expects mesh directory with: - nod2d.out: Node coordinates - elem2d.out: Triangle connectivity - aux3d.out: Vertical levels (optional) - mesh.diag.nc or fesom.mesh.diag.nc: Area data (optional) Parameters ---------- mesh_dir : Path Path to mesh directory. use_dask : bool, optional Whether to use dask arrays. Returns ------- xr.Dataset Standardized mesh dataset. """ # --- Load node coordinates --- nod2d_file = mesh_dir / "nod2d.out" if not nod2d_file.exists(): # Try netCDF fallback nc_file = mesh_dir / "fesom.mesh.diag.nc" if nc_file.exists(): return _load_from_netcdf(nc_file, use_dask=use_dask) raise FileNotFoundError( f"Could not find mesh files in {mesh_dir}. " "Expected nod2d.out or fesom.mesh.diag.nc" ) with open(nod2d_file) as f: n2d = int(f.readline().strip()) data = np.loadtxt(f, usecols=(1, 2)) lon_data = data[:, 0].astype(np.float64) lat_data = data[:, 1].astype(np.float64) # Normalize longitude lon_data = normalize_lon(lon_data, "pm180") use_dask_actual = should_use_dask(n2d, use_dask) if use_dask_actual: import dask.array as da lon_data = da.from_array(lon_data, chunks=-1) lat_data = da.from_array(lat_data, chunks=-1) ds = xr.Dataset() ds["lon"] = xr.DataArray( lon_data, dims=("npoints",), attrs={ "units": "degrees_east", "long_name": "Longitude", "standard_name": "longitude", }, ) ds["lat"] = xr.DataArray( lat_data, dims=("npoints",), attrs={ "units": "degrees_north", "long_name": "Latitude", "standard_name": "latitude", }, ) # --- Load triangles --- elem_file = mesh_dir / "elem2d.out" tri_data = None if elem_file.exists(): with open(elem_file) as f: n_elem = int(f.readline().strip()) elem_data = np.loadtxt(f, dtype=np.int32) # Convert from 1-indexed to 0-indexed tri_data = elem_data[:, :3] - 1 ds["triangles"] = xr.DataArray( tri_data, dims=("nelem", "three"), attrs={ "long_name": "Triangle connectivity (0-indexed)", "cf_role": "face_node_connectivity", "start_index": 0, }, ) # Compute element centers lon_np = lon_data if not use_dask_actual else lon_data.compute() lat_np = lat_data if not use_dask_actual else lat_data.compute() lon_tri, lat_tri = compute_element_centers(lon_np, lat_np, tri_data) ds["lon_tri"] = xr.DataArray( lon_tri, dims=("nelem",), attrs={ "units": "degrees_east", "long_name": "Element center longitude", }, ) ds["lat_tri"] = xr.DataArray( lat_tri, dims=("nelem",), attrs={ "units": "degrees_north", "long_name": "Element center latitude", }, ) # --- Compute area from triangles --- # Note: ASCII meshes don't have 3D nod_area, so nod_area_nans is not available. # Use NetCDF mesh loader if you need nod_area_nans for depth masking. area_data = None if tri_data is not None: # Compute from triangles lon_np = ds["lon"].values if not use_dask_actual else ds["lon"].compute().values lat_np = ds["lat"].values if not use_dask_actual else ds["lat"].compute().values area_data = _compute_cluster_area(lon_np, lat_np, tri_data) if area_data is None: # Rough approximation earth_area = 4 * np.pi * EARTH_RADIUS**2 area_data = np.full(n2d, earth_area / n2d) if use_dask_actual: import dask.array as da area_data = da.from_array(area_data, chunks=-1) ds["area"] = xr.DataArray( area_data, dims=("npoints",), attrs={ "units": "m2", "long_name": "Node cluster area", }, ) # --- Load vertical levels --- aux3d_file = mesh_dir / "aux3d.out" if aux3d_file.exists(): with open(aux3d_file) as f: nlev = int(f.readline().strip()) depth_interfaces = np.array([float(f.readline().strip()) for _ in range(nlev)]) # Layer centers depth_centers = 0.5 * (depth_interfaces[:-1] + depth_interfaces[1:]) nlevels = len(depth_centers) ds["depth"] = xr.DataArray( depth_centers, dims=("depth_level",), attrs={ "units": "m", "long_name": "Depth of layer centers", "positive": "down", }, ) # Depth bounds depth_bounds = np.column_stack([ depth_interfaces[:-1], depth_interfaces[1:], ]) ds["depth_bounds"] = xr.DataArray( depth_bounds, dims=("depth_level", "nv"), attrs={ "units": "m", "long_name": "Layer depth bounds", }, ) # Layer thickness layer_thickness = np.diff(depth_interfaces) ds["layer_thickness"] = xr.DataArray( layer_thickness, dims=("depth_level",), attrs={ "units": "m", "long_name": "Layer thickness", }, ) return add_mesh_metadata(ds, "fesom", mesh_dir, use_dask=use_dask_actual) def _compute_cluster_area( lon: NDArray[np.floating], lat: NDArray[np.floating], triangles: NDArray[np.integer], ) -> NDArray[np.floating]: """Compute cluster area from triangles. Distributes 1/3 of each triangle's area to its vertices. Parameters ---------- lon : array_like Node longitudes in degrees. lat : array_like Node latitudes in degrees. triangles : array_like Triangle connectivity (nelem, 3), 0-indexed. Returns ------- ndarray Cluster area for each node in m^2. """ n2d = len(lon) area = np.zeros(n2d, dtype=np.float64) for tri in triangles: tri_area = _compute_triangle_area(lon[tri], lat[tri]) area[tri] += tri_area / 3 return area def _compute_triangle_area( lon: NDArray[np.floating], lat: NDArray[np.floating], ) -> float: """Compute approximate area of spherical triangle. Parameters ---------- lon : array_like Longitude of 3 vertices in degrees. lat : array_like Latitude of 3 vertices in degrees. Returns ------- float Triangle area in m^2. """ # Convert to radians lon_rad = np.deg2rad(lon) lat_rad = np.deg2rad(lat) # Convert to Cartesian on unit sphere x = np.cos(lat_rad) * np.cos(lon_rad) y = np.cos(lat_rad) * np.sin(lon_rad) z = np.sin(lat_rad) # Edge vectors v1 = np.array([x[1] - x[0], y[1] - y[0], z[1] - z[0]]) v2 = np.array([x[2] - x[0], y[2] - y[0], z[2] - z[0]]) # Cross product magnitude gives 2 * area on unit sphere cross = np.cross(v1, v2) area = 0.5 * np.linalg.norm(cross) * EARTH_RADIUS**2 return float(area)
[docs] def open_dataset( data_path: str | os.PathLike, mesh: xr.Dataset | None = None, mesh_path: str | os.PathLike | None = None, ) -> xr.Dataset: """Open a FESOM2 data file with mesh information. Parameters ---------- data_path : str or path-like Path to the data file (NetCDF). mesh : xr.Dataset, optional Pre-loaded mesh dataset. If not provided, mesh_path must be specified. mesh_path : str or path-like, optional Path to mesh directory. Ignored if mesh is provided. Returns ------- xr.Dataset Dataset with mesh coordinates attached. Examples -------- >>> mesh = nr.fesom.load_mesh("/meshes/core2") >>> ds = nr.fesom.open_dataset("temp.fesom.2010.nc", mesh=mesh) >>> ds = nr.fesom.open_dataset("temp.fesom.2010.nc", mesh_path="/meshes/core2") """ # Load mesh if not provided if mesh is None: if mesh_path is None: raise ValueError("Either mesh or mesh_path must be provided") mesh = load_mesh(mesh_path) # Open dataset ds = xr.open_dataset(data_path) # Get coordinate arrays lon_data = mesh["lon"].values lat_data = mesh["lat"].values # Add mesh coordinates based on dimension names if "nod2" in ds.dims: ds = ds.assign_coords( lon=("nod2", lon_data), lat=("nod2", lat_data), ) elif "nodes_2d" in ds.dims: ds = ds.assign_coords( lon=("nodes_2d", lon_data), lat=("nodes_2d", lat_data), ) elif "npoints" in ds.dims: ds = ds.assign_coords( lon=("npoints", lon_data), lat=("npoints", lat_data), ) # Add depth coordinates if applicable if "depth" in mesh: depth_data = mesh["depth"].values nlev = len(depth_data) if "nz" in ds.dims and ds.sizes["nz"] == nlev: ds = ds.assign_coords(depth=("nz", depth_data)) elif "nz1" in ds.dims and ds.sizes["nz1"] == nlev: ds = ds.assign_coords(depth=("nz1", depth_data)) return ds