import os
import pathlib
from typing import Any, Protocol, TypedDict
import numpy as np
from numpy.typing import NDArray
from orbit.core import orbit_mpi
from orbit.core.bunch import Bunch
class SyncPartDict(TypedDict):
coords: NDArray[np.float64]
kin_energy: np.float64
momentum: np.float64
beta: np.float64
gamma: np.float64
time: np.float64
[docs]class BunchDict(TypedDict):
coords: NDArray[np.float64]
sync_part: SyncPartDict
attributes: dict[str, np.float64 | np.int32]
[docs]class FileHandler(Protocol):
"""Protocol for file handlers to read/write bunch data."""
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
[docs] def read(self) -> BunchDict: ...
[docs] def write(self, bunch: BunchDict) -> None: ...
[docs]class NumPyHandler:
"""Handler implementing the FileHandler protocol for NumPy binary files.
This handler will create two files in the directory passed to the constructor:
- coords.npy: A memory-mapped NumPy array containing the bunch coordinates.
- attributes.npz: A NumPy archive containing data related to the synchronous particle and other bunch attributes.
"""
_coords_fname = "coords.npy"
_attributes_fname = "attributes.npz"
def __init__(self, dir_name: str | pathlib.Path):
if isinstance(dir_name, str):
dir_name = pathlib.Path(dir_name)
self._dir_name = dir_name
self._coords_path = dir_name / self._coords_fname
self._attributes_path = dir_name / self._attributes_fname
[docs] def read(self) -> BunchDict:
if not self._coords_path.exists() or not self._attributes_path.exists():
raise FileNotFoundError(
f"Required files not found in directory: {self._dir_name}"
)
coords = np.load(self._coords_path, mmap_mode="r")
attr_data = np.load(self._attributes_path, allow_pickle=True)
sync_part = attr_data["sync_part"].item()
attributes = attr_data["attributes"].item()
return BunchDict(coords=coords, sync_part=sync_part, attributes=attributes)
[docs] def write(self, bunch: BunchDict) -> None:
self._dir_name.mkdir(parents=True, exist_ok=True)
np.save(self._coords_path, bunch["coords"])
np.savez(
self._attributes_path,
sync_part=bunch["sync_part"],
attributes=bunch["attributes"],
)
[docs]def collect_bunch(
bunch: Bunch, output_dir: str | pathlib.Path = "/tmp", return_memmap: bool = True
) -> BunchDict | None:
"""Collects attributes from a PyOrbit Bunch across all MPI ranks and returns it as a dictionary.
Parameters
----------
bunch : Bunch
The PyOrbit::Bunch object from which to collect attributes.
output_dir : str | pathlib.Path, optional
The director to use for temporary storage of the bunch coordinates on each MPI rank.
If None, the bunch will be stored in "/tmp".
Note: take care that the temporary files are created in a directory where all MPI ranks have write access.
return_memmap : bool, optional
Return the bunch coordinates as a memory-mapped NumPy array, otherwise the
entire array is copied into RAM and returned as normal NDArray. Default is True.
Returns
-------
BunchDict | None
A dictionary containing the collected bunch attributes. Returns None if not on the root MPI rank or if the global bunch size is 0.
Raises
------
FileNotFoundError
If the temporary files created by non-root MPI ranks could not be found by the root rank during
the collection process.
"""
global_size = bunch.getSizeGlobal()
if global_size == 0:
return None
mpi_comm = bunch.getMPIComm()
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
coords_shape = (bunch.getSizeGlobal(), 6)
local_rows = bunch.getSize()
if isinstance(output_dir, str):
output_dir = pathlib.Path(output_dir)
fname = output_dir / f"collect_bunch_tmpfile_{mpi_rank}.dat"
local_shape = (local_rows, coords_shape[1])
dtype = np.float64
coords_memmap = np.memmap(fname, dtype=dtype, mode="w+", shape=local_shape)
for i in range(local_rows):
coords_memmap[i, :] = (
bunch.x(i),
bunch.xp(i),
bunch.y(i),
bunch.yp(i),
bunch.z(i),
bunch.dE(i),
)
coords_memmap.flush()
bunch_dict: BunchDict = {"coords": None, "sync_part": {}, "attributes": {}}
if mpi_rank == 0:
sync_part = bunch.getSyncParticle()
bunch_dict["sync_part"] |= {
"coords": np.array(sync_part.pVector()),
"kin_energy": np.float64(sync_part.kinEnergy()),
"momentum": np.float64(sync_part.momentum()),
"beta": np.float64(sync_part.beta()),
"gamma": np.float64(sync_part.gamma()),
"time": np.float64(sync_part.time()),
}
for attr in bunch.bunchAttrDoubleNames():
bunch_dict["attributes"][attr] = np.float64(bunch.bunchAttrDouble(attr))
for attr in bunch.bunchAttrIntNames():
bunch_dict["attributes"][attr] = np.int32(bunch.bunchAttrInt(attr))
orbit_mpi.MPI_Barrier(mpi_comm)
if mpi_rank != 0:
return None
coords_memmap = np.memmap(fname, dtype=dtype, mode="r+", shape=coords_shape)
start_row = local_rows
for r in range(1, orbit_mpi.MPI_Comm_size(mpi_comm)):
src_fname = output_dir / f"collect_bunch_tmpfile_{r}.dat"
if not os.path.exists(src_fname):
raise FileNotFoundError(
f"Expected temporary file '{src_fname}' not found. Something went wrong."
)
src_memmap = np.memmap(src_fname, dtype=dtype, mode="r")
src_memmap = src_memmap.reshape((-1, coords_shape[1]))
stop_row = start_row + src_memmap.shape[0]
coords_memmap[start_row:stop_row, :] = src_memmap[:, :]
coords_memmap.flush()
del src_memmap
os.remove(src_fname)
start_row = stop_row
bunch_dict["coords"] = coords_memmap if return_memmap else np.array(coords_memmap)
return bunch_dict
[docs]def save_bunch(
bunch: Bunch | BunchDict,
output_dir: str | pathlib.Path = "bunch_data/",
Handler: type[FileHandler] = NumPyHandler,
) -> None:
"""Saves the collected bunch attributes to a specified directory.
Parameters
----------
bunch_dict : Bunch | BunchDict
The PyOrbit::Bunch object or the dictionary containing the collected bunch attributes.
output_dir : str, optional
The directory where the bunch data files will be saved. Default is "bunch_data/".
Handler : FileHandler, optional
The file handler class to use for writing the bunch data. Default is NumPyHandler.
Returns
-------
None
Raises
------
ValueError
If the provided `bunch` is neither a Bunch instance nor a BunchDict.
"""
if isinstance(bunch, Bunch):
mpi_comm = bunch.getMPIComm()
bunch = collect_bunch(bunch)
else:
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
if mpi_rank != 0 or bunch is None:
return
if bunch["coords"].shape[0] == 0:
print("No particles in the bunch to save.")
return
if isinstance(output_dir, str):
output_dir = pathlib.Path(output_dir)
handler = Handler(output_dir)
handler.write(bunch)
[docs]def load_bunch(
input_dir: str | pathlib.Path, Handler: type[FileHandler] = NumPyHandler
) -> tuple[Bunch, BunchDict]:
"""Loads the bunch attributes from a specified directory containing NumPy binary files.
Parameters
----------
input_dir : str | pathlib.Path
The directory from which to load the bunch data files.
Handler : FileHandler, optional
The file handler class to use for reading the bunch data. Default is NumPyHandler.
See `orbit.bunch_utils.file_handler` for available handlers.
Returns
-------
BunchDict
A dictionary containing the loaded bunch attributes.
Raises
------
FileNotFoundError
If the required files are not found in the specified directory.
TypeError
If an attribute in the loaded bunch has an unsupported type.
"""
mpi_comm = orbit_mpi.mpi_comm.MPI_COMM_WORLD
mpi_rank = orbit_mpi.MPI_Comm_rank(mpi_comm)
mpi_size = orbit_mpi.MPI_Comm_size(mpi_comm)
handler = Handler(input_dir)
bunch_dict = handler.read()
coords = bunch_dict["coords"]
global_size = coords.shape[0]
local_size = global_size // mpi_size
remainder = global_size % mpi_size
if mpi_rank < remainder:
local_size += 1
start_row = mpi_rank * local_size
else:
start_row = mpi_rank * local_size + remainder
stop_row = start_row + local_size
local_coords = coords[start_row:stop_row, :]
bunch = Bunch()
for i in range(local_size):
bunch.addParticle(*local_coords[i, :])
for attr, value in bunch_dict["attributes"].items():
if np.issubdtype(value, np.floating):
bunch.bunchAttrDouble(attr, value)
elif np.issubdtype(value, np.integer):
bunch.bunchAttrInt(attr, value)
else:
raise TypeError(f"Unsupported attribute type for '{attr}': {type(value)}")
sync_part_obj = bunch.getSyncParticle()
sync_part_obj.rVector(tuple(bunch_dict["sync_part"]["coords"]))
sync_part_obj.kinEnergy(bunch_dict["sync_part"]["kin_energy"])
sync_part_obj.time(bunch_dict["sync_part"]["time"])
return bunch, bunch_dict