"""JAX-differentiable MPCC solve via ``jax.custom_vjp`` (§5.6).
Differentiates through a converged MPCC solution by solving the adjoint
KKT system once (sIPOPT-style implicit differentiation). The forward
pass runs :func:`pympcc.solve` with ``tnlp_refine=True`` to obtain a
certified active set; the backward pass solves
``K · [u; w] = [v; 0]`` and contracts ``u``, ``w`` with
``∂(∇_xL)/∂θ`` and ``∂c/∂θ`` (computed via :func:`jax.vjp`) to produce
the ``θ`` cotangent.
Usage
-----
>>> import jax, jax.numpy as jnp
>>> from pympcc import ParametricMPCC, solve_jax
>>> pmpcc = ParametricMPCC(
... n=2, n_comp=1,
... objective=lambda x, theta: 0.5 * jnp.sum((x - theta) ** 2),
... comp_G=lambda x, theta: x[:1],
... comp_H=lambda x, theta: x[1:],
... )
>>> def loss(theta):
... x_star = solve_jax(pmpcc, theta)
... return jnp.sum(x_star)
>>> grad = jax.grad(loss)(jnp.array([1.0, 2.0]))
Caveats
-------
* **Not jittable.** :func:`pympcc.solve` is opaque NumPy + IPOPT code;
invoke ``jax.grad(loss)(theta)`` outside any :func:`jax.jit` wrapper.
* **No θ-dependent bounds.** ``xl`` / ``xu`` are constants of the solve;
no derivative flows through them. Use soft constraints for parametric
bounds.
* **Forward-mode deferred.** Only ``custom_vjp`` is registered;
``jax.jvp`` / ``jax.jacfwd`` are not yet supported.
Adjoint identity
----------------
At a TNLP-refined solution the implicit-function theorem gives::
K [dx_dθ; dλ_dθ] = -[∂(∇_xL)/∂θ; ∂c/∂θ]
For a cotangent ``v`` on ``x*`` the VJP becomes (``K`` is symmetric)::
K [u; w] = [v; 0]
θ̄ = -(u^T ∂(∇_xL)/∂θ + w^T ∂c/∂θ)
Both contractions are obtained for any ``θ`` pytree shape via two
:func:`jax.vjp` calls, so a 1-D ``θ`` and a more elaborate pytree work
through the same code path.
Skipped (returns zero ``θ``-cotangent with a ``UserWarning``):
* solve did not converge,
* TNLP refinement skipped (large biactive set, IPOPT restoration
failure, etc.),
* biactive pairs at ``x*`` (IFT prerequisites fail).
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Optional
import numpy as np
from ._constants import BIACTIVE_TOL as _BIACTIVE_TOL
from ._diagnostics import _stack_active_gradient_matrix, active_sets
from ._jax import HAS_JAX
from ._sosc import _build_hessian
from ._typing import StrategyName
from .problem import MPCCProblem
from .result import MPCCResult
__all__ = ["ParametricMPCC", "solve_jax"]
[docs]
@dataclass
class ParametricMPCC:
"""JAX-traceable MPCC parameterised by an external vector ``θ``.
Mirrors :class:`pympcc.MPCCProblem` but every callable takes
``(x, θ)`` instead of just ``x``. All callables must be
JAX-differentiable (use :mod:`jax.numpy` operations exclusively).
Attributes
----------
n : int
Number of decision variables.
n_comp : int
Number of complementarity pairs.
objective : callable
``f(x, θ) -> scalar``.
comp_G, comp_H : callable
``G(x, θ) -> (n_comp,)`` / ``H(x, θ) -> (n_comp,)``.
eq_constraints, ineq_constraints : callable, optional
``h(x, θ) -> (n_eq,)`` / ``g(x, θ) -> (n_ineq,)``.
n_eq, n_ineq : int
Constraint counts (default ``0``).
xl, xu : array-like, optional
Variable bounds; treated as constants by the autodiff pass.
"""
n: int
n_comp: int
objective: Callable
comp_G: Callable
comp_H: Callable
eq_constraints: Optional[Callable] = None
ineq_constraints: Optional[Callable] = None
n_eq: int = 0
n_ineq: int = 0
xl: Optional[np.ndarray] = None
xu: Optional[np.ndarray] = None
def __post_init__(self) -> None:
if self.n < 1:
raise ValueError("ParametricMPCC: n must be >= 1")
if self.n_comp < 1:
raise ValueError("ParametricMPCC: n_comp must be >= 1")
if self.n_eq and self.eq_constraints is None:
raise ValueError("eq_constraints required when n_eq > 0")
if self.n_ineq and self.ineq_constraints is None:
raise ValueError("ineq_constraints required when n_ineq > 0")
[docs]
def materialise(
self,
theta_np: np.ndarray,
x0: np.ndarray,
) -> MPCCProblem:
"""Build a numeric :class:`MPCCProblem` with ``θ`` baked into closures.
Sets ``derivatives="jax"`` so all gradients/Jacobians are produced
by :mod:`jax`; the closure over ``theta_np`` stays JAX-traceable
in the ``x`` argument because ``theta_np`` is a constant.
"""
if not HAS_JAX:
raise ImportError(
"JAX is required for ParametricMPCC.materialise(); "
"install with `pip install pympcc[jax]`"
)
import jax.numpy as jnp
theta_jnp = jnp.asarray(theta_np, dtype=float)
def _obj(x, _t=theta_jnp):
return self.objective(x, _t)
def _G(x, _t=theta_jnp):
return self.comp_G(x, _t)
def _H(x, _t=theta_jnp):
return self.comp_H(x, _t)
eq_fn = None
if self.eq_constraints is not None:
def _eq(x, _t=theta_jnp):
return self.eq_constraints(x, _t)
eq_fn = _eq
ineq_fn = None
if self.ineq_constraints is not None:
def _ineq(x, _t=theta_jnp):
return self.ineq_constraints(x, _t)
ineq_fn = _ineq
return MPCCProblem(
n=self.n,
n_comp=self.n_comp,
x0=np.asarray(x0, dtype=float),
objective=_obj,
comp_G=_G,
comp_H=_H,
n_eq=self.n_eq,
n_ineq=self.n_ineq,
eq_constraints=eq_fn,
ineq_constraints=ineq_fn,
xl=self.xl,
xu=self.xu,
derivatives="jax",
)
[docs]
def solve_jax(
parametric: ParametricMPCC,
theta: Any,
*,
x0: Optional[np.ndarray] = None,
strategy: StrategyName = "scholtes",
**solve_kwargs,
):
"""Solve a parametric MPCC; returns ``x*`` differentiable in ``θ``.
Wraps :func:`pympcc.solve` in a :func:`jax.custom_vjp` whose backward
rule is the sIPOPT-style adjoint solve. ``tnlp_refine=True`` is
forced (overridden if explicitly supplied) so the bwd pass has access
to certified TNLP multipliers.
Parameters
----------
parametric : ParametricMPCC
JAX-traceable problem definition.
theta : array-like (or pytree of arrays)
Parameter input. Differentiable.
x0 : array-like, optional
Initial primal point; defaults to ``np.zeros(parametric.n)``.
Treated as a constant by the autodiff.
strategy : str
Strategy forwarded to :func:`pympcc.solve` (default ``"scholtes"``).
**solve_kwargs
Additional kwargs forwarded to :func:`pympcc.solve`
(e.g. ``ipopt_options``, ``presolve``). Not differentiated.
Returns
-------
jnp.ndarray, shape (n,)
Optimal primal vector ``x*``.
Raises
------
ImportError
When :mod:`jax` is not installed.
Notes
-----
Not jittable: :func:`pympcc.solve` runs IPOPT under the hood. Call
inside ``jax.grad(...)`` directly, never inside ``jax.jit(...)``.
Forward-mode autodiff is not yet supported.
"""
if not HAS_JAX:
raise ImportError(
"JAX is required for solve_jax; install with `pip install pympcc[jax]`"
)
import jax
import jax.numpy as jnp
from .solver import solve as _solve
if x0 is None:
x0_np = np.zeros(parametric.n, dtype=float)
else:
x0_np = np.asarray(x0, dtype=float)
# Force certified MPCC multipliers — the adjoint pass needs them.
solve_kwargs = dict(solve_kwargs)
solve_kwargs["tnlp_refine"] = True
# JAX requires residuals to be JAX-compatible pytrees. ``MPCCProblem`` /
# ``MPCCResult`` are not, so stash them via a closure list and only thread
# ``theta`` through the residual pytree. Each call to ``solve_jax``
# constructs a fresh ``_solve_op`` so concurrent calls don't clobber.
_state: dict = {}
@jax.custom_vjp
def _solve_op(theta_):
return _fwd(theta_)[0]
def _fwd(theta_):
theta_np = np.asarray(theta_, dtype=float)
problem = parametric.materialise(theta_np, x0_np)
result = _solve(problem, strategy=strategy, **solve_kwargs)
_state["problem"] = problem
_state["result"] = result
x_star = jnp.asarray(result.x, dtype=float)
return x_star, theta_
def _bwd(theta_res, v):
problem = _state["problem"]
result = _state["result"]
v_np = np.asarray(v, dtype=float)
theta_bar = _theta_cotangent(parametric, theta_res, problem, result, v_np)
return (theta_bar,)
_solve_op.defvjp(_fwd, _bwd)
return _solve_op(jnp.asarray(theta, dtype=float))
# --------------------------------------------------------------------------- #
# Backward pass — adjoint KKT solve + jax.vjp for the parametric pjacs #
# --------------------------------------------------------------------------- #
def _theta_cotangent(
parametric: ParametricMPCC,
theta_jnp: Any,
problem: MPCCProblem,
result: MPCCResult,
v: np.ndarray,
) -> Any:
"""Return ``θ̄ = (dx*/dθ)^T v`` via the adjoint KKT solve."""
import jax
import jax.numpy as jnp
zero_theta = jax.tree_util.tree_map(
lambda a: jnp.zeros_like(jnp.asarray(a, dtype=float)),
theta_jnp,
)
if not result.success:
warnings.warn(
"solve_jax: forward solve did not converge; θ-gradient is zero.",
UserWarning,
stacklevel=4,
)
return zero_theta
x_star = np.asarray(result.x, dtype=float)
sets = active_sets(result, problem, tol=_BIACTIVE_TOL)
if sets["I_00"].size > 0:
warnings.warn(
"solve_jax: biactive pairs at x*; θ-gradient is zero. "
"Implicit differentiation requires MPCC-LICQ (no biactive set).",
UserWarning,
stacklevel=4,
)
return zero_theta
# ----- multipliers: TNLP-refined (preferred) → zeros (conservative) ----- #
n_g, n_h, n_c = problem.n_ineq, problem.n_eq, problem.n_comp
lam_g = np.zeros(n_g)
lam_h = np.zeros(n_h)
lam_G = np.zeros(n_c)
lam_H = np.zeros(n_c)
tnlp = getattr(result, "tnlp_refined", None)
if tnlp is not None and getattr(tnlp, "success", False):
lam_G = np.asarray(tnlp.mult_comp_G, dtype=float)
lam_H = np.asarray(tnlp.mult_comp_H, dtype=float)
if tnlp.mult_ineq is not None:
lam_g = np.asarray(tnlp.mult_ineq, dtype=float)
if tnlp.mult_eq is not None:
lam_h = np.asarray(tnlp.mult_eq, dtype=float)
else:
warnings.warn(
"solve_jax: TNLP refinement unavailable; using zero multipliers in "
"the adjoint Hessian. The gradient is exact only when the active "
"constraints are linear in x.",
UserWarning,
stacklevel=4,
)
# ----- KKT matrix and adjoint solve ------------------------------------ #
Jc, _ = _stack_active_gradient_matrix(problem, x_star, sets)
n_active = int(Jc.shape[0])
H_lag = _build_hessian(x_star, problem, lam_g, lam_h, lam_G, lam_H, problem.fd_h)
K = np.block([
[H_lag, Jc.T],
[Jc, np.zeros((n_active, n_active))],
])
rhs = np.concatenate([v, np.zeros(n_active)])
try:
sol = np.linalg.solve(K, rhs)
except np.linalg.LinAlgError:
warnings.warn(
"solve_jax: KKT matrix singular; using Tikhonov-regularised "
"pseudoinverse for the adjoint solve.",
UserWarning,
stacklevel=4,
)
K_reg = K + 1e-12 * np.eye(K.shape[0])
sol, *_ = np.linalg.lstsq(K_reg, rhs, rcond=None)
u = sol[: problem.n]
w = sol[problem.n:]
# ----- VJPs over parametric callables at (x*, θ) ----------------------- #
x_star_jnp = jnp.asarray(x_star, dtype=float)
lam_g_jnp = jnp.asarray(lam_g, dtype=float)
lam_h_jnp = jnp.asarray(lam_h, dtype=float)
lam_G_jnp = jnp.asarray(lam_G, dtype=float)
lam_H_jnp = jnp.asarray(lam_H, dtype=float)
def _lagrangian(x, theta):
L = parametric.objective(x, theta)
L = L + jnp.sum(lam_G_jnp * parametric.comp_G(x, theta))
L = L + jnp.sum(lam_H_jnp * parametric.comp_H(x, theta))
if parametric.n_eq:
L = L + jnp.sum(lam_h_jnp * parametric.eq_constraints(x, theta))
if parametric.n_ineq:
L = L + jnp.sum(lam_g_jnp * parametric.ineq_constraints(x, theta))
return L
def _grad_x_L_of_theta(theta):
return jax.grad(_lagrangian, argnums=0)(x_star_jnp, theta)
I_g = sets["I_g"]
I_G = sets["I_G"]
I_H = sets["I_H"]
n_bnd = int(np.union1d(sets["I_xL"], sets["I_xU"]).size)
def _active_constraints(theta):
parts = []
if parametric.n_eq:
parts.append(parametric.eq_constraints(x_star_jnp, theta))
if I_G.size:
parts.append(parametric.comp_G(x_star_jnp, theta)[jnp.asarray(I_G)])
if I_H.size:
parts.append(parametric.comp_H(x_star_jnp, theta)[jnp.asarray(I_H)])
if I_g.size:
parts.append(parametric.ineq_constraints(x_star_jnp, theta)[jnp.asarray(I_g)])
if n_bnd:
# Bounds are constants of the solve (no θ dependence).
parts.append(jnp.zeros(n_bnd))
if not parts:
return jnp.zeros(0)
return jnp.concatenate(parts)
u_jnp = jnp.asarray(u, dtype=float)
_, vjp_grad_L = jax.vjp(_grad_x_L_of_theta, theta_jnp)
contrib_L = vjp_grad_L(u_jnp)[0]
if n_active > 0:
w_jnp = jnp.asarray(w, dtype=float)
_, vjp_c = jax.vjp(_active_constraints, theta_jnp)
contrib_c = vjp_c(w_jnp)[0]
else:
contrib_c = zero_theta
return jax.tree_util.tree_map(lambda a, b: -(a + b), contrib_L, contrib_c)