"""VTK helper classes and functions"""
import json
from datetime import datetime, timedelta
import finam as fm
import numpy as np
import pyvista as pv
from cftime import num2pydate
from finam.data.grid_tools import INV_VTK_TYPE_MAP, VTK_CELL_DIM, get_cells_matrix
ASSOCIATION = {
pv.FieldAssociation.POINT: "point",
pv.FieldAssociation.CELL: "cell",
pv.FieldAssociation.NONE: "field",
}
"""dict: Field associations map."""
TIME_UNITS = ["days", "hours", "minutes", "seconds"]
"""list: supported time units."""
TIME_DELTAS = {
"days": timedelta(days=1),
"hours": timedelta(hours=1),
"minutes": timedelta(minutes=1),
"seconds": timedelta(seconds=1),
}
"""dict: time deltas associated with supported units."""
def _to_pure_datetime(dt):
return datetime(
year=dt.year,
month=dt.month,
day=dt.day,
hour=dt.hour,
minute=dt.minute,
second=dt.second,
microsecond=dt.microsecond,
tzinfo=dt.tzinfo,
fold=dt.fold,
)
def get_time_units(reference_date, time_unit):
"""
Generate CF conform time units.
Parameters
----------
reference_date : datetime.datetime
Reference datetime to determine times.
time_unit : str
Unit of the timesteps (e.g. "seconds", "hours", "days", ...)
Returns
-------
str
CF time units.
"""
return f"{time_unit} since {reference_date.isoformat(sep=' ')}"
def generate_times(time_values, reference_date, time_unit, calendar="standard"):
"""
Generate datetime list from raw timesteps.
Parameters
----------
time_values : iterable
list of timesteps.
reference_date : datetime.datetime
Reference datetime to determine times.
time_unit : str
Unit of the timesteps (e.g. "seconds", "hours", "days", ...)
calendar : str, optional
Describes the calendar used in the time calculations.
All the values currently defined in the CF metadata convention are supported.
Valid calendars **'standard', 'gregorian', 'proleptic_gregorian'
'noleap', '365_day', '360_day', 'julian', 'all_leap', '366_day'**
by default "standard"
Returns
-------
list of datetime.datetime
The generated time points.
"""
units = get_time_units(reference_date, time_unit)
times = num2pydate(time_values, units=units, calendar=calendar)
return list(map(_to_pure_datetime, times))
def convert_data(data, masked):
"""
Convert data to numpy array.
Parameters
----------
data : arraylike
The input data.
masked : bool
Whether the data should be masked at invalid values.
Returns
-------
numpy.ndarray
The converted data.
"""
return np.ma.masked_invalid(data) if masked else np.asarray(data)
def mesh_type(mesh):
"""
Determine the type of a pyvista mesh object.
Parameters
----------
mesh : pyvista.DataSet
Pyvista pointset or grid
Returns
-------
str
Class name of the vtk data set.
"""
return mesh.__class__.__name__
def is_unstructured(mesh):
"""
Whether a pyvista mesh is unstructured.
Parameters
----------
mesh : pyvista.DataSet
Pyvista pointset or grid
Returns
-------
bool
Flag for unstructured meshes.
"""
return mesh_type(mesh) not in ["RectilinearGrid", "ImageData", "PointSet"]
def prepare_unstructured_mesh(mesh, remove_low_dim_cells=True):
"""
Prepare an unstructured pyvista mesh.
Parameters
----------
mesh : pyvista.DataSet
Pyvista pointset or grid
remove_low_dim_cells : bool, optional
Whether to remove lower dimension cells (like vertices or edges),
by default True
Returns
-------
pyvista.DataSet
The prepared mesh.
"""
mesh = mesh.cast_to_unstructured_grid()
if remove_low_dim_cells:
unique_cell_types = np.unique(mesh.celltypes)
cell_dims = VTK_CELL_DIM[unique_cell_types]
max_dim = max(cell_dims)
selection = unique_cell_types[cell_dims == max_dim]
if len(selection) < len(unique_cell_types):
mesh = mesh.extract_cells(np.nonzero(np.isin(mesh.celltypes, selection))[0])
return mesh
def grid_from_pyvista(
mesh,
return_mesh=False,
remove_low_dim_cells=True,
minimal_spatial_dim=True,
**kwargs,
):
"""
Generate finam grid from a pyvista object.
Parameters
----------
mesh : pyvista.DataSet
Pyvista pointset or grid
return_mesh : bool, optional
Whether to return the mesh again. Useful if remove_low_dim_cells is True,
or mesh was cast to unstructured, by default False
remove_low_dim_cells : bool, optional
Whether to remove lower dimension cells (like vertices or edges),
by default True
minimal_spatial_dim : bool, optional
Force grid to have minimal dimension (e.g. 2D grids are defined as 3D in vtk),
by default True
**kwargs
Keyword arguments forwarded to the finam Grid class used.
Returns
-------
finam.Grid
The resulting finam grid.
Raises
------
ValueError
When the grid has unsupported cell types.
ValueError
When the mesh type is unsupported by finam.
"""
mtype = mesh_type(mesh)
if is_unstructured(mesh):
mesh = prepare_unstructured_mesh(mesh, remove_low_dim_cells)
cell_types = INV_VTK_TYPE_MAP[mesh.celltypes]
cells = get_cells_matrix(cell_types, mesh.cells)
if np.any((t_mask := cell_types == -1)):
t_err = np.unique(mesh.celltypes(t_mask))
msg = f"finam-VTK: mesh holds unsupported cell types ({t_err})"
raise ValueError(msg)
points = np.array(mesh.points, copy=True)
p_count, p_dim = points.shape
if minimal_spatial_dim and p_count > 1:
while p_dim > 1 and np.all(np.isclose(points[:, p_dim - 1], 0)):
p_dim -= 1
grid = fm.UnstructuredGrid(
points=points[:, :p_dim], cells=cells, cell_types=cell_types, **kwargs
)
elif mtype == "ImageData":
grid = fm.UniformGrid(
dims=mesh.dimensions, spacing=mesh.spacing, origin=mesh.origin, **kwargs
)
elif mtype == "RectilinearGrid":
grid = fm.RectilinearGrid(axes=[mesh.x, mesh.y, mesh.z], **kwargs)
elif mtype == "PointSet":
grid = fm.UnstructuredPoints(points=mesh.points, **kwargs)
else:
msg = f"finam-VTK: Unknown mesh type: {mtype}"
raise ValueError(msg)
if return_mesh:
return grid, mesh
return grid
[docs]
def read_vtk_grid(path, remove_low_dim_cells=True, minimal_spatial_dim=True, **kwargs):
"""
Read a finam grid from a VTK file.
Parameters
----------
path : pathlike
Path to the vtk file.
remove_low_dim_cells : bool, optional
Whether to remove lower dimension cells (like vertices or edges),
by default True
minimal_spatial_dim : bool, optional
Force grid to have minimal dimension (e.g. 2D grids are defined as 3D in vtk),
by default True
**kwargs
Keyword arguments forwarded to the finam Grid class used.
Returns
-------
finam.Grid
The resulting finam grid.
"""
return grid_from_pyvista(
mesh=pv.read(path),
remove_low_dim_cells=remove_low_dim_cells,
minimal_spatial_dim=minimal_spatial_dim,
**kwargs,
)
def _grids_compatible(grid, ref_grid):
"""
Check grid compatibility while tolerating shape-comparison errors.
Parameters
----------
grid : finam.Grid
Grid to compare.
ref_grid : finam.Grid
Reference grid to compare against.
Returns
-------
bool
Whether both grids are compatible.
"""
try:
return grid.compatible_with(ref_grid, check_location=False)
except ValueError:
return False
def _get_reference_grid(in_infos, data_arrays):
"""
Get the reference grid for a set of writer inputs.
Parameters
----------
in_infos : dict
Input infos keyed by variable name.
data_arrays : list of DataArray
Writer data arrays.
Returns
-------
finam.Grid
The first non-field grid, or a dummy grid for field-only outputs.
"""
for var in data_arrays:
ref_grid = in_infos[var.name].grid
if not isinstance(ref_grid, fm.NoGrid):
return ref_grid
return fm.UniformGrid(dims=(0, 0, 0))
def _create_pv_mesh(ref_grid, legacy=False):
"""
Create a PyVista mesh from a FINAM grid.
Parameters
----------
ref_grid : finam.Grid
Reference FINAM grid.
legacy : bool, optional
Whether to use the legacy VTK file extension, by default False.
Returns
-------
tuple
Tuple of ``(pv_mesh, is_structured, file_ext)``.
"""
is_structured = isinstance(ref_grid, fm.data.StructuredGrid)
if isinstance(ref_grid, fm.UniformGrid):
dim = ref_grid.dim
dimensions = ref_grid.dims + (1,) * (3 - dim)
origin = ref_grid.origin + (0.0,) * (3 - dim)
spacing = ref_grid.spacing + (0.0,) * (3 - dim)
pv_mesh = pv.ImageData(dimensions=dimensions, spacing=spacing, origin=origin)
file_ext = ".vti"
elif isinstance(ref_grid, fm.RectilinearGrid):
dim = ref_grid.dim
axes = ref_grid.axes
x = np.ascontiguousarray(axes[0])
y = np.ascontiguousarray(axes[1] if dim > 1 else np.array([0.0]))
z = np.ascontiguousarray(axes[2] if dim > 2 else np.array([0.0]))
pv_mesh = pv.RectilinearGrid(x, y, z)
file_ext = ".vtr"
else:
points = np.zeros((ref_grid.point_count, 3), dtype=float)
points[:, : ref_grid.dim] = ref_grid.points
cells = ref_grid.cells_definition
cell_types = fm.data.grid_tools.VTK_TYPE_MAP[ref_grid.cell_types]
pv_mesh = pv.UnstructuredGrid(cells, cell_types, points)
file_ext = ".vtu"
if legacy:
file_ext = ".vtk"
return pv_mesh, is_structured, file_ext
def _prepare_writer_arrays(data_arrays, in_infos, ref_grid, writer_name):
"""
Attach grids and validate array compatibility for a writer.
Parameters
----------
data_arrays : list of DataArray
Writer data arrays.
in_infos : dict
Input infos keyed by variable name.
ref_grid : finam.Grid
Reference grid used for the output.
writer_name : str
Name used in compatibility error messages.
"""
for var in data_arrays:
grid = in_infos[var.name].grid
var.info_kwargs["grid"] = grid
var.set_association_by_grid()
if not var.association == "field" and not _grids_compatible(grid, ref_grid):
raise ValueError(f"{writer_name}: inputs have incompatible grids.")
def _set_mesh_data(pv_mesh, data_arrays, data, is_structured):
"""
Set field, point, or cell data on a PyVista mesh.
Parameters
----------
pv_mesh : pyvista.DataSet
Target mesh.
data_arrays : list of DataArray
Writer data arrays.
data : dict
Input data keyed by variable name.
is_structured : bool
Whether the mesh is structured and needs canonical reshaping.
"""
for var in data_arrays:
grid = var.get_grid()
out = fm.data.strip_time(data[var.name], grid).magnitude
if not var.association == "field" and is_structured:
out = grid.to_canonical(out).reshape(-1, order="F")
if var.association == "field":
pv_mesh.field_data[var.name] = out
elif var.association == "cell":
pv_mesh.cell_data[var.name] = out
else:
pv_mesh.point_data[var.name] = out
def _timestep_path(file_path, file_name, step_counter, file_digits, file_ext):
"""
Build the output path for a timestep VTK file.
Parameters
----------
file_path : pathlib.Path
Base file path without the numeric suffix.
file_name : str
File stem used for the output files.
step_counter : int
Current timestep index.
file_digits : int
Width of the numeric suffix.
file_ext : str
Output file extension.
Returns
-------
pathlib.Path
The path for the timestep file.
"""
return file_path.with_name(f"{file_name}{step_counter:0{file_digits}}").with_suffix(
file_ext
)
[docs]
class DataArray:
"""
Specifications for a VTK data array.
Parameters
----------
name : str
Data array name in the VTK file.
association : str, optionl
Indicate how data is associated ("point", "cell", or "field").
Can be None to be determined.
**info_kwargs
Optional keyword arguments to instantiate an Info object (i.e. 'grid' and 'meta')
Used to overwrite meta data, to change units or to provide a desired grid specification.
"""
def __init__(self, name, association=None, **info_kwargs):
self.name = name
self.association = association
self.info_kwargs = info_kwargs
[docs]
def get_grid(self):
"""Get the grid of this data array."""
return self.info_kwargs.get("grid", None)
[docs]
def set_association_by_grid(self, grid=None):
"""Set the data association by a grid specification"""
grid = self.get_grid() if grid is None else grid
if grid is None:
raise ValueError("DataArray: no grid specified to determine association.")
if isinstance(grid, fm.NoGrid):
self.association = "field"
elif grid.data_location == fm.Location.CELLS:
self.association = "cell"
else:
self.association = "point"
def __repr__(self):
name, association = self.name, self.association
return f"DataArray({name=}, {association=}, **{self.info_kwargs})"
def create_data_array_list(data_arrays):
"""
Create a list of DataArray instances.
Parameters
----------
data_arrays : list of str or DataArray
List containing DataArray instances or names.
Returns
-------
list of DataArray
List containing only DataArray instances.
"""
return [
arr if isinstance(arr, DataArray) else DataArray(arr) for arr in data_arrays
]
def extract_data_arrays(mesh, data_arrays=None):
"""
Extract the data_array information from a pyvista mesh.
Parameters
----------
mesh : pyvista.DataSet
Pyvista mesh.
data_arrays : list of DataArray or str, optional
List of desired data_arrays given by name or a :class:`DataArray` instance.
By default, all data_arrays present in the mesh.
Returns
-------
data_arrays : list of DataArray
Data array information.
"""
if data_arrays is None:
data_arrays = []
for name in mesh.field_data:
data_arrays.append(DataArray(name=name, association="field"))
for name in mesh.point_data:
data_arrays.append(DataArray(name=name, association="point"))
for name in mesh.cell_data:
data_arrays.append(DataArray(name=name, association="cell"))
else:
data_arrays = create_data_array_list(data_arrays)
for var in data_arrays:
association = ASSOCIATION[mesh.get_array_association(var.name)]
if var.association is None:
var.association = association
elif var.association != association:
msg = (
f"{var.name}: data is associated with '{association}', "
f"but '{var.association}' was given."
)
raise ValueError(msg)
return data_arrays
def needs_masking(data):
"""
Whether data has non-finite values (infs, nans).
Parameters
----------
data : arraylike
Data to check
Returns
-------
bool
Finity status.
"""
return np.any(~(np.isfinite(data)))
def _is_int(value, **kwargs):
return np.isclose(value, np.around(value), **kwargs)
def get_time_unit(step, time_unit=None):
"""
Determine time units from time step delta.
Parameters
----------
step : timedelta
The desired time step
time_unit : str or None, optional
The desired time unit. If not given, will be determined.
By default : None
Returns
-------
str
The time unit (days, hours, minutes, seconds)
Raises
------
ValueError
if provided step is not compatible with supported units
ValueError
if provided time unit is not supported
ValueError
if provided time unit is not compatible with provided step
"""
ref_unit = None
for unit in TIME_UNITS:
if _is_int(step / TIME_DELTAS[unit]):
ref_unit = unit
break
if ref_unit is None:
raise ValueError("VTK: time step is not compatible with supported units.")
if time_unit is not None:
if time_unit not in TIME_UNITS:
raise ValueError(f"VTK: time unit '{time_unit}' not supported.")
if TIME_UNITS.index(time_unit) < TIME_UNITS.index(ref_unit):
raise ValueError(
f"VTK: time unit '{time_unit}' not compatible with time step."
)
return time_unit
return ref_unit
class _DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime):
return o.isoformat()
return super().default(o)
def save_dict_to_json(file_path, **data_dict):
"""
Save a dictionary with datetime objects to a JSON file.
Parameters
----------
file_path : pathlike
The path to the JSON file.
data_dict : dict
The dictionary to save.
"""
with open(file_path, "w", encoding="utf-8") as json_file:
json.dump(data_dict, json_file, cls=_DateTimeEncoder)
def _datetime_parser(dct):
for key, value in dct.items():
if key == "reference_date":
try:
dct[key] = datetime.fromisoformat(value)
except (ValueError, TypeError):
pass
return dct
[docs]
def read_aux_file(file_path):
"""
Load a dictionary from a JSON file and convert 'reference_date' strings back to datetime objects.
Parameters
----------
file_path : pathlike
The path to the JSON file.
Returns
-------
dict
The loaded dictionary with datetime objects for 'reference_date'.
"""
with open(file_path, "r", encoding="utf-8") as json_file:
return json.load(json_file, object_hook=_datetime_parser)
def write_pvd_file(path, vtk_files, time_steps, time_units=None):
"""
Generates a PVD file linking VTK files with their corresponding time steps.
Parameters
----------
path : pathlike
Name of the PVD file to be created.
vtk_files : list of str
List of paths to VTK files.
time_steps : list of float
List of time steps corresponding to each VTK file.
time_units : str, optional
The units of measurement for time steps added as attribute to the collection.
Follows CF convention and should be like "days since 2023-09-20 00:00:00".
Default is None.
"""
if len(vtk_files) != len(time_steps):
raise ValueError("The number of VTK files must match the number of time steps.")
header = '<?xml version="1.0"?>\n'
pvd_start = '<VTKFile type="Collection" version="0.1">\n'
if time_units:
pvd_col = f'<Collection time_units="{time_units}">\n'
else:
pvd_col = "<Collection>\n"
pvd_end = "</Collection>\n</VTKFile>\n"
entries = []
for file_name, time in zip(vtk_files, time_steps):
entry = f' <DataSet timestep="{time}" file="{file_name}" />\n'
entries.append(entry)
content = header + pvd_start + pvd_col + "".join(entries) + pvd_end
with open(path, "w", encoding="utf-8") as f:
f.write(content)