"""ADMM core — generic coordinator and message types.
Provides the :class:`ADMMGenericCoordinator` which drives the standard
Alternating Direction Method of Multipliers iteration loop. Concrete
global-actor implementations live in :mod:`.consensus_admm` and
:mod:`.sharing_admm`; the local actor lives in :mod:`.flex_actor`.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
from ..core import Coordinator
if TYPE_CHECKING:
from ...carrier.core import Carrier
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Message types
# ---------------------------------------------------------------------------
[docs]
@dataclass
class ADMMStart:
"""Sent to the coordinator to begin a new ADMM run.
:param data: Algorithm-specific input (e.g. :class:`ADMMSharingData` or a
target vector).
:param solution_length: Number of decision variables per participant.
"""
data: Any
solution_length: int
[docs]
@dataclass
class ADMMMessage:
"""Sent by the coordinator to each participant to request an x-update.
:param v: Scaled consensus/sharing vector (the local QP reference point).
:param rho: ADMM penalty parameter.
"""
v: np.ndarray
rho: float
[docs]
@dataclass
class ADMMAnswer:
"""Reply from a participant after solving its local update.
:param x: Local solution vector.
:param aux: Optional follower-side scalar/data (e.g. a per-step move
magnitude) surfaced to the global actor's convergence hook. ``None``
for variants that converge on the primal/dual residuals alone.
"""
x: np.ndarray
aux: Any = None
# ---------------------------------------------------------------------------
# Abstract global-actor interface
# ---------------------------------------------------------------------------
[docs]
class ADMMGlobalActor(ABC):
"""Interface for the coordinator-side global update in ADMM variants."""
[docs]
@abstractmethod
def z_update(
self,
input_data: Any,
x: list[np.ndarray],
u: Any,
z: Any,
rho: float,
n: int,
) -> Any:
"""Compute the new global *z* from the current *x* and *u*."""
[docs]
@abstractmethod
def u_update(
self,
x: list[np.ndarray],
u: Any,
z: Any,
rho: float,
n: int,
) -> Any:
"""Update the dual variable *u*."""
[docs]
@abstractmethod
def init_z(self, n: int, m: int) -> Any:
"""Initialise *z* (called once before the iteration loop)."""
[docs]
@abstractmethod
def init_u(self, n: int, m: int) -> Any:
"""Initialise *u* (called once before the iteration loop)."""
[docs]
@abstractmethod
def actor_correction(
self,
x: list[np.ndarray],
z: Any,
u: Any,
i: int,
) -> np.ndarray:
"""Compute the correction vector sent to participant *i* (0-indexed)."""
[docs]
@abstractmethod
def primal_residual(self, x: list[np.ndarray], z: Any) -> float:
"""Compute the primal residual used for convergence checking."""
# ---- optional hooks (defaults preserve standard ADMM behaviour) ----
[docs]
def dual_residual(self, z: Any, z_old: Any, rho: float) -> float:
"""Dual residual for convergence. Default: ``rho * max||z - z_old||``."""
return float(rho * _max_diff_norm(z, z_old))
[docs]
def should_stop(
self,
primal_res: float,
dual_res: float,
aux: list[Any],
abs_tol: float,
) -> bool | None:
"""Convergence override. Return ``True``/``False`` to decide directly,
or ``None`` to fall back to the coordinator's eps_pri/eps_dual test."""
return None
[docs]
def adapt_rho(
self,
primal_res: float,
dual_res: float,
rho: float,
u: Any,
) -> tuple[float, Any]:
"""Optionally rebalance ``rho`` (rescaling the scaled dual ``u``
inversely). Default: leave both unchanged."""
return rho, u
[docs]
class ADMMGlobalObjective(ABC):
"""Optional global objective (currently informational only)."""
[docs]
@abstractmethod
def objective(
self,
x: list[np.ndarray],
u: Any,
z: Any,
n: int,
) -> float:
"""Evaluate the global objective."""
# ---------------------------------------------------------------------------
# Helper: max-norm over list-of-arrays or single array
# ---------------------------------------------------------------------------
def _max_norm(v: Any) -> float:
"""Return ``max ||v_i||`` if *v* is a list, else ``max |v_j|`` for a vector."""
if isinstance(v, list):
return float(max(float(np.linalg.norm(vi)) for vi in v))
return float(np.max(np.abs(v)))
def _max_diff_norm(a: Any, b: Any) -> float:
"""``max ||a_i - b_i||`` for lists or ``max |a_j - b_j|`` for arrays."""
if isinstance(a, list):
return float(max(float(np.linalg.norm(ai - bi)) for ai, bi in zip(a, b)))
return float(np.max(np.abs(a - b)))
def _deepcopy_z(z: Any) -> Any:
if isinstance(z, list):
return [np.copy(zi) for zi in z]
return np.copy(z)
# ---------------------------------------------------------------------------
# Generic ADMM coordinator
# ---------------------------------------------------------------------------
[docs]
class ADMMGenericCoordinator(Coordinator):
"""Standard ADMM iteration loop.
Each round:
1. Send :class:`ADMMMessage` (correction + ρ) to all participants in
parallel and await :class:`ADMMAnswer` from each.
2. Global *z*-update via :meth:`~ADMMGlobalActor.z_update`.
3. Dual *u*-update via :meth:`~ADMMGlobalActor.u_update`.
4. Check primal and dual residuals against tolerances; stop if met.
:param global_actor: Variant-specific global update logic.
:param rho: ADMM penalty parameter (default: 1.0).
:param max_iters: Maximum number of iterations (default: 1000).
:param abs_tol: Absolute convergence tolerance (default: 1e-4).
:param rel_tol: Relative convergence tolerance (default: 1e-3).
"""
def __init__(
self,
global_actor: ADMMGlobalActor,
rho: float = 1.0,
max_iters: int = 1000,
abs_tol: float = 1e-4,
rel_tol: float = 1e-3,
) -> None:
self.global_actor = global_actor
self.rho = rho
self.max_iters = max_iters
self.abs_tol = abs_tol
self.rel_tol = rel_tol
[docs]
async def start_optimization(
self,
carrier: "Carrier",
message_data: ADMMStart,
meta: Any,
) -> list[np.ndarray]:
x, _z, _u = await self._run(carrier, message_data.data, message_data.solution_length)
return x
async def _run(
self,
carrier: "Carrier",
input_data: Any,
m: int,
*,
x_init: list[np.ndarray] | None = None,
) -> tuple[list[np.ndarray], Any, Any]:
"""Core ADMM loop.
:param carrier: Coordinator's carrier.
:param input_data: Algorithm-specific data (target, priorities, …).
:param m: Problem dimension (number of decision variables).
:param x_init: Optional warm-start primal list (e.g. carried over
from a previous round); defaults to per-participant zeros.
:returns: ``(x_list, z, u)`` at convergence or max-iter.
"""
actor = self.global_actor
rho = self.rho
participant_addrs = carrier.others("coordinator")
n = len(participant_addrs)
x: list[np.ndarray] = (
[np.array(xi, dtype=float) for xi in x_init]
if x_init is not None
else [np.zeros(m) for _ in range(n)]
)
z = actor.init_z(n, m)
u = actor.init_u(n, m)
for k in range(1, self.max_iters + 1):
# 1. Send ADMMMessage to all participants in parallel, collect futures
futures: list[asyncio.Future] = []
for i, addr in enumerate(participant_addrs):
correction = actor.actor_correction(x, z, u, i)
fut = carrier.send_awaitable(ADMMMessage(v=correction, rho=rho), addr)
futures.append(fut)
# Await all replies simultaneously
replies = await asyncio.gather(*futures)
aux = [getattr(reply, "aux", None) for reply in replies]
for i, reply in enumerate(replies):
x[i] = np.asarray(reply.x, dtype=float)
# 2. z-update
z_old = _deepcopy_z(z)
z = actor.z_update(input_data, x, u, z, rho, n)
# 3. u-update
u = actor.u_update(x, u, z, rho, n)
# 4. Convergence check — actor may override the eps_pri/eps_dual test
r_norm = actor.primal_residual(x, z)
s_norm = actor.dual_residual(z, z_old, rho)
stop = actor.should_stop(r_norm, s_norm, aux, self.abs_tol)
if stop is None:
eps_pri = np.sqrt(m * n) * self.abs_tol + self.rel_tol * max(
_max_norm(x), _max_norm(z)
)
eps_dual = np.sqrt(m * n) * self.abs_tol + self.rel_tol * _max_norm(u)
stop = bool(r_norm < eps_pri and s_norm < eps_dual)
if stop:
logger.debug("ADMM converged in %d iterations.", k)
break
# 5. Optional penalty re-balancing (no-op unless the actor adapts).
rho, u = actor.adapt_rho(r_norm, s_norm, rho, u)
if k == self.max_iters:
logger.warning(
"ADMM reached max iterations (%d) without full convergence (r=%.4g, s=%.4g).",
self.max_iters,
r_norm,
s_norm,
)
self.rho = rho # persist adapted penalty for callers reusing this loop
return x, z, u
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
[docs]
def create_admm_start(data: Any, length: int | None = None) -> ADMMStart:
"""Create an :class:`ADMMStart` message.
When *length* is omitted the length is inferred from ``data.solution_length``
or from ``len(data.target)`` (for :class:`.sharing_admm.ADMMSharingData`).
"""
if length is not None:
return ADMMStart(data=data, solution_length=length)
# Try to infer from data
if hasattr(data, "solution_length"):
return ADMMStart(data=data, solution_length=data.solution_length)
if hasattr(data, "target"):
return ADMMStart(data=data, solution_length=len(data.target))
raise ValueError("Cannot infer solution_length; pass it explicitly as the second argument.")