Source code for sweepexp.sweepexp_parallel
"""Running the experiments in parallel using multiprocessing."""
from __future__ import annotations
import multiprocessing as mp
import time
from typing import TYPE_CHECKING
from sweepexp import SweepExp, log
if TYPE_CHECKING: # pragma: no cover
import xarray as xr
WAIT_TIME = 0.05 # 50 ms
[docs]
class SweepExpParallel(SweepExp):
"""
Run a parameter sweep in parallel using multiprocessing.
Parameters
----------
func : Callable
The experiment function to run. The function should take the parameters
as keyword arguments and return a dictionary with the return values.
parameters : dict[str, list]
The parameters to sweep over. The keys are the parameter names and the
values are lists of the parameter values.
save_path : Path | str | None
The path to save the results to. Supported file formats are: '.zarr',
'.nc', '.cdf', '.pkl'. The '.zarr' and '.nc' formats only support
numeric and boolean data. Only the '.pkl' format supports saving data
of any type.
Description
-----------
The SweepExpParallel class can be used to run a custom experiment function with
different parameter combinations. The results of the experiments are saved
as an xarray dataset. The dataset can be saved to a file and loaded later
to continue the experiments. All parameter combinations are run
in parallel using multiprocessing. The number of workers can be specified
using the 'max_workers' parameter in the 'run' method.
SweepExp supports the following additional features:
- Custom arguments: Add custom arguments to the experiment function.
- UUID: Pass a unique identifier to the experiment function.
- Auto save: Automatically save the results after each experiment.
- Timeit: Measure the duration of each experiment.
- Priorities: Run experiments with higher priority first.
Examples
--------
.. code-block:: python
from sweepexp import SweepExpParallel
# Create a simple experiment function
def my_experiment(x: int, y: float) -> dict:
return {"sum": x + y, "product": x * y}
# Initialize the SweepExp object
sweep = SweepExpParallel(
func=my_experiment,
parameters={"x": [1, 2, 3], "y": [4, 5, 6]},
)
# Run the sweep
sweep.run()
"""
[docs]
def run(self, # noqa: D102
status: str | list[str] | None = "N",
max_workers: int | None = None,
) -> xr.Dataset:
# Set the max_workers to the number of CPUs if not specified
max_workers = max_workers or mp.cpu_count()
# Create a list of all experiments that need to be run
indices = self._get_indices(status)
number_of_experiments = len(indices[0])
log.info(f"Found {number_of_experiments} experiments to run.")
# Set the experiment function
self._set_experiment_function()
# Sort the experiments based on the priorities
indices = self._sort_indices(indices)
# Create a job for each experiment
jobs = [self._create_job(index)
for index in zip(*indices, strict=True)]
self._run_jobs(jobs, max_workers)
return self.data
def _run_jobs(self, jobs: list[dict[str, any]], max_workers: int) -> None:
"""Run the list of processes in parallel."""
# Create a list to store the active processes
active_jobs = []
# Run the experiments
while jobs or active_jobs:
# Start new processes
while jobs and len(active_jobs) < max_workers:
# Remove the first process, start it and add it to the active processes
job = jobs.pop(0)
kwargs = self._get_kwargs(job["index"])
log.debug(f"Starting: {kwargs}")
job["process"].start()
active_jobs.append(job)
# Check if any of the active processes have finished
for job in active_jobs:
# Check if the process is still alive
if job["process"].is_alive():
continue
self._handle_finished_job(job)
active_jobs.remove(job)
log.debug(f"Number of remaining jobs: {len(jobs) + len(active_jobs)}")
# Sleep if there are processes left but we can't start new ones
if ( (jobs and len(active_jobs) >= max_workers) or
(not jobs and active_jobs) ):
time.sleep(WAIT_TIME)
def _handle_finished_job(self, job: dict[str, any]) -> None:
"""Handle the return values of a finished job."""
# Get the index of the experiment
index = job["index"]
# Get the return values from the queue
if self.timeit:
return_values, duration = job["queue"].get()
else:
return_values = job["queue"].get()
# Check if the return values are an exception
if isinstance(return_values, Exception):
log.error(f"Error in experiment {self._get_kwargs(index)}: {return_values}")
return_values = {}
status = "F"
else:
log.debug(f"Finished: {self._get_kwargs(index)}")
status = "C"
# Set the status and return values of the experiment
self._set_status_at(index, status)
self._set_return_values_at(index, return_values)
# Set the duration of the experiment
if self.timeit:
self._set_duration_at(index, duration)
log.debug(f"Experiment took {duration:.2f} seconds.")
# Save the results (if enabled)
if self.auto_save:
self.save(mode="w")
def _create_job(self, index: tuple[int, ...]) -> dict[str, any]:
# Get the kwargs for the experiment
kwargs = self._get_kwargs(index)
# Create a queue for the experiment
queue = mp.Queue()
# Create a process for the experiment
process = mp.Process(
target=self._exp_func,
args=(kwargs, queue),
)
return {
"process": process,
"queue": queue,
"index": index,
}
def _set_experiment_function(self) -> None:
"""Wrap the experiment function to be run in a separate process."""
def wrapper(kwargs: dict[str, any], # pragma: no cover
queue: mp.Queue) -> None:
# Save the start time if timeit is enabled
if self.timeit:
start_time = time.time()
# Try to run the experiment function
try:
return_values = self.func(**kwargs)
except Exception as e: # noqa: BLE001
return_values = e
# Save the end time if timeit is enabled
if self.timeit:
end_time = time.time()
return_values = (return_values, end_time - start_time)
# Put the return values in the queue
queue.put(return_values)
self._exp_func = wrapper