pympcc.solve_jax¶
- pympcc.solve_jax(parametric, theta, *, x0=None, strategy='scholtes', **solve_kwargs)[source]¶
Solve a parametric MPCC; returns
x*differentiable inθ.Wraps
pympcc.solve()in ajax.custom_vjp()whose backward rule is the sIPOPT-style adjoint solve.tnlp_refine=Trueis forced (overridden if explicitly supplied) so the bwd pass has access to certified TNLP multipliers.- Parameters:
parametric (ParametricMPCC) – JAX-traceable problem definition.
theta (array-like (or pytree of arrays)) – Parameter input. Differentiable.
x0 (array-like, optional) – Initial primal point; defaults to
np.zeros(parametric.n). Treated as a constant by the autodiff.strategy (str) – Strategy forwarded to
pympcc.solve()(default"scholtes").**solve_kwargs – Additional kwargs forwarded to
pympcc.solve()(e.g.ipopt_options,presolve). Not differentiated.
- Returns:
Optimal primal vector
x*.- Return type:
jnp.ndarray, shape (n,)
- Raises:
ImportError – When
jaxis not installed.
Notes
Not jittable:
pympcc.solve()runs IPOPT under the hood. Call insidejax.grad(...)directly, never insidejax.jit(...). Forward-mode autodiff is not yet supported.