from typing import Any, cast
import cvxpy as cp
import jax
import jax.experimental.sparse
import jax.numpy as jnp
import numpy as np
import scipy.sparse
import cvxpylayers.utils.parse_args as pa
def _reshape_fortran(array: jnp.ndarray, shape: tuple) -> jnp.ndarray:
"""Reshape array using Fortran (column-major) order.
Args:
array: Input array to reshape
shape: Target shape tuple
Returns:
Reshaped array in Fortran order
"""
return jnp.reshape(array, shape, order="F")
def _apply_gp_log_transform(
params: tuple[jnp.ndarray, ...],
ctx: pa.LayersContext,
) -> tuple[jnp.ndarray, ...]:
"""Apply log transformation to geometric program (GP) parameters.
Geometric programs are solved in log-space after conversion to DCP.
This function applies log transformation to the appropriate parameters.
Args:
params: Tuple of parameter arrays in original GP space
ctx: Layer context containing GP parameter mapping info
Returns:
Tuple of transformed parameters (log-space for GP params, unchanged otherwise)
"""
if not ctx.gp or ctx.gp_log_mask is None:
return params
# Use pre-computed mask for JIT compatibility (no dict lookups)
return tuple(
jnp.log(p) if needs_log else p
for p, needs_log in zip(params, ctx.gp_log_mask)
)
def _flatten_and_batch_params(
params: tuple[jnp.ndarray, ...],
ctx: pa.LayersContext,
batch: tuple,
) -> jnp.ndarray:
"""Flatten and batch parameters into a single stacked array.
Converts a tuple of parameter arrays (potentially with mixed batched/unbatched)
into a single concatenated array suitable for matrix multiplication with the
parametrized problem matrices.
Args:
params: Tuple of parameter arrays
ctx: Layer context with batch info and ordering
batch: Batch dimensions tuple (empty if unbatched)
Returns:
Concatenated parameter array with shape (num_params, batch_size) or (num_params,)
"""
flattened_params: list[jnp.ndarray | None] = [None] * (len(params) + 1)
for i, param in enumerate(params):
# Check if this parameter is batched or needs broadcasting
if ctx.batch_sizes[i] == 0 and batch: # type: ignore[index]
# Unbatched parameter - expand to match batch size
param_expanded = jnp.broadcast_to(jnp.expand_dims(param, 0), batch + param.shape)
flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
param_expanded,
batch + (-1,),
)
else:
# Already batched or no batch dimension needed
flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
param,
batch + (-1,),
)
# Add constant 1.0 column for offset terms in canonical form
flattened_params[-1] = jnp.ones(batch + (1,), dtype=params[0].dtype)
assert all(p is not None for p in flattened_params), "All parameters must be assigned"
p_stack = jnp.concatenate(cast(list[jnp.ndarray], flattened_params), -1)
# When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size)
if batch:
p_stack = p_stack.T
return p_stack
def _svec_to_symmetric(
svec: jnp.ndarray,
n: int,
batch: tuple,
rows: np.ndarray,
cols: np.ndarray,
scale: np.ndarray | None = None,
) -> jnp.ndarray:
"""Convert vectorized form to full symmetric matrix.
Args:
svec: Vectorized form, shape (*batch, n*(n+1)/2)
n: Matrix dimension
batch: Batch dimensions
rows: Row indices for unpacking
cols: Column indices for unpacking
scale: Optional scaling factors (for svec format with sqrt(2) scaling)
Returns:
Full symmetric matrix, shape (*batch, n, n)
"""
rows_arr = jnp.array(rows)
cols_arr = jnp.array(cols)
data = svec * jnp.array(scale) if scale is not None else svec
out_shape = batch + (n, n)
result = jnp.zeros(out_shape, dtype=svec.dtype)
result = result.at[..., rows_arr, cols_arr].set(data)
result = result.at[..., cols_arr, rows_arr].set(data)
return result
def _unpack_primal_svec(svec: jnp.ndarray, n: int, batch: tuple) -> jnp.ndarray:
"""Unpack symmetric primal variable from vectorized form.
CVXPY stores symmetric variables in upper triangular row-major order:
[X[0,0], X[0,1], ..., X[0,n-1], X[1,1], X[1,2], ..., X[n-1,n-1]]
Args:
svec: Vectorized symmetric variable
n: Matrix dimension
batch: Batch dimensions
Returns:
Full symmetric matrix
"""
rows, cols = np.triu_indices(n)
return _svec_to_symmetric(svec, n, batch, rows, cols)
def _unpack_svec(svec: jnp.ndarray, n: int, batch: tuple) -> jnp.ndarray:
"""Unpack scaled vectorized (svec) form to full symmetric matrix.
The svec format stores a symmetric n x n matrix as a vector of length n*(n+1)/2,
with off-diagonal elements scaled by sqrt(2). Uses column-major lower triangular
ordering: (0,0), (1,0), (1,1), (2,0), ...
Args:
svec: Scaled vectorized form
n: Matrix dimension
batch: Batch dimensions
Returns:
Full symmetric matrix with scaling removed
"""
rows_rm, cols_rm = np.tril_indices(n)
sort_idx = np.lexsort((rows_rm, cols_rm))
rows = rows_rm[sort_idx]
cols = cols_rm[sort_idx]
# Scale: 1.0 for diagonal, 1/sqrt(2) for off-diagonal
scale = np.where(rows == cols, 1.0, 1.0 / np.sqrt(2.0))
return _svec_to_symmetric(svec, n, batch, rows, cols, scale)
def _recover_results(
primal: jnp.ndarray,
dual: jnp.ndarray,
ctx: pa.LayersContext,
batch: tuple,
) -> tuple[jnp.ndarray, ...]:
"""Recover variable values from primal/dual solutions.
Extracts the requested variables from the solver's primal and dual
solutions, unpacks symmetric matrices if needed, applies inverse GP
transformation, and removes batch dimension for unbatched inputs.
Args:
primal: Primal solution from solver
dual: Dual solution from solver
ctx: Layer context with variable recovery info
batch: Batch dimensions tuple (empty if unbatched)
Returns:
Tuple of recovered variable values
"""
results = []
batch_shape = tuple(primal.shape[:-1])
for var in ctx.var_recover:
# Use pre-computed source field to select data (JIT-compatible)
if var.source == "primal":
data = primal[..., var.primal]
else: # var.source == "dual"
data = dual[..., var.dual]
# Use pre-computed unpack_fn field (JIT-compatible)
if var.unpack_fn == "svec_primal":
results.append(_unpack_primal_svec(data, var.shape[0], batch_shape))
elif var.unpack_fn == "svec_dual":
results.append(_unpack_svec(data, var.shape[0], batch_shape))
elif var.unpack_fn == "reshape":
results.append(_reshape_fortran(data, batch_shape + var.shape))
else:
raise ValueError(f"Unknown variable recovery type: {var.unpack_fn}")
# Apply exp transformation to recover primal variables from log-space for GP
# (dual variables stay in original space - no transformation needed)
# Uses pre-computed source field (JIT-compatible)
if ctx.gp:
results = [
jnp.exp(r) if var.source == "primal" else r
for r, var in zip(results, ctx.var_recover)
]
# Squeeze batch dimension for unbatched inputs
if not batch:
results = [jnp.squeeze(r, 0) for r in results]
return tuple(results)
[docs]
class CvxpyLayer:
"""A differentiable convex optimization layer for JAX.
This layer wraps a parametrized CVXPY problem, solving it in the forward pass
and computing gradients via implicit differentiation. Compatible with
``jax.grad``, ``jax.jit``, and ``jax.vmap``.
JIT/vmap Compatibility:
When using solver="MOREAU", this layer is fully compatible with
jax.jit, jax.vmap, and jax.pmap. The Moreau solver provides native
JAX autodiff support via custom_vjp with pure_callback, enabling
JIT compilation of the entire solve-differentiate pipeline.
Other solvers (DIFFCP) use Python-based solving and are not
JIT-compatible due to closure-based gradient handling.
Example:
>>> import cvxpy as cp
>>> import jax
>>> import jax.numpy as jnp
>>> from cvxpylayers.jax import CvxpyLayer
>>>
>>> # Define a simple QP
>>> x = cp.Variable(2)
>>> A = cp.Parameter((3, 2))
>>> b = cp.Parameter(3)
>>> problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0])
>>>
>>> # Create the layer
>>> layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
>>>
>>> # Solve and compute gradients
>>> A_jax = jax.random.normal(jax.random.PRNGKey(0), (3, 2))
>>> b_jax = jax.random.normal(jax.random.PRNGKey(1), (3,))
>>> (solution,) = layer(A_jax, b_jax)
>>>
>>> # Gradient computation
>>> def loss_fn(A, b):
... (x,) = layer(A, b)
... return jnp.sum(x)
>>> grads = jax.grad(loss_fn, argnums=[0, 1])(A_jax, b_jax)
"""
[docs]
def __init__(
self,
problem: cp.Problem,
parameters: list[cp.Parameter],
variables: list[cp.Variable],
solver: str | None = None,
gp: bool = False,
verbose: bool = False,
canon_backend: str | None = None,
solver_args: dict[str, Any] | None = None,
) -> None:
"""Initialize the differentiable optimization layer.
Args:
problem: A CVXPY Problem. Must be DPP-compliant (``problem.is_dpp()``
must return True).
parameters: List of CVXPY Parameters that will be filled with values
at runtime. Order must match the order of arrays passed to __call__().
variables: List of CVXPY Variables whose optimal values will be returned
by __call__(). Order determines the order of returned arrays.
solver: CVXPY solver to use (e.g., ``cp.CLARABEL``, ``cp.SCS``).
If None, uses the default diffcp solver.
gp: If True, problem is a geometric program. Parameters will be
log-transformed before solving.
verbose: If True, print solver output.
canon_backend: Backend for canonicalization. Options are 'diffcp',
'cuclarabel', or None (auto-select).
solver_args: Default keyword arguments passed to the solver.
Can be overridden per-call in __call__().
Raises:
AssertionError: If problem is not DPP-compliant.
ValueError: If parameters or variables are not part of the problem.
"""
if solver_args is None:
solver_args = {}
self.ctx = pa.parse_args(
problem,
variables,
parameters,
solver,
gp=gp,
verbose=verbose,
canon_backend=canon_backend,
solver_args=solver_args,
)
if self.ctx.reduced_P.reduced_mat is not None: # type: ignore[attr-defined]
self.P = scipy_csr_to_jax_bcsr(self.ctx.reduced_P.reduced_mat) # type: ignore[attr-defined]
else:
self.P = None
self.q: jax.experimental.sparse.BCSR = scipy_csr_to_jax_bcsr(self.ctx.q.tocsr()) # type: ignore[assignment]
self.A: jax.experimental.sparse.BCSR = scipy_csr_to_jax_bcsr(self.ctx.reduced_A.reduced_mat) # type: ignore[attr-defined,assignment]
# Cache the Moreau solve function for JIT compatibility.
# Must be captured as closure, not looked up dynamically during tracing.
if self.ctx.solver == "MOREAU":
self._moreau_solve_fn = self.ctx.solver_ctx.get_jax_solver()._impl.solve
[docs]
def __call__(
self, *params: jnp.ndarray, solver_args: dict[str, Any] | None = None
) -> tuple[jnp.ndarray, ...]:
"""Solve the optimization problem and return optimal variable values.
Args:
*params: Array values for each CVXPY Parameter, in the same order
as the ``parameters`` argument to __init__(). Each array shape must
match the corresponding Parameter shape, optionally with a batch
dimension prepended. Batched and unbatched parameters can be mixed;
unbatched parameters are broadcast.
solver_args: Keyword arguments passed to the solver, overriding any
defaults set in __init__().
Returns:
Tuple of arrays containing optimal values for each CVXPY Variable
specified in the ``variables`` argument to __init__(). If inputs are
batched, outputs will have matching batch dimensions.
Raises:
RuntimeError: If the solver fails to find a solution.
Example:
>>> # Single problem
>>> (x_opt,) = layer(A_array, b_array)
>>>
>>> # Batched: solve 10 problems in parallel
>>> A_batch = jax.random.normal(key, (10, 3, 2))
>>> b_batch = jax.random.normal(key, (10, 3))
>>> (x_batch,) = layer(A_batch, b_batch) # x_batch.shape = (10, 2)
"""
if solver_args is None:
solver_args = {}
batch = self.ctx.validate_params(list(params))
# Apply log transformation to GP parameters
params = _apply_gp_log_transform(params, self.ctx)
# Flatten and batch parameters
p_stack = _flatten_and_batch_params(params, self.ctx, batch)
# Evaluate parametrized matrices
P_eval = self.P @ p_stack if self.P is not None else None
q_eval = self.q @ p_stack
A_eval = self.A @ p_stack
# Check if solver has native JAX autodiff (Moreau)
# If so, bypass the closure-based custom_vjp wrapper and use Moreau's native autodiff
if self.ctx.solver == "MOREAU":
return self._solve_moreau(P_eval, q_eval, A_eval, batch, solver_args)
# Non-Moreau: use existing custom_vjp wrapper (not JIT-compatible)
return self._solve_with_custom_vjp(P_eval, q_eval, A_eval, batch, solver_args)
def _solve_moreau(
self,
P_eval: jnp.ndarray | None,
q_eval: jnp.ndarray,
A_eval: jnp.ndarray,
batch: tuple,
solver_args: dict[str, Any],
) -> tuple[jnp.ndarray, ...]:
"""Direct call to Moreau solver - uses its native custom_vjp (JIT-compatible).
Moreau's JAX solver (moreau.jax.Solver) implements custom_vjp with
pure_callback and vmap_method="broadcast_all", making it fully compatible
with jax.jit, jax.vmap, and jax.pmap.
This method uses jax.vmap for batched cases to ensure each problem is
solved individually, which is required for JIT compatibility.
"""
solver_ctx = self.ctx.solver_ctx # type: ignore[attr-defined]
# Apply per-call solver_args to solver settings
if solver_args:
settings = solver_ctx.get_jax_solver()._impl._settings
for key, value in solver_args.items():
if hasattr(settings, key):
setattr(settings, key, value)
# Use cached solve function (captured as closure, not dynamic lookup)
solve_fn = self._moreau_solve_fn
# Cache solver_ctx attributes for closure capture
P_idx = solver_ctx.P_idx
A_idx = solver_ctx.A_idx
b_idx = solver_ctx.b_idx
m = solver_ctx.A_shape[0]
def extract_and_solve(
P_eval_single: jnp.ndarray | None,
q_eval_single: jnp.ndarray,
A_eval_single: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Extract problem data and solve a single (unbatched) problem."""
# Extract P values in CSR order
if P_idx is not None and P_eval_single is not None:
P_values = P_eval_single[P_idx] # (nnzP,)
else:
P_values = jnp.zeros(0, dtype=jnp.float64)
# Extract A values in CSR order (negated for Ax + s = b form)
A_values = -A_eval_single[A_idx] # (nnzA,)
# Extract b vector from the end of A_eval
b_start = A_eval_single.shape[0] - b_idx.size
b_raw = A_eval_single[b_start:] # (b_idx.size,)
# Scatter b_raw into full b vector at correct indices
b = jnp.zeros(m, dtype=jnp.float64)
b = b.at[b_idx].set(b_raw)
# Extract q (linear cost) - exclude the constant term at the end
q = q_eval_single[:-1] # (n,)
# Call moreau's solve with unbatched inputs
solution, _info = solve_fn(P_values, A_values, q, b)
return solution.x, solution.z # (n,), (m,)
if batch:
# Batched case: inputs have shape (dim, batch)
# Transpose to (batch, dim) for vmap
P_eval_batched = P_eval.T if P_eval is not None else None
q_eval_batched = q_eval.T # (batch, dim)
A_eval_batched = A_eval.T # (batch, dim)
# Use vmap to solve each problem individually
vmapped_solve = jax.vmap(extract_and_solve, in_axes=(0, 0, 0))
primal, dual = vmapped_solve(P_eval_batched, q_eval_batched, A_eval_batched)
# primal: (batch, n), dual: (batch, m)
else:
# Unbatched case: inputs have shape (dim,)
primal, dual = extract_and_solve(P_eval, q_eval, A_eval)
# primal: (n,), dual: (m,)
# Add batch dimension for _recover_results (which expects it)
primal = jnp.expand_dims(primal, 0) # (1, n)
dual = jnp.expand_dims(dual, 0) # (1, m)
return _recover_results(primal, dual, self.ctx, batch)
def _solve_with_custom_vjp(
self,
P_eval: jnp.ndarray | None,
q_eval: jnp.ndarray,
A_eval: jnp.ndarray,
batch: tuple,
solver_args: dict[str, Any],
) -> tuple[jnp.ndarray, ...]:
"""Solve using closure-based custom_vjp wrapper (not JIT-compatible).
This is used for non-Moreau solvers (e.g., DIFFCP) that don't have
native JAX autodiff support. The closure-based approach stores data
in a Python dict for the backward pass, which breaks JIT tracing.
"""
# Store data and adjoint in closure for backward pass
# This avoids JAX trying to trace through DIFFCP's Python-based solver
data_container: dict[str, Any] = {}
@jax.custom_vjp
def solve_problem(P_eval, q_eval, A_eval):
# Forward pass: solve the optimization problem
data = self.ctx.solver_ctx.jax_to_data(P_eval, q_eval, A_eval) # type: ignore[attr-defined]
primal, dual, adj_batch = data.jax_solve(solver_args) # type: ignore[attr-defined]
# Store for backward pass (outside JAX tracing)
data_container["data"] = data
data_container["adj_batch"] = adj_batch
return primal, dual
def solve_problem_fwd(P_eval, q_eval, A_eval):
# Call forward to execute and populate container
primal, dual = solve_problem(P_eval, q_eval, A_eval)
# Return empty residuals (data is in closure)
return (primal, dual), ()
def solve_problem_bwd(res, g):
# Backward pass using adjoint method
dprimal, ddual = g
data = data_container["data"]
adj_batch = data_container["adj_batch"]
dP, dq, dA = data.jax_derivative(dprimal, ddual, adj_batch)
return dP, dq, dA
solve_problem.defvjp(solve_problem_fwd, solve_problem_bwd)
primal, dual = solve_problem(P_eval, q_eval, A_eval)
# Recover results and apply GP inverse transform if needed
return _recover_results(primal, dual, self.ctx, batch)
def scipy_csr_to_jax_bcsr(
scipy_csr: scipy.sparse.csr_array | None,
) -> jax.experimental.sparse.BCSR | None:
if scipy_csr is None:
return None
# Use cast to help type checker understand scipy_csr is not None
scipy_csr = cast(scipy.sparse.csr_array, scipy_csr)
num_rows, num_cols = scipy_csr.shape # type: ignore[misc]
# JAX BCSR doesn't handle empty matrices (0 rows) properly.
# Create a minimal valid BCSR with a single zero element instead.
if num_rows == 0:
# Create a (1, num_cols) matrix with a single zero at position (0, 0)
# This will produce a (1, ...) result when multiplied, which we'll slice to (0, ...)
return _EmptyBCSRWrapper(num_cols)
# Get the CSR format components
values = scipy_csr.data
col_indices = scipy_csr.indices
row_ptr = scipy_csr.indptr
# Create the JAX BCSR tensor
jax_bcsr = jax.experimental.sparse.BCSR(
(jnp.array(values), jnp.array(col_indices), jnp.array(row_ptr)),
shape=(num_rows, num_cols),
)
return jax_bcsr
class _EmptyBCSRWrapper:
"""Wrapper for empty (0-row) sparse matrices that JAX BCSR can't handle.
When multiplied with a vector/matrix, returns an empty array with the correct shape.
"""
def __init__(self, num_cols: int):
self.num_cols = num_cols
self.shape = (0, num_cols)
def __matmul__(self, other: jnp.ndarray) -> jnp.ndarray:
# other shape: (num_cols,) or (num_cols, batch)
if other.ndim == 1:
return jnp.zeros((0,), dtype=other.dtype)
else:
batch_size = other.shape[1]
return jnp.zeros((0, batch_size), dtype=other.dtype)