Source code for pympcc.bilevel

"""
Bilevel KKT-emitter frontend.

A bilevel program

    min_{x, y}  F(x, y)
    s.t.        y ∈ argmin_y { f(x, y) : g(x, y) ≤ 0,  h(x, y) = 0 }

becomes an MPCC by replacing the lower-level argmin with its KKT system::

    ∇_y f(x, y) + Σ λ_i ∇_y g_i(x, y) + Σ μ_k ∇_y h_k(x, y) = 0     (stationarity)
    h(x, y)  =  0                                                     (lower-eq)
    λ ≥ 0,    -g(x, y) ≥ 0,    λ ⊥ -g(x, y)                         (complementarity)

The :func:`from_lower_level` emitter does this rewrite automatically and
returns a :class:`~pympcc.problem.MPCCProblem` ready for
:func:`pympcc.solve`.

Variable layout of the emitted MPCC::

    z = [ x_upper (n_x) | y_lower (n_y) | λ (n_g_lower) | μ (n_h_lower) ]

KKT validity assumes lower-level convexity (or at least a stationary
optimum); for non-convex lower-level problems the resulting MPCC
characterises stationary points of the lower problem rather than its
global argmin.

Composes naturally with §4.5 (variable-paired complementarity) — λ enters
as a non-negative slice of ``z`` and the complementarity row ``H_i`` is
``−g_i(x, y)``.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Optional

import numpy as np

from ._typing import Derivatives
from .problem import MPCCProblem

__all__ = ["from_lower_level", "Leader", "LowerLevel", "from_epec"]


[docs] def from_lower_level( *, n_x: int, n_y: int, x0: np.ndarray, y0: np.ndarray, f_upper: Callable[[np.ndarray, np.ndarray], float], f_lower: Callable[[np.ndarray, np.ndarray], float], n_g_lower: int = 0, g_lower: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None, n_h_lower: int = 0, h_lower: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None, derivatives: Derivatives = "jax", xl: Optional[np.ndarray] = None, xu: Optional[np.ndarray] = None, yl: Optional[np.ndarray] = None, yu: Optional[np.ndarray] = None, lambda0: Optional[np.ndarray] = None, mu0: Optional[np.ndarray] = None, ) -> MPCCProblem: """Emit an :class:`~pympcc.problem.MPCCProblem` from a bilevel program. Parameters ---------- n_x, n_y : int Sizes of the upper-level and lower-level decision-variable blocks. x0, y0 : array-like Initial guesses for ``x`` and ``y``. f_upper : callable Upper-level objective ``F(x, y) -> float``. f_lower : callable Lower-level objective ``f(x, y) -> float``. n_g_lower : int Number of lower-level inequality constraints ``g(x, y) ≤ 0``. Must be ``≥ 1`` (a bilevel without lower-level inequalities is not an MPCC; solve it as a regular NLP). g_lower : callable ``g(x, y) -> ndarray, shape (n_g_lower,)``. Required when ``n_g_lower > 0``. n_h_lower : int Number of lower-level equality constraints ``h(x, y) = 0`` (default 0). h_lower : callable, optional ``h(x, y) -> ndarray, shape (n_h_lower,)``. Required when ``n_h_lower > 0``. derivatives : {"jax", "fd"}, default "jax" Backend used both to form the stationarity rows and to fill every Jacobian on the emitted MPCC. ``"jax"`` requires the user-supplied callables to be ``jax.numpy``-traceable and is dramatically more accurate (single-level autodiff) than the ``"fd"`` fallback (nested finite differences). xl, xu, yl, yu : array-like, optional Bounds on the upper- and lower-level variable blocks (default ``±inf``). The λ block is automatically lower-bounded at ``0``; μ stays free. lambda0, mu0 : array-like, optional Initial multipliers for ``λ`` and ``μ`` (default zeros). Returns ------- MPCCProblem With ``n = n_x + n_y + n_g_lower + n_h_lower``, ``n_comp = n_g_lower``, ``n_eq = n_y + n_h_lower``. Notes ----- No lower-level inequality slacks are introduced; the complementarity pair lives directly between ``λ`` and ``-g(x, y)``. The stationarity rows depend on ``y, λ, μ`` through the lower-level Lagrangian's second-order partials, so the MPCC's own Jacobian computation involves second derivatives of ``f_lower``, ``g_lower``, and ``h_lower``. With ``derivatives="jax"`` JAX takes those derivatives exactly; with ``"fd"`` they are computed by nested finite differences (acceptable for prototyping, noisy on tight tolerances). """ if derivatives not in ("jax", "fd"): raise ValueError( f"derivatives must be 'jax' or 'fd', got {derivatives!r}" ) if n_g_lower < 1: raise ValueError( "n_g_lower must be >= 1; a bilevel program without lower-level " "inequality constraints is not an MPCC — solve it as a regular NLP." ) if g_lower is None: raise ValueError("g_lower is required when n_g_lower > 0") if n_h_lower < 0: raise ValueError("n_h_lower must be >= 0") if (n_h_lower > 0) != (h_lower is not None): raise ValueError( "h_lower and n_h_lower must agree: pass both or neither" ) n_lam = int(n_g_lower) n_mu = int(n_h_lower) n = n_x + n_y + n_lam + n_mu x0_arr = np.asarray(x0, dtype=float).ravel() y0_arr = np.asarray(y0, dtype=float).ravel() if x0_arr.shape != (n_x,): raise ValueError(f"x0 must have shape ({n_x},), got {x0_arr.shape}") if y0_arr.shape != (n_y,): raise ValueError(f"y0 must have shape ({n_y},), got {y0_arr.shape}") if lambda0 is None: lam0 = np.zeros(n_lam) else: lam0 = np.asarray(lambda0, dtype=float).ravel() if lam0.shape != (n_lam,): raise ValueError( f"lambda0 must have shape ({n_lam},), got {lam0.shape}" ) if np.any(lam0 < 0.0): raise ValueError("lambda0 must be non-negative (λ ≥ 0)") if mu0 is None: mu0_arr = np.zeros(n_mu) else: mu0_arr = np.asarray(mu0, dtype=float).ravel() if mu0_arr.shape != (n_mu,): raise ValueError( f"mu0 must have shape ({n_mu},), got {mu0_arr.shape}" ) z0 = np.concatenate([x0_arr, y0_arr, lam0, mu0_arr]) def _resolve_bound(b: Optional[np.ndarray], default: float, m: int, name: str) -> np.ndarray: if b is None: return np.full(m, default) arr = np.asarray(b, dtype=float).ravel() if arr.shape != (m,): raise ValueError(f"{name} must have shape ({m},), got {arr.shape}") return arr xl_full = np.concatenate([ _resolve_bound(xl, -np.inf, n_x, "xl"), _resolve_bound(yl, -np.inf, n_y, "yl"), np.zeros(n_lam), np.full(n_mu, -np.inf), ]) xu_full = np.concatenate([ _resolve_bound(xu, np.inf, n_x, "xu"), _resolve_bound(yu, np.inf, n_y, "yu"), np.full(n_lam, np.inf), np.full(n_mu, np.inf), ]) sx0, sx1 = 0, n_x sy0, sy1 = n_x, n_x + n_y sl0, sl1 = n_x + n_y, n_x + n_y + n_lam sm0, sm1 = n_x + n_y + n_lam, n def _split(z): return z[sx0:sx1], z[sy0:sy1], z[sl0:sl1], z[sm0:sm1] def objective(z: np.ndarray) -> float: x, y, _, _ = _split(z) return f_upper(x, y) def comp_G(z: np.ndarray) -> np.ndarray: return z[sl0:sl1] def comp_H(z: np.ndarray) -> np.ndarray: x, y, _, _ = _split(z) return -g_lower(x, y) if derivatives == "jax": from ._jax import HAS_JAX if not HAS_JAX: raise ImportError( "derivatives='jax' requires JAX. Install with " "`pip install pympcc[jax]` or pass derivatives='fd'." ) import jax import jax.numpy as jnp def _lagrangian_y(x, y, lam, mu): L = f_lower(x, y) if n_lam > 0: L = L + jnp.dot(lam, g_lower(x, y)) if n_mu > 0: L = L + jnp.dot(mu, h_lower(x, y)) return L _grad_y_L = jax.grad(_lagrangian_y, argnums=1) def stationarity(z): x, y, lam, mu = _split(z) return _grad_y_L(x, y, lam, mu) def eq_constraints(z): x, y, _lam, _mu = _split(z) stat = stationarity(z) if n_mu > 0: return jnp.concatenate([stat, h_lower(x, y)]) return stat else: from ._fd import _DEFAULT_H, fd_gradient def stationarity(z): x, y, lam, mu = _split(z) x_a = np.asarray(x, dtype=float) lam_a = np.asarray(lam, dtype=float) mu_a = np.asarray(mu, dtype=float) def _scalar_L_of_y(yy): v = float(f_lower(x_a, yy)) if n_lam > 0: v += float(np.dot(lam_a, np.asarray(g_lower(x_a, yy)))) if n_mu > 0: v += float(np.dot(mu_a, np.asarray(h_lower(x_a, yy)))) return v grad_fn = fd_gradient(_scalar_L_of_y, n_y, h=_DEFAULT_H, mode="forward") return grad_fn(np.asarray(y, dtype=float)) def eq_constraints(z): x, y, _lam, _mu = _split(z) stat = np.asarray(stationarity(z), dtype=float) if n_mu > 0: hh = np.asarray(h_lower(x, y), dtype=float) return np.concatenate([stat, hh]) return stat n_eq_total = n_y + n_mu return MPCCProblem( n=n, n_comp=n_lam, x0=z0, xl=xl_full, xu=xu_full, objective=objective, comp_G=comp_G, comp_H=comp_H, n_eq=n_eq_total, eq_constraints=eq_constraints, derivatives=derivatives, )
# ============================================================================ # # EPEC multi-leader-common-follower emitter (§5.5) # # ============================================================================ # @dataclass class Leader: """One leader in a multi-leader-common-follower EPEC. Each leader chooses its own block ``x_i ∈ ℝ^{n_x}`` to minimise its upper-level objective ``F_i(x, y)``, where ``x`` is the concatenated leader-decision vector across *all* leaders (``x = (x_1, …, x_N)``) and ``y`` is the shared lower-level decision. Each ``F_i`` typically reads its own block ``x[i_self_slice]`` and the rivals' blocks ``x_{-i}``. Parameters ---------- n_x : int Size of this leader's decision block. F : callable Upper-level objective ``F_i(x_full, y) -> scalar`` where ``x_full`` is the concatenated leader decision vector. Must be JAX-traceable when ``derivatives='jax'``. x0 : array-like, optional Initial guess for this leader's block (default zeros). """ n_x: int F: Callable[[np.ndarray, np.ndarray], float] x0: Optional[np.ndarray] = None @dataclass class LowerLevel: """Common lower-level optimisation shared across all leaders. Defines ``y ∈ argmin_y { f(x, y) : g(x, y) ≤ 0, h(x, y) = 0 }``, where ``x`` is the concatenated leader-decision vector. Parameters ---------- n_y : int Size of the lower-level decision block ``y``. f : callable Lower-level objective ``f(x_full, y) -> scalar``. Must be JAX-traceable when ``derivatives='jax'``. n_g : int Number of inequality constraints ``g(x, y) ≤ 0``. Must be ``≥ 1``; EPECs without a lower-level inequality block are not MPCCs. g : callable ``g(x_full, y) -> ndarray, shape (n_g,)``. y0 : array-like, optional Initial guess for ``y`` (default zeros). n_h : int, default 0 Number of equality constraints ``h(x, y) = 0``. h : callable, optional ``h(x_full, y) -> ndarray, shape (n_h,)``. Required iff ``n_h > 0``. """ n_y: int f: Callable[[np.ndarray, np.ndarray], float] n_g: int g: Callable[[np.ndarray, np.ndarray], np.ndarray] y0: Optional[np.ndarray] = None n_h: int = 0 h: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None def from_epec( *, leaders: list[Leader], common_lower: LowerLevel, derivatives: Derivatives = "jax", ) -> MPCCProblem: r"""Emit an :class:`MPCCProblem` from a multi-leader-common-follower EPEC. Each leader ``i = 1, …, N`` solves an MPCC .. math:: \min_{x_i}\ F_i(x, y) \quad\text{s.t.}\quad y \in \mathrm{argmin}_y\{f(x,y): g(x,y)\le 0,\ h(x,y)=0\} with ``x = (x_1, …, x_N)``. A Nash equilibrium of the EPEC is a fixed point of best-response: each ``x_i^*`` is optimal for leader ``i``'s MPCC given ``x_{-i}^*``. This emitter stacks each leader's KKT conditions (with the lower-level KKT system as constraints) into a single MPCC; under MPCC-MFCQ at the equilibrium, the stacked KKT system is a necessary characterisation of the Nash point. Variable layout --------------- ``z = [x_1, …, x_N | y | λ_lo | μ_lo | (ξ_i^y, ξ_i^h, θ_i, ν_i)_{i=1}^N]`` where ``λ_lo``, ``μ_lo`` are the shared lower-level multipliers for ``g, h``, and ``ξ_i^y``, ``ξ_i^h``, ``θ_i``, ``ν_i`` are leader ``i``'s multipliers on the lower-level KKT rows that act as constraints in leader ``i``'s MPCC. Equality rows ------------- For each leader ``i``: stationarity of the leader Lagrangian w.r.t. ``x_i`` (``n_{x,i}`` rows), ``y`` (``n_y``), ``λ_lo`` (``n_g``), and ``μ_lo`` (``n_h``). Plus shared lower-level KKT eq: ``∇_y L_lo = 0`` (``n_y``) and ``h_lo(x, y) = 0`` (``n_h``). Complementarity rows (``n_comp = (1 + 2N)·n_g``) -------------------- * shared: ``λ_lo ⊥ −g_lo(x, y)`` (``n_g`` pairs) * per leader ``i``: ``θ_i ⊥ −g_lo(x, y)`` and ``ν_i ⊥ λ_lo`` (``2·n_g`` pairs each) Parameters ---------- leaders : list of :class:`Leader` ``len(leaders) >= 2``. common_lower : :class:`LowerLevel` derivatives : {"jax"}, default "jax" Only ``"jax"`` is supported in v1; the leader Lagrangian involves the gradient of the lower-level Lagrangian w.r.t. ``y``, so the emitted equality block needs second-order autodiff. Pass ``derivatives="fd"`` to receive an explicit ``NotImplementedError``. Returns ------- MPCCProblem Objective is the constant zero — EPEC equilibrium is a feasibility problem; complementarity + stationarity rows pin the equilibrium without help from the upper objective. Notes ----- * No private leader constraints in v1. Every leader sees the same lower-level KKT system as its only constraint set; their upper objectives ``F_i`` may, of course, differ. * The ``ν_i`` block is a primal variable of the emitted MPCC even though leader ``i``'s KKT determines it from ``ξ_i^y`` via ``ν_i = (∇_y g_lo)·ξ_i^y``. Carrying ``ν_i`` explicitly keeps the Jacobian sparse and the complementarity row ``ν_i ⊥ λ_lo`` linear. * Initial multipliers are zero; this is feasible w.r.t. the complementarity rows since ``θ_i = ν_i = 0`` and ``λ_lo = 0`` start satisfies all comp pairs. """ if derivatives == "fd": raise NotImplementedError( "from_epec v1 supports derivatives='jax' only. The emitted " "equality rows differentiate the lower-level Lagrangian " "gradient w.r.t. (x, y, λ, μ), so a second-order autodiff " "backend is required. FD fallback is planned for v2." ) if derivatives != "jax": raise ValueError( f"derivatives must be 'jax' (got {derivatives!r})." ) if not isinstance(leaders, (list, tuple)) or len(leaders) < 2: raise ValueError( "from_epec requires at least 2 leaders; for a single-leader " "bilevel program use from_lower_level instead." ) cl = common_lower if cl.n_g < 1: raise ValueError( "common_lower.n_g must be >= 1; an EPEC without lower-level " "inequality constraints is not an MPCC." ) if cl.g is None: raise ValueError("common_lower.g is required when n_g > 0") if cl.n_h < 0: raise ValueError("common_lower.n_h must be >= 0") if (cl.n_h > 0) != (cl.h is not None): raise ValueError( "common_lower.h and common_lower.n_h must agree: pass both or neither" ) from ._jax import HAS_JAX if not HAS_JAX: raise ImportError( "from_epec requires JAX. Install with `pip install pympcc[jax]`." ) import jax import jax.numpy as jnp N = len(leaders) n_x_per = [int(L.n_x) for L in leaders] if any(s < 1 for s in n_x_per): raise ValueError("each Leader must have n_x >= 1") n_x_total = int(sum(n_x_per)) n_y = int(cl.n_y) n_g = int(cl.n_g) n_h = int(cl.n_h) # z layout offsets x_offsets = np.cumsum([0] + n_x_per).tolist() off_y = n_x_total off_lam = off_y + n_y off_mu = off_lam + n_g leader_dual_size = n_y + n_h + 2 * n_g leader_off = [off_mu + n_h + i * leader_dual_size for i in range(N + 1)] n_total = leader_off[-1] def _split(z): x_full = z[:n_x_total] y = z[off_y:off_lam] lam_lo = z[off_lam:off_mu] mu_lo = z[off_mu:off_mu + n_h] leader_duals = [] for i in range(N): base = leader_off[i] xi_y = z[base : base + n_y] xi_h = z[base + n_y : base + n_y + n_h] theta = z[base + n_y + n_h : base + n_y + n_h + n_g] nu = z[base + n_y + n_h + n_g : base + n_y + n_h + 2 * n_g] leader_duals.append((xi_y, xi_h, theta, nu)) return x_full, y, lam_lo, mu_lo, leader_duals f_lo = cl.f g_lo = cl.g h_lo = cl.h # Lower-level Lagrangian gradient w.r.t. y — used as a stationarity row # in both the shared lower-level KKT block and inside each leader's # Lagrangian (where it acts as an = 0 constraint with multiplier ξ_i^y). def _stat_lo(x_full, y, lam_lo, mu_lo): def _lag(yy): v = f_lo(x_full, yy) if n_g > 0: v = v + jnp.dot(lam_lo, g_lo(x_full, yy)) if n_h > 0: v = v + jnp.dot(mu_lo, h_lo(x_full, yy)) return v return jax.grad(_lag)(y) def _leader_lagrangian(i_int: int, z): x_full, y, lam_lo, mu_lo, leader_duals = _split(z) xi_y, xi_h, theta, nu = leader_duals[i_int] L = leaders[i_int].F(x_full, y) L = L + jnp.dot(xi_y, _stat_lo(x_full, y, lam_lo, mu_lo)) if n_h > 0: L = L + jnp.dot(xi_h, h_lo(x_full, y)) # type: ignore[misc] L = L + jnp.dot(theta, g_lo(x_full, y)) - jnp.dot(nu, lam_lo) return L _leader_grads = [ jax.grad(lambda z, _i=i: _leader_lagrangian(_i, z)) for i in range(N) ] def eq_constraints(z): x_full, y, lam_lo, mu_lo, _ = _split(z) rows: list = [] for i in range(N): grad_z = _leader_grads[i](z) xi, xi1 = x_offsets[i], x_offsets[i + 1] rows.append(grad_z[xi:xi1]) # ∂L_i/∂x_i rows.append(grad_z[off_y:off_lam]) # ∂L_i/∂y rows.append(grad_z[off_lam:off_mu]) # ∂L_i/∂λ_lo if n_h > 0: rows.append(grad_z[off_mu:off_mu + n_h]) # ∂L_i/∂μ_lo rows.append(_stat_lo(x_full, y, lam_lo, mu_lo)) # ∇_y L_lo = 0 if n_h > 0: rows.append(h_lo(x_full, y)) # h_lo(x, y) = 0 return jnp.concatenate(rows) def comp_G(z): _, _, lam_lo, _, leader_duals = _split(z) parts = [lam_lo] for _xi_y, _xi_h, theta, nu in leader_duals: parts.append(theta) parts.append(nu) return jnp.concatenate(parts) def comp_H(z): x_full, y, lam_lo, _, _ = _split(z) neg_g = -g_lo(x_full, y) parts = [neg_g] for _ in range(N): parts.append(neg_g) parts.append(lam_lo) return jnp.concatenate(parts) def objective(_z): return 0.0 # x0 x0_parts = [] for L in leaders: if L.x0 is None: x0_parts.append(np.zeros(L.n_x)) else: arr = np.asarray(L.x0, dtype=float).ravel() if arr.shape != (L.n_x,): raise ValueError( f"Leader.x0 must have shape ({L.n_x},), got {arr.shape}" ) x0_parts.append(arr) if cl.y0 is None: y0 = np.zeros(n_y) else: y0 = np.asarray(cl.y0, dtype=float).ravel() if y0.shape != (n_y,): raise ValueError(f"common_lower.y0 must have shape ({n_y},), got {y0.shape}") z0 = np.concatenate([ *x0_parts, y0, np.zeros(n_g), np.zeros(n_h), *[np.zeros(leader_dual_size) for _ in range(N)], ]) # Bounds: λ_lo ≥ 0, θ_i ≥ 0, ν_i ≥ 0; everything else free. xl_full = np.full(n_total, -np.inf) xu_full = np.full(n_total, np.inf) xl_full[off_lam:off_mu] = 0.0 for i in range(N): base = leader_off[i] theta_off = base + n_y + n_h nu_off = theta_off + n_g xl_full[theta_off:theta_off + n_g] = 0.0 xl_full[nu_off :nu_off + n_g] = 0.0 n_eq_total = n_x_total + N * (n_y + n_g + n_h) + n_y + n_h n_comp_total = (1 + 2 * N) * n_g return MPCCProblem( n=n_total, n_comp=n_comp_total, x0=z0, xl=xl_full, xu=xu_full, objective=objective, comp_G=comp_G, comp_H=comp_H, n_eq=n_eq_total, eq_constraints=eq_constraints, derivatives=derivatives, )