Source code for finam_vtk.tools

"""VTK helper classes and functions"""

import json
from datetime import datetime, timedelta

import finam as fm
import numpy as np
import pyvista as pv
from cftime import num2pydate
from finam.data.grid_tools import INV_VTK_TYPE_MAP, VTK_CELL_DIM, get_cells_matrix

ASSOCIATION = {
    pv.FieldAssociation.POINT: "point",
    pv.FieldAssociation.CELL: "cell",
    pv.FieldAssociation.NONE: "field",
}
"""dict: Field associations map."""

TIME_UNITS = ["days", "hours", "minutes", "seconds"]
"""list: supported time units."""

TIME_DELTAS = {
    "days": timedelta(days=1),
    "hours": timedelta(hours=1),
    "minutes": timedelta(minutes=1),
    "seconds": timedelta(seconds=1),
}
"""dict: time deltas associated with supported units."""


def _to_pure_datetime(dt):
    return datetime(
        year=dt.year,
        month=dt.month,
        day=dt.day,
        hour=dt.hour,
        minute=dt.minute,
        second=dt.second,
        microsecond=dt.microsecond,
        tzinfo=dt.tzinfo,
        fold=dt.fold,
    )


def get_time_units(reference_date, time_unit):
    """
    Generate CF conform time units.

    Parameters
    ----------
    reference_date : datetime.datetime
        Reference datetime to determine times.
    time_unit : str
        Unit of the timesteps (e.g. "seconds", "hours", "days", ...)

    Returns
    -------
    str
        CF time units.
    """
    return f"{time_unit} since {reference_date.isoformat(sep=' ')}"


def generate_times(time_values, reference_date, time_unit, calendar="standard"):
    """
    Generate datetime list from raw timesteps.

    Parameters
    ----------
    time_values : iterable
        list of timesteps.
    reference_date : datetime.datetime
        Reference datetime to determine times.
    time_unit : str
        Unit of the timesteps (e.g. "seconds", "hours", "days", ...)
    calendar : str, optional
        Describes the calendar used in the time calculations.
        All the values currently defined in the CF metadata convention are supported.
        Valid calendars **'standard', 'gregorian', 'proleptic_gregorian'
        'noleap', '365_day', '360_day', 'julian', 'all_leap', '366_day'**
        by default "standard"

    Returns
    -------
    list of datetime.datetime
        The generated time points.
    """
    units = get_time_units(reference_date, time_unit)
    times = num2pydate(time_values, units=units, calendar=calendar)
    return list(map(_to_pure_datetime, times))


def convert_data(data, masked):
    """
    Convert data to numpy array.

    Parameters
    ----------
    data : arraylike
        The input data.
    masked : bool
        Whether the data should be masked at invalid values.

    Returns
    -------
    numpy.ndarray
        The converted data.
    """
    return np.ma.masked_invalid(data) if masked else np.asarray(data)


def mesh_type(mesh):
    """
    Determine the type of a pyvista mesh object.

    Parameters
    ----------
    mesh : pyvista.DataSet
        Pyvista pointset or grid

    Returns
    -------
    str
        Class name of the vtk data set.
    """
    return mesh.__class__.__name__


def is_unstructured(mesh):
    """
    Whether a pyvista mesh is unstructured.

    Parameters
    ----------
    mesh : pyvista.DataSet
        Pyvista pointset or grid

    Returns
    -------
    bool
        Flag for unstructured meshes.
    """
    return mesh_type(mesh) not in ["RectilinearGrid", "ImageData", "PointSet"]


def prepare_unstructured_mesh(mesh, remove_low_dim_cells=True):
    """
    Prepare an unstructured pyvista mesh.

    Parameters
    ----------
    mesh : pyvista.DataSet
        Pyvista pointset or grid
    remove_low_dim_cells : bool, optional
        Whether to remove lower dimension cells (like vertices or edges),
        by default True

    Returns
    -------
    pyvista.DataSet
        The prepared mesh.
    """
    mesh = mesh.cast_to_unstructured_grid()
    if remove_low_dim_cells:
        unique_cell_types = np.unique(mesh.celltypes)
        cell_dims = VTK_CELL_DIM[unique_cell_types]
        max_dim = max(cell_dims)
        selection = unique_cell_types[cell_dims == max_dim]
        if len(selection) < len(unique_cell_types):
            mesh = mesh.extract_cells(np.nonzero(np.isin(mesh.celltypes, selection))[0])
    return mesh


def grid_from_pyvista(
    mesh,
    return_mesh=False,
    remove_low_dim_cells=True,
    minimal_spatial_dim=True,
    **kwargs,
):
    """
    Generate finam grid from a pyvista object.

    Parameters
    ----------
    mesh : pyvista.DataSet
        Pyvista pointset or grid
    return_mesh : bool, optional
        Whether to return the mesh again. Useful if remove_low_dim_cells is True,
        or mesh was cast to unstructured, by default False
    remove_low_dim_cells : bool, optional
        Whether to remove lower dimension cells (like vertices or edges),
        by default True
    minimal_spatial_dim : bool, optional
        Force grid to have minimal dimension (e.g. 2D grids are defined as 3D in vtk),
        by default True
    **kwargs
        Keyword arguments forwarded to the finam Grid class used.

    Returns
    -------
    finam.Grid
        The resulting finam grid.

    Raises
    ------
    ValueError
        When the grid has unsupported cell types.
    ValueError
        When the mesh type is unsupported by finam.
    """
    mtype = mesh_type(mesh)
    if is_unstructured(mesh):
        mesh = prepare_unstructured_mesh(mesh, remove_low_dim_cells)
        cell_types = INV_VTK_TYPE_MAP[mesh.celltypes]
        cells = get_cells_matrix(cell_types, mesh.cells)
        if np.any((t_mask := cell_types == -1)):
            t_err = np.unique(mesh.celltypes(t_mask))
            msg = f"finam-VTK: mesh holds unsupported cell types ({t_err})"
            raise ValueError(msg)
        points = np.array(mesh.points, copy=True)
        p_count, p_dim = points.shape
        if minimal_spatial_dim and p_count > 1:
            while p_dim > 1 and np.all(np.isclose(points[:, p_dim - 1], 0)):
                p_dim -= 1
        grid = fm.UnstructuredGrid(
            points=points[:, :p_dim], cells=cells, cell_types=cell_types, **kwargs
        )
    elif mtype == "ImageData":
        grid = fm.UniformGrid(
            dims=mesh.dimensions, spacing=mesh.spacing, origin=mesh.origin, **kwargs
        )
    elif mtype == "RectilinearGrid":
        grid = fm.RectilinearGrid(axes=[mesh.x, mesh.y, mesh.z], **kwargs)
    elif mtype == "PointSet":
        grid = fm.UnstructuredPoints(points=mesh.points, **kwargs)
    else:
        msg = f"finam-VTK: Unknown mesh type: {mtype}"
        raise ValueError(msg)
    if return_mesh:
        return grid, mesh
    return grid


[docs] def read_vtk_grid(path, remove_low_dim_cells=True, minimal_spatial_dim=True, **kwargs): """ Read a finam grid from a VTK file. Parameters ---------- path : pathlike Path to the vtk file. remove_low_dim_cells : bool, optional Whether to remove lower dimension cells (like vertices or edges), by default True minimal_spatial_dim : bool, optional Force grid to have minimal dimension (e.g. 2D grids are defined as 3D in vtk), by default True **kwargs Keyword arguments forwarded to the finam Grid class used. Returns ------- finam.Grid The resulting finam grid. """ return grid_from_pyvista( mesh=pv.read(path), remove_low_dim_cells=remove_low_dim_cells, minimal_spatial_dim=minimal_spatial_dim, **kwargs, )
def _grids_compatible(grid, ref_grid): """ Check grid compatibility while tolerating shape-comparison errors. Parameters ---------- grid : finam.Grid Grid to compare. ref_grid : finam.Grid Reference grid to compare against. Returns ------- bool Whether both grids are compatible. """ try: return grid.compatible_with(ref_grid, check_location=False) except ValueError: return False def _get_reference_grid(in_infos, data_arrays): """ Get the reference grid for a set of writer inputs. Parameters ---------- in_infos : dict Input infos keyed by variable name. data_arrays : list of DataArray Writer data arrays. Returns ------- finam.Grid The first non-field grid, or a dummy grid for field-only outputs. """ for var in data_arrays: ref_grid = in_infos[var.name].grid if not isinstance(ref_grid, fm.NoGrid): return ref_grid return fm.UniformGrid(dims=(0, 0, 0)) def _create_pv_mesh(ref_grid, legacy=False): """ Create a PyVista mesh from a FINAM grid. Parameters ---------- ref_grid : finam.Grid Reference FINAM grid. legacy : bool, optional Whether to use the legacy VTK file extension, by default False. Returns ------- tuple Tuple of ``(pv_mesh, is_structured, file_ext)``. """ is_structured = isinstance(ref_grid, fm.data.StructuredGrid) if isinstance(ref_grid, fm.UniformGrid): dim = ref_grid.dim dimensions = ref_grid.dims + (1,) * (3 - dim) origin = ref_grid.origin + (0.0,) * (3 - dim) spacing = ref_grid.spacing + (0.0,) * (3 - dim) pv_mesh = pv.ImageData(dimensions=dimensions, spacing=spacing, origin=origin) file_ext = ".vti" elif isinstance(ref_grid, fm.RectilinearGrid): dim = ref_grid.dim axes = ref_grid.axes x = np.ascontiguousarray(axes[0]) y = np.ascontiguousarray(axes[1] if dim > 1 else np.array([0.0])) z = np.ascontiguousarray(axes[2] if dim > 2 else np.array([0.0])) pv_mesh = pv.RectilinearGrid(x, y, z) file_ext = ".vtr" else: points = np.zeros((ref_grid.point_count, 3), dtype=float) points[:, : ref_grid.dim] = ref_grid.points cells = ref_grid.cells_definition cell_types = fm.data.grid_tools.VTK_TYPE_MAP[ref_grid.cell_types] pv_mesh = pv.UnstructuredGrid(cells, cell_types, points) file_ext = ".vtu" if legacy: file_ext = ".vtk" return pv_mesh, is_structured, file_ext def _prepare_writer_arrays(data_arrays, in_infos, ref_grid, writer_name): """ Attach grids and validate array compatibility for a writer. Parameters ---------- data_arrays : list of DataArray Writer data arrays. in_infos : dict Input infos keyed by variable name. ref_grid : finam.Grid Reference grid used for the output. writer_name : str Name used in compatibility error messages. """ for var in data_arrays: grid = in_infos[var.name].grid var.info_kwargs["grid"] = grid var.set_association_by_grid() if not var.association == "field" and not _grids_compatible(grid, ref_grid): raise ValueError(f"{writer_name}: inputs have incompatible grids.") def _set_mesh_data(pv_mesh, data_arrays, data, is_structured): """ Set field, point, or cell data on a PyVista mesh. Parameters ---------- pv_mesh : pyvista.DataSet Target mesh. data_arrays : list of DataArray Writer data arrays. data : dict Input data keyed by variable name. is_structured : bool Whether the mesh is structured and needs canonical reshaping. """ for var in data_arrays: grid = var.get_grid() out = fm.data.strip_time(data[var.name], grid).magnitude if not var.association == "field" and is_structured: out = grid.to_canonical(out).reshape(-1, order="F") if var.association == "field": pv_mesh.field_data[var.name] = out elif var.association == "cell": pv_mesh.cell_data[var.name] = out else: pv_mesh.point_data[var.name] = out def _timestep_path(file_path, file_name, step_counter, file_digits, file_ext): """ Build the output path for a timestep VTK file. Parameters ---------- file_path : pathlib.Path Base file path without the numeric suffix. file_name : str File stem used for the output files. step_counter : int Current timestep index. file_digits : int Width of the numeric suffix. file_ext : str Output file extension. Returns ------- pathlib.Path The path for the timestep file. """ return file_path.with_name(f"{file_name}{step_counter:0{file_digits}}").with_suffix( file_ext )
[docs] class DataArray: """ Specifications for a VTK data array. Parameters ---------- name : str Data array name in the VTK file. association : str, optionl Indicate how data is associated ("point", "cell", or "field"). Can be None to be determined. **info_kwargs Optional keyword arguments to instantiate an Info object (i.e. 'grid' and 'meta') Used to overwrite meta data, to change units or to provide a desired grid specification. """ def __init__(self, name, association=None, **info_kwargs): self.name = name self.association = association self.info_kwargs = info_kwargs
[docs] def get_meta(self): """Get the meta-data dictionary of this data array.""" meta = self.info_kwargs.get("meta", {}) meta.update( { k: v for k, v in self.info_kwargs.items() if k not in ["time", "grid", "meta"] } ) return meta
[docs] def get_grid(self): """Get the grid of this data array.""" return self.info_kwargs.get("grid", None)
[docs] def set_association_by_grid(self, grid=None): """Set the data association by a grid specification""" grid = self.get_grid() if grid is None else grid if grid is None: raise ValueError("DataArray: no grid specified to determine association.") if isinstance(grid, fm.NoGrid): self.association = "field" elif grid.data_location == fm.Location.CELLS: self.association = "cell" else: self.association = "point"
def __repr__(self): name, association = self.name, self.association return f"DataArray({name=}, {association=}, **{self.info_kwargs})"
def create_data_array_list(data_arrays): """ Create a list of DataArray instances. Parameters ---------- data_arrays : list of str or DataArray List containing DataArray instances or names. Returns ------- list of DataArray List containing only DataArray instances. """ return [ arr if isinstance(arr, DataArray) else DataArray(arr) for arr in data_arrays ] def extract_data_arrays(mesh, data_arrays=None): """ Extract the data_array information from a pyvista mesh. Parameters ---------- mesh : pyvista.DataSet Pyvista mesh. data_arrays : list of DataArray or str, optional List of desired data_arrays given by name or a :class:`DataArray` instance. By default, all data_arrays present in the mesh. Returns ------- data_arrays : list of DataArray Data array information. """ if data_arrays is None: data_arrays = [] for name in mesh.field_data: data_arrays.append(DataArray(name=name, association="field")) for name in mesh.point_data: data_arrays.append(DataArray(name=name, association="point")) for name in mesh.cell_data: data_arrays.append(DataArray(name=name, association="cell")) else: data_arrays = create_data_array_list(data_arrays) for var in data_arrays: association = ASSOCIATION[mesh.get_array_association(var.name)] if var.association is None: var.association = association elif var.association != association: msg = ( f"{var.name}: data is associated with '{association}', " f"but '{var.association}' was given." ) raise ValueError(msg) return data_arrays def needs_masking(data): """ Whether data has non-finite values (infs, nans). Parameters ---------- data : arraylike Data to check Returns ------- bool Finity status. """ return np.any(~(np.isfinite(data))) def _is_int(value, **kwargs): return np.isclose(value, np.around(value), **kwargs) def get_time_unit(step, time_unit=None): """ Determine time units from time step delta. Parameters ---------- step : timedelta The desired time step time_unit : str or None, optional The desired time unit. If not given, will be determined. By default : None Returns ------- str The time unit (days, hours, minutes, seconds) Raises ------ ValueError if provided step is not compatible with supported units ValueError if provided time unit is not supported ValueError if provided time unit is not compatible with provided step """ ref_unit = None for unit in TIME_UNITS: if _is_int(step / TIME_DELTAS[unit]): ref_unit = unit break if ref_unit is None: raise ValueError("VTK: time step is not compatible with supported units.") if time_unit is not None: if time_unit not in TIME_UNITS: raise ValueError(f"VTK: time unit '{time_unit}' not supported.") if TIME_UNITS.index(time_unit) < TIME_UNITS.index(ref_unit): raise ValueError( f"VTK: time unit '{time_unit}' not compatible with time step." ) return time_unit return ref_unit class _DateTimeEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, datetime): return o.isoformat() return super().default(o) def save_dict_to_json(file_path, **data_dict): """ Save a dictionary with datetime objects to a JSON file. Parameters ---------- file_path : pathlike The path to the JSON file. data_dict : dict The dictionary to save. """ with open(file_path, "w", encoding="utf-8") as json_file: json.dump(data_dict, json_file, cls=_DateTimeEncoder) def _datetime_parser(dct): for key, value in dct.items(): if key == "reference_date": try: dct[key] = datetime.fromisoformat(value) except (ValueError, TypeError): pass return dct
[docs] def read_aux_file(file_path): """ Load a dictionary from a JSON file and convert 'reference_date' strings back to datetime objects. Parameters ---------- file_path : pathlike The path to the JSON file. Returns ------- dict The loaded dictionary with datetime objects for 'reference_date'. """ with open(file_path, "r", encoding="utf-8") as json_file: return json.load(json_file, object_hook=_datetime_parser)
def write_pvd_file(path, vtk_files, time_steps, time_units=None): """ Generates a PVD file linking VTK files with their corresponding time steps. Parameters ---------- path : pathlike Name of the PVD file to be created. vtk_files : list of str List of paths to VTK files. time_steps : list of float List of time steps corresponding to each VTK file. time_units : str, optional The units of measurement for time steps added as attribute to the collection. Follows CF convention and should be like "days since 2023-09-20 00:00:00". Default is None. """ if len(vtk_files) != len(time_steps): raise ValueError("The number of VTK files must match the number of time steps.") header = '<?xml version="1.0"?>\n' pvd_start = '<VTKFile type="Collection" version="0.1">\n' if time_units: pvd_col = f'<Collection time_units="{time_units}">\n' else: pvd_col = "<Collection>\n" pvd_end = "</Collection>\n</VTKFile>\n" entries = [] for file_name, time in zip(vtk_files, time_steps): entry = f' <DataSet timestep="{time}" file="{file_name}" />\n' entries.append(entry) content = header + pvd_start + pvd_col + "".join(entries) + pvd_end with open(path, "w", encoding="utf-8") as f: f.write(content)