JAX API

The JAX layer is a callable class compatible with JAX transformations like jax.grad, jax.vmap, and jax.jit (with the Moreau solver).

CvxpyLayer

class cvxpylayers.jax.CvxpyLayer[source]

Bases: object

A differentiable convex optimization layer for JAX.

This layer wraps a parametrized CVXPY problem, solving it in the forward pass and computing gradients via implicit differentiation. Compatible with jax.grad, jax.jit, and jax.vmap.

JIT/vmap Compatibility:

When using solver=”MOREAU”, this layer is fully compatible with jax.jit, jax.vmap, and jax.pmap. The Moreau solver provides native JAX autodiff support via custom_vjp with pure_callback, enabling JIT compilation of the entire solve-differentiate pipeline.

Other solvers (DIFFCP) use Python-based solving and are not JIT-compatible due to closure-based gradient handling.

Example

>>> import cvxpy as cp
>>> import jax
>>> import jax.numpy as jnp
>>> from cvxpylayers.jax import CvxpyLayer
>>>
>>> # Define a simple QP
>>> x = cp.Variable(2)
>>> A = cp.Parameter((3, 2))
>>> b = cp.Parameter(3)
>>> problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0])
>>>
>>> # Create the layer
>>> layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
>>>
>>> # Solve and compute gradients
>>> A_jax = jax.random.normal(jax.random.PRNGKey(0), (3, 2))
>>> b_jax = jax.random.normal(jax.random.PRNGKey(1), (3,))
>>> (solution,) = layer(A_jax, b_jax)
>>>
>>> # Gradient computation
>>> def loss_fn(A, b):
...     (x,) = layer(A, b)
...     return jnp.sum(x)
>>> grads = jax.grad(loss_fn, argnums=[0, 1])(A_jax, b_jax)
__init__(problem, parameters, variables, solver=None, gp=False, verbose=False, canon_backend=None, solver_args=None)[source]

Initialize the differentiable optimization layer.

Parameters:
  • problem (Problem) – A CVXPY Problem. Must be DPP-compliant (problem.is_dpp() must return True).

  • parameters (list[Parameter]) – List of CVXPY Parameters that will be filled with values at runtime. Order must match the order of arrays passed to __call__().

  • variables (list[Variable]) – List of CVXPY Variables whose optimal values will be returned by __call__(). Order determines the order of returned arrays.

  • solver (str | None) – CVXPY solver to use (e.g., cp.CLARABEL, cp.SCS). If None, uses the default diffcp solver.

  • gp (bool) – If True, problem is a geometric program. Parameters will be log-transformed before solving.

  • verbose (bool) – If True, print solver output.

  • canon_backend (str | None) – Backend for canonicalization. Options are ‘diffcp’, ‘cuclarabel’, or None (auto-select).

  • solver_args (dict[str, Any] | None) – Default keyword arguments passed to the solver. Can be overridden per-call in __call__().

Raises:
  • AssertionError – If problem is not DPP-compliant.

  • ValueError – If parameters or variables are not part of the problem.

Return type:

None

__call__(*params, solver_args=None)[source]

Solve the optimization problem and return optimal variable values.

Parameters:
  • *params (Array) – Array values for each CVXPY Parameter, in the same order as the parameters argument to __init__(). Each array shape must match the corresponding Parameter shape, optionally with a batch dimension prepended. Batched and unbatched parameters can be mixed; unbatched parameters are broadcast.

  • solver_args (dict[str, Any] | None) – Keyword arguments passed to the solver, overriding any defaults set in __init__().

Returns:

Tuple of arrays containing optimal values for each CVXPY Variable specified in the variables argument to __init__(). If inputs are batched, outputs will have matching batch dimensions.

Raises:

RuntimeError – If the solver fails to find a solution.

Return type:

tuple[Array, …]

Example

>>> # Single problem
>>> (x_opt,) = layer(A_array, b_array)
>>>
>>> # Batched: solve 10 problems in parallel
>>> A_batch = jax.random.normal(key, (10, 3, 2))
>>> b_batch = jax.random.normal(key, (10, 3))
>>> (x_batch,) = layer(A_batch, b_batch)  # x_batch.shape = (10, 2)

Usage Example

import cvxpy as cp
import jax
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer

# Define problem
n, m = 2, 3
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)
problem = cp.Problem(
    cp.Minimize(cp.sum_squares(A @ x - b)),
    [x >= 0]
)

# Create layer
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])

# Solve
key = jax.random.PRNGKey(0)
key, k1, k2 = jax.random.split(key, 3)
A_jax = jax.random.normal(k1, shape=(m, n))
b_jax = jax.random.normal(k2, shape=(m,))

(x_sol,) = layer(A_jax, b_jax)

Computing Gradients

Use jax.grad to compute gradients:

def loss_fn(A, b):
    (x,) = layer(A, b)
    return jnp.sum(x)

# Gradient with respect to A and b
grad_fn = jax.grad(loss_fn, argnums=[0, 1])
dA, db = grad_fn(A_jax, b_jax)

JIT Compilation

jax.jit is supported with the Moreau solver:

import jax
import jax.numpy as jnp

# Create layer with Moreau solver
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x], solver="MOREAU")

# JIT the entire solve + gradient computation
@jax.jit
def solve_and_grad(A, b):
    (x,) = layer(A, b)
    return jnp.sum(x)

grad_fn = jax.grad(solve_and_grad, argnums=[0, 1])
dA, db = grad_fn(A_jax, b_jax)

Note

JIT compilation requires solver="MOREAU". Other solvers (DIFFCP) use Python callbacks that are not JIT-compatible.

Vectorization with vmap

Use jax.vmap for batched execution:

# Batched solve
batch_size = 10
A_batch = jax.random.normal(key, shape=(batch_size, m, n))

@jax.vmap
def solve_single(A):
    (x,) = layer(A, b_jax)
    return x

x_batch = solve_single(A_batch)  # Shape: (10, n)

# Or use built-in batching
(x_batch,) = layer(A_batch, b_jax)