Source code for mcframework.backends.torch_mps
r"""
Torch MPS (Metal Performance Shaders) backend for Apple Silicon.
This module provides:
Classes
:class:`TorchMPSBackend` — GPU-accelerated batch execution on Apple Silicon
Functions
:func:`is_mps_available` — Check MPS availability
:func:`validate_mps_device` — Validate MPS is usable
The MPS backend enables GPU-accelerated Monte Carlo simulations on
Apple Silicon Macs (M1/M2/M3/M4) using Metal Performance Shaders.
Notes
-----
**MPS determinism caveat.** Torch MPS preserves RNG stream structure but does
not guarantee bitwise reproducibility due to Metal backend scheduling and
float32 arithmetic. Statistical properties (mean, variance, CI coverage)
remain correct.
**Dtype policy.** MPS performs best with float32. Sampling uses float32,
but results are promoted to float64 on CPU before returning to ensure
stats engine precision.
**System requirements:**
- macOS 12.3 (Monterey) or later
- Apple Silicon (M1, M2, M3, M4 series)
- PyTorch built with MPS support
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable
import numpy as np
from .torch_base import import_torch, make_torch_generator
if TYPE_CHECKING:
from ..simulation import MonteCarloSimulation
logger = logging.getLogger(__name__)
__all__ = [
"TorchMPSBackend",
"is_mps_available",
"validate_mps_device",
]
[docs]
def is_mps_available() -> bool:
"""
Check if MPS (Metal Performance Shaders) is available.
Returns
-------
bool
True if MPS is available and PyTorch was built with MPS support.
Examples
--------
>>> if is_mps_available():
... backend = TorchMPSBackend() # doctest: +SKIP
"""
try:
th = import_torch()
return (
hasattr(th.backends, "mps")
and th.backends.mps.is_available()
and th.backends.mps.is_built()
)
except ImportError:
return False
[docs]
def validate_mps_device() -> None:
"""
Validate that MPS device is available and usable.
Raises
------
ImportError
If PyTorch is not installed.
RuntimeError
If MPS is not available or not built into PyTorch.
Examples
--------
>>> validate_mps_device() # doctest: +SKIP
"""
th = import_torch()
if not th.backends.mps.is_available():
raise RuntimeError(
"MPS device requested but not available. "
"MPS requires macOS 12.3+ with Apple Silicon (M1/M2/M3/M4)."
)
if not th.backends.mps.is_built():
raise RuntimeError(
"MPS device requested but PyTorch was not built with MPS support. "
"Reinstall PyTorch with MPS support enabled."
)
[docs]
class TorchMPSBackend:
r"""
Torch MPS batch execution backend for Apple Silicon GPUs.
Uses PyTorch with MPS (Metal Performance Shaders) backend for GPU-accelerated
execution on Apple Silicon Macs and leverage unified memory architecture.
Requires simulations to implement :meth:`~mcframework.core.MonteCarloSimulation.torch_batch` and
set :attr:`~mcframework.simulation.MonteCarloSimulation.supports_batch` to ``True`` to
enable Metal Performance Shaders GPU-accelerated batch execution.
Notes
-----
**RNG architecture.** Uses explicit :class:`~torch.Generator` objects seeded from
:class:`~numpy.random.SeedSequence` via :meth:`~numpy.random.SeedSequence.spawn`. This preserves:
- Deterministic parallel streams (best-effort on MPS)
- Counter-based RNG (Philox) semantics
- Correct statistical structure
**Never uses** :meth:`~torch.Generator.manual_seed` (global state).
**Dtype policy.** MPS performs best with :meth:`~torch.Tensor.float` (float32):
- Sampling uses :meth:`~torch.Tensor.float` (float32) on device
- Results moved to CPU and promoted to :meth:`~torch.Tensor.double` (float64).
- The framework converts the results to :class:`numpy.ndarray` of :class:`numpy.double` (float64)
for stats engine compatibility.
**MPS determinism caveat.** Torch MPS preserves RNG stream structure but
does not guarantee bitwise reproducibility due to:
- Metal backend scheduling variations
- float32 arithmetic rounding
- GPU kernel execution order
Statistical properties (mean, variance, CI coverage) remain correct
despite potential bitwise differences between runs. (see ``TestMPSDeterminism``
in ``tests/test_torch_backend.py`` for actual tests)
Examples
--------
>>> if is_mps_available():
... backend = TorchMPSBackend()
... results = backend.run(sim, n_simulations=1_000_000, seed_seq=seed_seq)
... # doctest: +SKIP
See Also
--------
:func:`is_mps_available` : Check MPS availability before instantiation.
:class:`TorchCPUBackend` : Fallback for non-Apple systems.
"""
device_type: str = "mps"
[docs]
def __init__(self):
"""
Initialize Torch MPS backend.
Raises
------
ImportError
If PyTorch is not installed.
RuntimeError
If MPS is not available on this system.
"""
validate_mps_device()
th = import_torch()
self.device = th.device("mps")
[docs]
def run(
self,
sim: "MonteCarloSimulation",
n_simulations: int,
seed_seq: np.random.SeedSequence | None,
progress_callback: Callable[[int, int], None] | None = None,
**_simulation_kwargs: Any,
) -> np.ndarray:
r"""
Run simulations using Torch MPS batch execution.
Parameters
----------
sim : MonteCarloSimulation
The simulation instance to run. Must have
:attr:`~mcframework.simulation.MonteCarloSimulation.supports_batch` = ``True``
and implement :meth:`~mcframework.core.MonteCarloSimulation.torch_batch`.
n_simulations : int
Number of simulation draws to perform.
seed_seq : SeedSequence or None
Seed sequence for reproducible random streams.
progress_callback : callable or None
Optional callback ``f(completed, total)`` for progress reporting.
**_simulation_kwargs : Any
Ignored for Torch backend (batch method handles all parameters).
Returns
-------
np.ndarray
Array of simulation results with shape ``(n_simulations,)``.
Results are float64 despite MPS using float32 internally.
Raises
------
ValueError
If the simulation does not support batch execution.
NotImplementedError
If the simulation does not implement :meth:`~mcframework.core.MonteCarloSimulation.torch_batch`.
Notes
-----
The dtype conversion flow is:
1. :meth:`~mcframework.core.MonteCarloSimulation.torch_batch` returns :meth:`~torch.Tensor.float` (float32) on MPS device.
2. :class:`~torch.Tensor` moved to CPU via :meth:`~torch.Tensor.detach` and :meth:`~torch.Tensor.cpu`
3. Promoted to :meth:`~torch.Tensor.double` (float64) via :meth:`~torch.Tensor.to`
4. Converted to :class:`~numpy.ndarray` of :class:`~numpy.double` (float64) via :meth:`~torch.Tensor.numpy`
This ensures stats engine precision while maximizing MPS performance.
"""# noqa: E501 pylint: disable=line-too-long
th = import_torch()
# Validate simulation supports batch execution
if not getattr(sim, "supports_batch", False):
raise ValueError(
f"Simulation '{sim.name}' does not support Torch batch execution. "
"Set supports_batch = True and implement torch_batch()."
)
# Create explicit generator from SeedSequence (never use global RNG)
generator = make_torch_generator(self.device, seed_seq)
logger.info(
"Computing %d simulations using Torch MPS (Apple Silicon GPU)...",
n_simulations
)
# Execute the vectorized batch with explicit generator
# torch_batch should return float32 for MPS compatibility
samples = sim.torch_batch(n_simulations, device=self.device, generator=generator)
# Move to CPU first (required before float64 conversion for MPS)
samples = samples.detach().cpu()
# Promote to float64 for stats engine precision
samples = samples.to(th.float64)
# Report completion (batch execution is atomic)
if progress_callback:
progress_callback(n_simulations, n_simulations)
# Convert to NumPy for stats engine compatibility
return samples.numpy()