Differentiable solve with jax.grad¶
pympcc.solve_jax is registered as jax.custom_vjp, so JAX can differentiate through a converged MPCC. The forward calls IPOPT; the backward solves the KKT saddle once and contracts via jax.vjp on the user’s parametric callables.
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import pympcc
jax.config.update("jax_enable_x64", True)
A parametric branch-selection MPCC¶
Two complementarity pairs, \(\theta\)-shifted objective targets:
The active branch is determined by the sign of each \(\theta\) coordinate: if \(\theta_0 > 0\) then \(x_0 = \theta_0\) and \(x_2 = 0\); otherwise \(x_0 = 0\) and \(x_2 = -\theta_2\) (and similarly for indices 1, 3). The optimum varies non-smoothly in \(\theta\) at branch boundaries — exactly where MPCC differentiation is interesting.
pmpcc = pympcc.ParametricMPCC(
n=4, n_comp=2,
objective=lambda x, theta: 0.5 * jnp.sum((x - theta) ** 2),
comp_G=lambda x, theta: x[:2],
comp_H=lambda x, theta: x[2:],
)
Forward solve¶
theta0 = jnp.array([2.0, 0.0, 0.0, 3.0])
x_star = pympcc.solve_jax(pmpcc, theta0, x0=np.asarray(theta0))
print("x* =", np.asarray(x_star))
x* = [2.00000000e+00 8.04201683e-10 1.72455387e-09 3.00000000e+00]
<string>:38: UserWarning: JAX autodiff active for: gradient, comp_G_jacobian, comp_H_jacobian. Ensure all primal callables are JAX-differentiable.
At \(\theta = (2, 0, 0, 3)\) the active branches pin \(x_1\) and \(x_2\) to 0; the free coordinates take their unconstrained optima.
Gradient via jax.grad¶
def loss(theta):
x = pympcc.solve_jax(pmpcc, theta, x0=np.asarray(theta0))
return 0.5 * jnp.sum(x ** 2)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
g = jax.grad(loss)(theta0)
print("∇_θ L =", np.asarray(g))
∇_θ L = [ 2.00000001 -0. -0. 3. ]
Finite-difference verification¶
def fd_grad(loss_fn, theta, eps=1e-3):
theta = np.asarray(theta, dtype=float)
g = np.zeros_like(theta)
for k in range(theta.size):
tp = theta.copy(); tp[k] += eps
tm = theta.copy(); tm[k] -= eps
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
lp = float(loss_fn(jnp.asarray(tp)))
lm = float(loss_fn(jnp.asarray(tm)))
g[k] = (lp - lm) / (2 * eps)
return g
g_fd = fd_grad(loss, theta0)
print("∇_θ L (FD) =", g_fd)
print("max abs error:", float(np.max(np.abs(np.asarray(g) - g_fd))))
∇_θ L (FD) = [ 2.00000003e+00 -8.04742051e-07 4.28401359e-06 3.00000003e+00]
max abs error: 4.284013588318203e-06
Full Jacobian via per-row jax.grad¶
jax.jacrev is not supported in Phase 1 (its internal vmap cannot trace the NumPy/IPOPT backward). Build the Jacobian one row at a time instead:
def x_i(theta, i):
return pympcc.solve_jax(pmpcc, theta, x0=np.asarray(theta0))[i]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
rows = [jax.grad(lambda th, i=i: x_i(th, i))(theta0) for i in range(4)]
J = np.stack([np.asarray(r) for r in rows], axis=0)
print("dx*/dθ =")
print(J)
dx*/dθ =
[[ 1.00000001 -0. -0. -0. ]
[-0. -0. -0. -0. ]
[-0. -0. -0. -0. ]
[-0. -0. -0. 1. ]]
For this θ, the active branches give x* = (θ_0, 0, 0, θ_3), so the Jacobian is diag(1, 0, 0, 1).