Source code for nereus.models.mitgcm.mesh

"""MITgcm mesh loading.

Loads MITgcm grid files (``XC.data``, ``YC.data``, etc.) and produces
a standardized nereus ``xr.Dataset`` with flattened 1D coordinates.
"""

from __future__ import annotations

import os
from pathlib import Path

import numpy as np
import xarray as xr
from numpy.typing import NDArray

from nereus.core.coordinates import EARTH_RADIUS
from nereus.core.mesh import (
    add_mesh_metadata,
    normalize_lon,
    should_use_dask,
)
from nereus.models.mitgcm.io import read_mds


[docs] def load_mesh( path: str | os.PathLike, *, grid_type: str = "C", use_dask: bool | None = None, mask_land: bool = False, ) -> xr.Dataset: """Load MITgcm mesh from a directory of MDS grid files. Parameters ---------- path : str or Path Directory containing MITgcm grid files (``XC.data``, ``YC.data``, etc.). grid_type : {"C", "G"} Which grid point to use: - ``"C"``: Cell centers (reads ``XC``/``YC``). - ``"G"``: Cell corners (reads ``XG``/``YG``). use_dask : bool, optional Whether to use dask arrays. Auto-detects if None. mask_land : bool If True, derive land/ocean masks from ``hFacC.data`` (3D, per-level) or fall back to ``Depth.data > 0`` (2D surface only). The masks are stored as ``land_mask`` (2D boolean, ``npoints``) and ``hFacC`` (3D float, ``depth_level × npoints``) if available. Default is False to preserve raw data access. Returns ------- xr.Dataset Mesh dataset with ``lon``, ``lat``, ``area`` on ``npoints`` dimension. Optionally includes ``depth``, ``layer_thickness``, ``bathymetry``. When ``mask_land=True``, also includes ``land_mask`` and ``hFacC``. Attributes include ``nx``, ``ny``, ``original_shape``. Examples -------- >>> mesh = nr.mitgcm.load_mesh("/path/to/run/") >>> mesh_masked = nr.mitgcm.load_mesh("/path/to/run/", mask_land=True) >>> # Use mask to filter data: ds["THETA"].where(~mesh["land_mask"]) """ path = Path(path) # Select coordinate file prefixes based on grid type if grid_type == "C": lon_prefix, lat_prefix = "XC", "YC" elif grid_type == "G": lon_prefix, lat_prefix = "XG", "YG" else: raise ValueError(f"Unknown grid_type: {grid_type!r}. Use 'C' or 'G'.") # Read 2D coordinate arrays _, lon_2d = read_mds(path / lon_prefix) _, lat_2d = read_mds(path / lat_prefix) lon_2d = lon_2d.astype(np.float64) lat_2d = lat_2d.astype(np.float64) ny, nx = lon_2d.shape # Read cell areas if (path / "RAC.meta").exists(): _, area_2d = read_mds(path / "RAC") area_flat = area_2d.ravel().astype(np.float64) else: area_flat = _estimate_structured_area(lon_2d, lat_2d) # Flatten and normalize lon_flat = normalize_lon(lon_2d.ravel(), "pm180") lat_flat = lat_2d.ravel().astype(np.float64) npoints = len(lon_flat) use_dask_actual = should_use_dask(npoints, use_dask) if use_dask_actual: import dask.array as da lon_data = da.from_array(lon_flat, chunks=-1) lat_data = da.from_array(lat_flat, chunks=-1) area_data = da.from_array(area_flat, chunks=-1) else: lon_data = lon_flat lat_data = lat_flat area_data = area_flat ds = xr.Dataset( { "lon": (("npoints",), lon_data, { "units": "degrees_east", "long_name": "Longitude", "standard_name": "longitude", }), "lat": (("npoints",), lat_data, { "units": "degrees_north", "long_name": "Latitude", "standard_name": "latitude", }), "area": (("npoints",), area_data, { "units": "m2", "long_name": "Cell area", }), }, attrs={ "nx": nx, "ny": ny, "original_shape": (ny, nx), }, ) # Depth levels (RC.data = center depths, DRF.data = layer thicknesses) if (path / "RC.meta").exists(): _, rc = read_mds(path / "RC") depth = np.abs(rc.ravel()).astype(np.float64) ds["depth"] = xr.DataArray( depth, dims=("depth_level",), attrs={ "units": "m", "long_name": "Depth of layer centers", "positive": "down", }, ) if (path / "DRF.meta").exists(): _, drf = read_mds(path / "DRF") layer_thickness = drf.ravel().astype(np.float64) ds["layer_thickness"] = xr.DataArray( layer_thickness, dims=("depth_level",), attrs={ "units": "m", "long_name": "Layer thickness", }, ) # Bathymetry if (path / "Depth.meta").exists(): _, bathy = read_mds(path / "Depth") bathy_flat = bathy.ravel().astype(np.float64) if use_dask_actual: import dask.array as da bathy_flat = da.from_array(bathy_flat, chunks=-1) ds["bathymetry"] = xr.DataArray( bathy_flat, dims=("npoints",), attrs={ "units": "m", "long_name": "Bathymetry", "positive": "down", }, ) # Land/ocean masking if mask_land: if (path / "hFacC.meta").exists(): _, hfac_3d = read_mds(path / "hFacC") hfac_3d = hfac_3d.astype(np.float64) nz = hfac_3d.shape[0] # 2D surface mask: True = land land_mask = hfac_3d[0].ravel() == 0 # 3D hFacC flattened to (depth_level, npoints) hfac_flat = hfac_3d.reshape(nz, -1) ds["land_mask"] = xr.DataArray( land_mask, dims=("npoints",), attrs={ "long_name": "Land mask (True = land)", "source": "hFacC", }, ) ds["hFacC"] = xr.DataArray( hfac_flat, dims=("depth_level", "npoints"), attrs={ "long_name": "Fraction of open cell at tracer point", "source": "hFacC.data", "comment": "0 = land, 1 = fully open ocean, " "between 0 and 1 = partial cell", }, ) elif (path / "Depth.meta").exists(): # Fallback: derive 2D mask from bathymetry if "bathymetry" not in ds: _, bathy = read_mds(path / "Depth") bathy_vals = bathy.ravel().astype(np.float64) else: bathy_vals = ds["bathymetry"].values land_mask = bathy_vals == 0 ds["land_mask"] = xr.DataArray( land_mask, dims=("npoints",), attrs={ "long_name": "Land mask (True = land)", "source": "Depth", "comment": "Derived from Depth == 0; " "2D only, no per-level masking", }, ) return add_mesh_metadata(ds, "mitgcm", path, use_dask=use_dask_actual)
def _estimate_structured_area( lon_2d: NDArray[np.floating], lat_2d: NDArray[np.floating], ) -> NDArray[np.floating]: """Estimate cell areas from coordinate arrays using finite differences. Parameters ---------- lon_2d : ndarray 2D longitude array (ny, nx). lat_2d : ndarray 2D latitude array (ny, nx). Returns ------- ndarray Flattened cell areas in m^2. """ nlat, nlon = lon_2d.shape # dx: spacing in longitude direction dlon = np.zeros_like(lon_2d) dlon[:, 1:-1] = (lon_2d[:, 2:] - lon_2d[:, :-2]) / 2 dlon[:, 0] = lon_2d[:, 1] - lon_2d[:, 0] dlon[:, -1] = lon_2d[:, -1] - lon_2d[:, -2] dlon = np.abs(dlon) dlon = np.where(dlon > 180, 360 - dlon, dlon) # dy: spacing in latitude direction dlat = np.zeros_like(lat_2d) dlat[1:-1, :] = (lat_2d[2:, :] - lat_2d[:-2, :]) / 2 dlat[0, :] = lat_2d[1, :] - lat_2d[0, :] dlat[-1, :] = lat_2d[-1, :] - lat_2d[-2, :] dlat = np.abs(dlat) dlon_rad = np.deg2rad(dlon) dlat_rad = np.deg2rad(dlat) lat_rad = np.deg2rad(lat_2d) area_2d = EARTH_RADIUS**2 * np.cos(lat_rad) * dlon_rad * dlat_rad return np.abs(area_2d).ravel()