JAX API¶
The JAX layer is a callable class compatible with JAX transformations like jax.grad, jax.vmap, and jax.jit (with the Moreau solver).
CvxpyLayer¶
- class cvxpylayers.jax.CvxpyLayer[source]¶
Bases:
objectA differentiable convex optimization layer for JAX.
This layer wraps a parametrized CVXPY problem, solving it in the forward pass and computing gradients via implicit differentiation. Compatible with
jax.grad,jax.jit, andjax.vmap.- JIT/vmap Compatibility:
When using solver=”MOREAU”, this layer is fully compatible with jax.jit, jax.vmap, and jax.pmap. The Moreau solver provides native JAX autodiff support via custom_vjp with pure_callback, enabling JIT compilation of the entire solve-differentiate pipeline.
Other solvers (DIFFCP) use Python-based solving and are not JIT-compatible due to closure-based gradient handling.
Example
>>> import cvxpy as cp >>> import jax >>> import jax.numpy as jnp >>> from cvxpylayers.jax import CvxpyLayer >>> >>> # Define a simple QP >>> x = cp.Variable(2) >>> A = cp.Parameter((3, 2)) >>> b = cp.Parameter(3) >>> problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0]) >>> >>> # Create the layer >>> layer = CvxpyLayer(problem, parameters=[A, b], variables=[x]) >>> >>> # Solve and compute gradients >>> A_jax = jax.random.normal(jax.random.PRNGKey(0), (3, 2)) >>> b_jax = jax.random.normal(jax.random.PRNGKey(1), (3,)) >>> (solution,) = layer(A_jax, b_jax) >>> >>> # Gradient computation >>> def loss_fn(A, b): ... (x,) = layer(A, b) ... return jnp.sum(x) >>> grads = jax.grad(loss_fn, argnums=[0, 1])(A_jax, b_jax)
- __init__(problem, parameters, variables, solver=None, gp=False, verbose=False, canon_backend=None, solver_args=None)[source]¶
Initialize the differentiable optimization layer.
- Parameters:
problem (Problem) – A CVXPY Problem. Must be DPP-compliant (
problem.is_dpp()must return True).parameters (list[Parameter]) – List of CVXPY Parameters that will be filled with values at runtime. Order must match the order of arrays passed to __call__().
variables (list[Variable]) – List of CVXPY Variables whose optimal values will be returned by __call__(). Order determines the order of returned arrays.
solver (str | None) – CVXPY solver to use (e.g.,
cp.CLARABEL,cp.SCS). If None, uses the default diffcp solver.gp (bool) – If True, problem is a geometric program. Parameters will be log-transformed before solving.
verbose (bool) – If True, print solver output.
canon_backend (str | None) – Backend for canonicalization. Options are ‘diffcp’, ‘cuclarabel’, or None (auto-select).
solver_args (dict[str, Any] | None) – Default keyword arguments passed to the solver. Can be overridden per-call in __call__().
- Raises:
AssertionError – If problem is not DPP-compliant.
ValueError – If parameters or variables are not part of the problem.
- Return type:
None
- __call__(*params, solver_args=None)[source]¶
Solve the optimization problem and return optimal variable values.
- Parameters:
*params (Array) – Array values for each CVXPY Parameter, in the same order as the
parametersargument to __init__(). Each array shape must match the corresponding Parameter shape, optionally with a batch dimension prepended. Batched and unbatched parameters can be mixed; unbatched parameters are broadcast.solver_args (dict[str, Any] | None) – Keyword arguments passed to the solver, overriding any defaults set in __init__().
- Returns:
Tuple of arrays containing optimal values for each CVXPY Variable specified in the
variablesargument to __init__(). If inputs are batched, outputs will have matching batch dimensions.- Raises:
RuntimeError – If the solver fails to find a solution.
- Return type:
tuple[Array, …]
Example
>>> # Single problem >>> (x_opt,) = layer(A_array, b_array) >>> >>> # Batched: solve 10 problems in parallel >>> A_batch = jax.random.normal(key, (10, 3, 2)) >>> b_batch = jax.random.normal(key, (10, 3)) >>> (x_batch,) = layer(A_batch, b_batch) # x_batch.shape = (10, 2)
Usage Example¶
import cvxpy as cp
import jax
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer
# Define problem
n, m = 2, 3
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)
problem = cp.Problem(
cp.Minimize(cp.sum_squares(A @ x - b)),
[x >= 0]
)
# Create layer
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
# Solve
key = jax.random.PRNGKey(0)
key, k1, k2 = jax.random.split(key, 3)
A_jax = jax.random.normal(k1, shape=(m, n))
b_jax = jax.random.normal(k2, shape=(m,))
(x_sol,) = layer(A_jax, b_jax)
Computing Gradients¶
Use jax.grad to compute gradients:
def loss_fn(A, b):
(x,) = layer(A, b)
return jnp.sum(x)
# Gradient with respect to A and b
grad_fn = jax.grad(loss_fn, argnums=[0, 1])
dA, db = grad_fn(A_jax, b_jax)
JIT Compilation¶
jax.jit is supported with the Moreau solver:
import jax
import jax.numpy as jnp
# Create layer with Moreau solver
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x], solver="MOREAU")
# JIT the entire solve + gradient computation
@jax.jit
def solve_and_grad(A, b):
(x,) = layer(A, b)
return jnp.sum(x)
grad_fn = jax.grad(solve_and_grad, argnums=[0, 1])
dA, db = grad_fn(A_jax, b_jax)
Note
JIT compilation requires solver="MOREAU". Other solvers (DIFFCP) use Python callbacks that are not JIT-compatible.
Vectorization with vmap¶
Use jax.vmap for batched execution:
# Batched solve
batch_size = 10
A_batch = jax.random.normal(key, shape=(batch_size, m, n))
@jax.vmap
def solve_single(A):
(x,) = layer(A, b_jax)
return x
x_batch = solve_single(A_batch) # Shape: (10, n)
# Or use built-in batching
(x_batch,) = layer(A_batch, b_jax)