"""Running the experiments in parallel using mpi."""
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Literal
from mpi4py import MPI
from sweepexp import SweepExp, log
if TYPE_CHECKING: # pragma: no cover
import xarray as xr
MAIN_RANK = 0
MY_RANK = MPI.COMM_WORLD.Get_rank()
IS_MAIN_RANK = MY_RANK == MAIN_RANK
WAIT_TIME = 0.05 # 50 ms
[docs]
class SweepExpMPI(SweepExp):
"""
Run a parameter sweep in parallel using MPI.
Parameters
----------
func : Callable
The experiment function to run. The function should take the parameters
as keyword arguments and return a dictionary with the return values.
parameters : dict[str, list]
The parameters to sweep over. The keys are the parameter names and the
values are lists of the parameter values.
save_path : Path | str | None
The path to save the results to. Supported file formats are: '.zarr',
'.nc', '.cdf', '.pkl'. The '.zarr' and '.nc' formats only support
numeric and boolean data. Only the '.pkl' format supports saving data
of any type.
Description
-----------
The SweepExpMPI class can be used to run a custom experiment function with
different parameter combinations. The results of the experiments are saved
as an xarray dataset. The dataset can be saved to a file and loaded later
to continue the experiments. All parameter combinations are run
in parallel using mpi4py.
SweepExp supports the following additional features:
- Custom arguments: Add custom arguments to the experiment function.
- UUID: Pass a unique identifier to the experiment function.
- Auto save: Automatically save the results after each experiment.
- Timeit: Measure the duration of each experiment.
- Priorities: Run experiments with higher priority first.
Examples
--------
.. code-block:: python
:caption: my_experiment.py
from sweepexp import SweepExpParallel
# Create a simple experiment function
def my_experiment(x: int, y: float) -> dict:
return {"sum": x + y, "product": x * y}
# Initialize the SweepExp object
sweep = SweepExpParallel(
func=my_experiment,
parameters={"x": [1, 2, 3], "y": [4, 5, 6]},
)
# Run the sweep
sweep.run()
To run the experiment on 4 CPUs using MPI, use the following command:
.. code-block:: bash
mpiexec -n 4 python3 my_experiment.py
Or, alternatively any other MPI launcher, e.g., `mpirun`, `srun`, etc.
"""
# Override save and load methods to only save and load on the main rank
[docs]
def save(self, mode: Literal["x", "w"] = "x") -> None: # noqa: D102
if IS_MAIN_RANK:
super().save(mode=mode)
def _load_data_from_file(self) -> xr.Dataset:
if IS_MAIN_RANK:
return super()._load_data_from_file()
# If this is not the main rank, we don't load the data (just create as usual)
return self._create_data()
[docs]
def run(self, # noqa: D102
status: str | list[str] | None = "N",
max_workers: int | None = None,
) -> xr.Dataset:
if max_workers is not None and IS_MAIN_RANK:
msg = f"Argument 'max_workers={max_workers}' has no effect in "
msg += "mode=mpi. "
msg += "Use the 'mode=parallel' argument to run the sweep in parallel."
log.warning(msg)
# Check that at least two ranks are available
min_size = 2
if MPI.COMM_WORLD.Get_size() < min_size:
msg = "At least two ranks are required to run the sweep."
raise ValueError(msg)
if not IS_MAIN_RANK:
self._handle_jobs()
else:
self._manage_jobs(status)
return self.data
# ----------------------------------------------------------------
# Methods for the main rank
# ----------------------------------------------------------------
def _manage_jobs(self, status: str | list[str] | None) -> None:
# Create a list of all experiments that need to be run
indices = self._get_indices(status)
number_of_experiments = len(indices[0])
log.info(f"Found {number_of_experiments} experiments to run.")
# Sort the experiments based on the priorities
indices = self._sort_indices(indices)
# Create a job for each experiment
jobs = list(zip(*indices, strict=True))
free_workers = list(range(1, MPI.COMM_WORLD.Get_size()))
active_jobs = []
while jobs or active_jobs:
# Start new jobs
while jobs and free_workers:
index = jobs.pop(0)
kwargs = self._get_kwargs(index)
worker = free_workers.pop(0)
MPI.COMM_WORLD.send(kwargs, dest=worker)
active_jobs.append((worker, index))
# Check if any of the active jobs have finished
for worker, index in active_jobs:
# Check if the process is still alive
if not MPI.COMM_WORLD.Iprobe(source=worker):
continue
result = MPI.COMM_WORLD.recv(source=worker)
self._handle_finished_job(index, result)
free_workers.append(worker)
active_jobs.remove((worker, index))
# log the number of remaining jobs
log.debug(f"Number of remaining jobs: {len(jobs) + len(active_jobs)}")
# Sleep for a short time to prevent busy waiting
if ( (jobs and not free_workers) or
(not jobs and active_jobs) ):
time.sleep(WAIT_TIME)
# Send a signal to the workers to stop
for worker in free_workers:
MPI.COMM_WORLD.send(None, dest=worker)
def _handle_finished_job(self, index: tuple[int, ...], result: tuple) -> None:
# unpack the result
return_values, status, duration = result
# Set the status and return values of the experiment
self._set_status_at(index, status)
self._set_return_values_at(index, return_values)
# Set the duration of the experiment
if self.timeit:
self._set_duration_at(index, duration)
log.debug(f"Experiment took {duration:.2f} seconds.")
# Save the results (if enabled)
if self.auto_save:
self.save(mode="w")
# ----------------------------------------------------------------
# Methods for workers
# ----------------------------------------------------------------
def _handle_jobs(self) -> None:
"""Handle the jobs."""
while True:
# Receive the index of the parameter combination to be run
kwargs = MPI.COMM_WORLD.recv(source=MAIN_RANK)
if kwargs is None:
break
# Run the experiment and send the results back to the main rank
MPI.COMM_WORLD.send(self._run_experiment(kwargs), dest=MAIN_RANK)
def _run_experiment(self, kwargs: dict[str: any]) -> tuple:
"""Run a single experiment."""
log.debug(f"Rank {MY_RANK} - Starting: {kwargs}")
if self.timeit:
start_time = time.time()
try:
return_values = self.func(**kwargs)
status = "C"
except Exception as error: # noqa: BLE001
log.error(f"Error in experiment {kwargs}: {error}")
return_values = {}
status = "F"
# Calculate the duration of the experiment
duration = time.time() - start_time if self.timeit else 0
log.debug(f"Rank {MY_RANK} - Finished: {kwargs}")
return return_values, status, duration