pympcc.solve_jax

pympcc.solve_jax(parametric, theta, *, x0=None, strategy='scholtes', **solve_kwargs)[source]

Solve a parametric MPCC; returns x* differentiable in θ.

Wraps pympcc.solve() in a jax.custom_vjp() whose backward rule is the sIPOPT-style adjoint solve. tnlp_refine=True is forced (overridden if explicitly supplied) so the bwd pass has access to certified TNLP multipliers.

Parameters:
  • parametric (ParametricMPCC) – JAX-traceable problem definition.

  • theta (array-like (or pytree of arrays)) – Parameter input. Differentiable.

  • x0 (array-like, optional) – Initial primal point; defaults to np.zeros(parametric.n). Treated as a constant by the autodiff.

  • strategy (str) – Strategy forwarded to pympcc.solve() (default "scholtes").

  • **solve_kwargs – Additional kwargs forwarded to pympcc.solve() (e.g. ipopt_options, presolve). Not differentiated.

Returns:

Optimal primal vector x*.

Return type:

jnp.ndarray, shape (n,)

Raises:

ImportError – When jax is not installed.

Notes

Not jittable: pympcc.solve() runs IPOPT under the hood. Call inside jax.grad(...) directly, never inside jax.jit(...). Forward-mode autodiff is not yet supported.