Source code for sweepexp.sweepexp
"""Main entry point for the sweepexp package."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import sweepexp as se
if TYPE_CHECKING: # pragma: no cover
from collections.abc import Callable
from pathlib import Path
[docs]
def sweepexp(
func: Callable,
parameters: dict[str, list],
mode: Literal["sequential", "parallel", "mpi"] = "sequential",
save_path: Path | str | None = None,
**kwargs: dict,
) -> se.SweepExp | se.SweepExpMPI | se.SweepExpParallel:
"""
Create a new instance of the SweepExp class.
Parameters
----------
func : Callable
The experiment function to run. The function should take the parameters
as keyword arguments and return a dictionary with the return values.
parameters : dict[str, list]
The parameters to sweep over. The keys are the parameter names and the
values are lists of the parameter values.
mode : "sequential" | "parallel" | "mpi", default="sequential"
The mode to run the experiments in.
save_path : Path | str | None (optional)
The path to save the results to. Supported file formats are: '.zarr',
'.nc', '.cdf', '.pkl'. The '.zarr' and '.nc' formats only support
numeric and boolean data. Only the '.pkl' format supports saving data
of any type.
**kwargs : dict
Additional settings:
- `timeit`: bool, measure the duration of each experiment.
- `auto_save`: bool, automatically save the results after each experiment.
- `enable_priorities`: bool, run experiments with higher priority first.
- `pass_uuid`: bool, pass a unique identifier to the experiment function.
Returns
-------
SweepExp | SweepExpMPI | SweepExpParallel
An instance of the appropriate SweepExp class based on the mode.
Examples
--------
.. code-block:: python
from sweepexp import sweepexp
# Create a simple experiment function
def my_experiment(x: int, y: float) -> dict:
return {"sum": x + y, "product": x * y}
# Initialize the sweepexp object
sweep = sweepexp(
func=my_experiment,
parameters={"x": [1, 2, 3], "y": [4, 5, 6]},
)
# Run the sweep
sweep.run()
"""
# update the kwargs with the provided parameters
kwargs.update({
"func": func,
"parameters": parameters,
"save_path": save_path,
})
if mode == "mpi":
try:
return se.SweepExpMPI(**kwargs)
except ImportError:
msg = "Failed to import 'mpi4py'. "
msg += "Fallback to 'parallel' mode."
se.log.warning(msg)
return se.SweepExpParallel(**kwargs)
if mode == "parallel":
return se.SweepExpParallel(**kwargs)
if mode == "sequential":
return se.SweepExp(**kwargs)
msg = f"Unknown mode '{mode}'. "
msg += "Supported modes are: 'sequential', 'parallel', 'mpi'."
raise ValueError(msg)