Source code for cvxpylayers.jax.cvxpylayer

from typing import Any, cast

import cvxpy as cp
import jax
import jax.experimental.sparse
import jax.numpy as jnp
import numpy as np
import scipy.sparse

try:
    from moreau._types import BatchedWarmStart, WarmStart
except ImportError:
    WarmStart = None  # type: ignore[assignment,misc]
    BatchedWarmStart = None  # type: ignore[assignment,misc]

import cvxpylayers.utils.parse_args as pa


def _reshape_fortran(array: jnp.ndarray, shape: tuple) -> jnp.ndarray:
    """Reshape array using Fortran (column-major) order.

    Args:
        array: Input array to reshape
        shape: Target shape tuple

    Returns:
        Reshaped array in Fortran order
    """
    return jnp.reshape(array, shape, order="F")


def _apply_gp_log_transform(
    params: tuple[jnp.ndarray, ...],
    ctx: pa.LayersContext,
) -> tuple[jnp.ndarray, ...]:
    """Apply log transformation to geometric program (GP) parameters.

    Geometric programs are solved in log-space after conversion to DCP.
    This function applies log transformation to the appropriate parameters.

    Args:
        params: Tuple of parameter arrays in original GP space
        ctx: Layer context containing GP parameter mapping info

    Returns:
        Tuple of transformed parameters (log-space for GP params, unchanged otherwise)
    """
    if not ctx.gp or ctx.gp_log_mask is None:
        return params

    # Use pre-computed mask for JIT compatibility (no dict lookups)
    return tuple(
        jnp.log(p) if needs_log else p
        for p, needs_log in zip(params, ctx.gp_log_mask)
    )


def _flatten_and_batch_params(
    params: tuple[jnp.ndarray, ...],
    ctx: pa.LayersContext,
    batch: tuple,
) -> jnp.ndarray:
    """Flatten and batch parameters into a single stacked array.

    Converts a tuple of parameter arrays (potentially with mixed batched/unbatched)
    into a single concatenated array suitable for matrix multiplication with the
    parametrized problem matrices.

    Args:
        params: Tuple of parameter arrays
        ctx: Layer context with batch info and ordering
        batch: Batch dimensions tuple (empty if unbatched)

    Returns:
        Concatenated parameter array with shape (num_params, batch_size) or (num_params,)
    """
    flattened_params: list[jnp.ndarray | None] = [None] * (len(params) + 1)

    for i, param in enumerate(params):
        # Check if this parameter is batched or needs broadcasting
        if ctx.batch_sizes[i] == 0 and batch:  # type: ignore[index]
            # Unbatched parameter - expand to match batch size
            param_expanded = jnp.broadcast_to(jnp.expand_dims(param, 0), batch + param.shape)
            flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
                param_expanded,
                batch + (-1,),
            )
        else:
            # Already batched or no batch dimension needed
            flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
                param,
                batch + (-1,),
            )

    # Add constant 1.0 column for offset terms in canonical form
    flattened_params[-1] = jnp.ones(batch + (1,), dtype=params[0].dtype)

    assert all(p is not None for p in flattened_params), "All parameters must be assigned"
    p_stack = jnp.concatenate(cast(list[jnp.ndarray], flattened_params), -1)
    # When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size)
    if batch:
        p_stack = p_stack.T
    return p_stack


def _svec_to_symmetric(
    svec: jnp.ndarray,
    n: int,
    batch: tuple,
    rows: np.ndarray,
    cols: np.ndarray,
    scale: np.ndarray | None = None,
) -> jnp.ndarray:
    """Convert vectorized form to full symmetric matrix.

    Args:
        svec: Vectorized form, shape (*batch, n*(n+1)/2)
        n: Matrix dimension
        batch: Batch dimensions
        rows: Row indices for unpacking
        cols: Column indices for unpacking
        scale: Optional scaling factors (for svec format with sqrt(2) scaling)

    Returns:
        Full symmetric matrix, shape (*batch, n, n)
    """
    rows_arr = jnp.array(rows)
    cols_arr = jnp.array(cols)
    data = svec * jnp.array(scale) if scale is not None else svec
    out_shape = batch + (n, n)
    result = jnp.zeros(out_shape, dtype=svec.dtype)
    result = result.at[..., rows_arr, cols_arr].set(data)
    result = result.at[..., cols_arr, rows_arr].set(data)
    return result


def _unpack_primal_svec(svec: jnp.ndarray, n: int, batch: tuple) -> jnp.ndarray:
    """Unpack symmetric primal variable from vectorized form.

    CVXPY stores symmetric variables in upper triangular row-major order:
    [X[0,0], X[0,1], ..., X[0,n-1], X[1,1], X[1,2], ..., X[n-1,n-1]]

    Args:
        svec: Vectorized symmetric variable
        n: Matrix dimension
        batch: Batch dimensions

    Returns:
        Full symmetric matrix
    """
    rows, cols = np.triu_indices(n)
    return _svec_to_symmetric(svec, n, batch, rows, cols)


def _unpack_svec(svec: jnp.ndarray, n: int, batch: tuple) -> jnp.ndarray:
    """Unpack scaled vectorized (svec) form to full symmetric matrix.

    The svec format stores a symmetric n x n matrix as a vector of length n*(n+1)/2,
    with off-diagonal elements scaled by sqrt(2). Uses column-major lower triangular
    ordering: (0,0), (1,0), (1,1), (2,0), ...

    Args:
        svec: Scaled vectorized form
        n: Matrix dimension
        batch: Batch dimensions

    Returns:
        Full symmetric matrix with scaling removed
    """
    rows_rm, cols_rm = np.tril_indices(n)
    sort_idx = np.lexsort((rows_rm, cols_rm))
    rows = rows_rm[sort_idx]
    cols = cols_rm[sort_idx]
    # Scale: 1.0 for diagonal, 1/sqrt(2) for off-diagonal
    scale = np.where(rows == cols, 1.0, 1.0 / np.sqrt(2.0))
    return _svec_to_symmetric(svec, n, batch, rows, cols, scale)


def _recover_results(
    primal: jnp.ndarray,
    dual: jnp.ndarray,
    ctx: pa.LayersContext,
    batch: tuple,
) -> tuple[jnp.ndarray, ...]:
    """Recover variable values from primal/dual solutions.

    Extracts the requested variables from the solver's primal and dual
    solutions, unpacks symmetric matrices if needed, applies inverse GP
    transformation, and removes batch dimension for unbatched inputs.

    Args:
        primal: Primal solution from solver
        dual: Dual solution from solver
        ctx: Layer context with variable recovery info
        batch: Batch dimensions tuple (empty if unbatched)

    Returns:
        Tuple of recovered variable values
    """
    results = []
    batch_shape = tuple(primal.shape[:-1])

    for var in ctx.var_recover:
        # Use pre-computed source field to select data (JIT-compatible)
        if var.source == "primal":
            data = primal[..., var.primal]
        else:  # var.source == "dual"
            data = dual[..., var.dual]

        # Use pre-computed unpack_fn field (JIT-compatible)
        if var.unpack_fn == "svec_primal":
            results.append(_unpack_primal_svec(data, var.shape[0], batch_shape))
        elif var.unpack_fn == "svec_dual":
            results.append(_unpack_svec(data, var.shape[0], batch_shape))
        elif var.unpack_fn == "reshape":
            results.append(_reshape_fortran(data, batch_shape + var.shape))
        else:
            raise ValueError(f"Unknown variable recovery type: {var.unpack_fn}")

    # Apply exp transformation to recover primal variables from log-space for GP
    # (dual variables stay in original space - no transformation needed)
    # Uses pre-computed source field (JIT-compatible)
    if ctx.gp:
        results = [
            jnp.exp(r) if var.source == "primal" else r
            for r, var in zip(results, ctx.var_recover)
        ]

    # Squeeze batch dimension for unbatched inputs
    if not batch:
        results = [jnp.squeeze(r, 0) for r in results]

    return tuple(results)


[docs] class CvxpyLayer: """A differentiable convex optimization layer for JAX. This layer wraps a parametrized CVXPY problem, solving it in the forward pass and computing gradients via implicit differentiation. Compatible with ``jax.grad``, ``jax.jit``, and ``jax.vmap``. JIT/vmap Compatibility: When using solver="MOREAU", this layer is fully compatible with jax.jit, jax.vmap, and jax.pmap. The Moreau solver provides native JAX autodiff support via custom_vjp with pure_callback, enabling JIT compilation of the entire solve-differentiate pipeline. Other solvers (DIFFCP) use Python-based solving and are not JIT-compatible due to closure-based gradient handling. Example: >>> import cvxpy as cp >>> import jax >>> import jax.numpy as jnp >>> from cvxpylayers.jax import CvxpyLayer >>> >>> # Define a simple QP >>> x = cp.Variable(2) >>> A = cp.Parameter((3, 2)) >>> b = cp.Parameter(3) >>> problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0]) >>> >>> # Create the layer >>> layer = CvxpyLayer(problem, parameters=[A, b], variables=[x]) >>> >>> # Solve and compute gradients >>> A_jax = jax.random.normal(jax.random.PRNGKey(0), (3, 2)) >>> b_jax = jax.random.normal(jax.random.PRNGKey(1), (3,)) >>> (solution,) = layer(A_jax, b_jax) >>> >>> # Gradient computation >>> def loss_fn(A, b): ... (x,) = layer(A, b) ... return jnp.sum(x) >>> grads = jax.grad(loss_fn, argnums=[0, 1])(A_jax, b_jax) """
[docs] def __init__( self, problem: cp.Problem, parameters: list[cp.Parameter], variables: list[cp.Variable], solver: str | None = None, gp: bool = False, verbose: bool = False, canon_backend: str | None = None, solver_args: dict[str, Any] | None = None, ) -> None: """Initialize the differentiable optimization layer. Args: problem: A CVXPY Problem. Must be DPP-compliant (``problem.is_dpp()`` must return True). parameters: List of CVXPY Parameters that will be filled with values at runtime. Order must match the order of arrays passed to __call__(). variables: List of CVXPY Variables whose optimal values will be returned by __call__(). Order determines the order of returned arrays. solver: CVXPY solver to use (e.g., ``cp.CLARABEL``, ``cp.SCS``). If None, uses the default diffcp solver. gp: If True, problem is a geometric program. Parameters will be log-transformed before solving. verbose: If True, print solver output. canon_backend: Backend for canonicalization. Options are 'diffcp', 'cuclarabel', or None (auto-select). solver_args: Default keyword arguments passed to the solver. Can be overridden per-call in __call__(). Raises: AssertionError: If problem is not DPP-compliant. ValueError: If parameters or variables are not part of the problem. """ if solver_args is None: solver_args = {} self.ctx = pa.parse_args( problem, variables, parameters, solver, gp=gp, verbose=verbose, canon_backend=canon_backend, solver_args=solver_args, ) if self.ctx.reduced_P.reduced_mat is not None: # type: ignore[attr-defined] self.P = scipy_csr_to_jax_bcsr(self.ctx.reduced_P.reduced_mat) # type: ignore[attr-defined] else: self.P = None self.q: jax.experimental.sparse.BCSR = scipy_csr_to_jax_bcsr(self.ctx.q.tocsr()) # type: ignore[assignment] self.A: jax.experimental.sparse.BCSR = scipy_csr_to_jax_bcsr(self.ctx.reduced_A.reduced_mat) # type: ignore[attr-defined,assignment] # Cache the Moreau solve functions for JIT compatibility. # Must be captured as closure, not looked up dynamically during tracing. if self.ctx.solver == "MOREAU": self._moreau_jax_impl = self.ctx.solver_ctx.get_jax_solver()._impl self._moreau_solve_fn = self._moreau_jax_impl.solve # solve_warm accepts (P, A, q, b, warm_x, warm_z, warm_s) as pure # function args — vmap-compatible and avoids the _pending_warm_start # side-channel. Only available on CUDA; None on CPU. self._moreau_solve_warm_fn = self._moreau_jax_impl.solve_warm self._warm_start_cache = None
[docs] def __call__( self, *params: jnp.ndarray, solver_args: dict[str, Any] | None = None, warm_start: bool = False, ) -> tuple[jnp.ndarray, ...]: """Solve the optimization problem and return optimal variable values. Args: *params: Array values for each CVXPY Parameter, in the same order as the ``parameters`` argument to __init__(). Each array shape must match the corresponding Parameter shape, optionally with a batch dimension prepended. Batched and unbatched parameters can be mixed; unbatched parameters are broadcast. solver_args: Keyword arguments passed to the solver, overriding any defaults set in __init__(). warm_start: If True, use the cached solution from the previous solve as a warm start for the solver. Only supported with solver="MOREAU". Returns: Tuple of arrays containing optimal values for each CVXPY Variable specified in the ``variables`` argument to __init__(). If inputs are batched, outputs will have matching batch dimensions. Raises: RuntimeError: If the solver fails to find a solution. ValueError: If warm_start=True with a non-Moreau solver. Example: >>> # Single problem >>> (x_opt,) = layer(A_array, b_array) >>> >>> # Batched: solve 10 problems in parallel >>> A_batch = jax.random.normal(key, (10, 3, 2)) >>> b_batch = jax.random.normal(key, (10, 3)) >>> (x_batch,) = layer(A_batch, b_batch) # x_batch.shape = (10, 2) """ if solver_args is None: solver_args = {} if warm_start and self.ctx.solver != "MOREAU": raise ValueError( "warm_start=True is only supported with solver='MOREAU'. " f"Current solver is '{self.ctx.solver}'." ) batch = self.ctx.validate_params(list(params)) # Apply log transformation to GP parameters params = _apply_gp_log_transform(params, self.ctx) # Flatten and batch parameters p_stack = _flatten_and_batch_params(params, self.ctx, batch) # Evaluate parametrized matrices P_eval = self.P @ p_stack if self.P is not None else None q_eval = self.q @ p_stack A_eval = self.A @ p_stack # Check if solver has native JAX autodiff (Moreau) # If so, bypass the closure-based custom_vjp wrapper and use Moreau's native autodiff if self.ctx.solver == "MOREAU": ws = self._warm_start_cache if warm_start else None return self._solve_moreau(P_eval, q_eval, A_eval, batch, solver_args, warm_start=ws) # Non-Moreau: use existing custom_vjp wrapper (not JIT-compatible) return self._solve_with_custom_vjp(P_eval, q_eval, A_eval, batch, solver_args)
@staticmethod def _validate_warm_start(warm_start: Any, batch: tuple) -> Any: """Check that cached warm start is compatible with the current batch size. Returns the warm start if compatible, None otherwise. """ if warm_start is None: return None cached_ndim = warm_start.x.ndim current_is_batched = bool(batch) if cached_ndim == 1 and not current_is_batched: return warm_start # Both unbatched if cached_ndim == 2 and current_is_batched: if warm_start.x.shape[0] == batch[0]: return warm_start # Batch sizes match return None # Mismatch — skip warm start def _solve_moreau( self, P_eval: jnp.ndarray | None, q_eval: jnp.ndarray, A_eval: jnp.ndarray, batch: tuple, solver_args: dict[str, Any], warm_start: Any = None, ) -> tuple[jnp.ndarray, ...]: """Direct call to Moreau solver - uses its native custom_vjp (JIT-compatible). Moreau's JAX solver (moreau.jax.Solver) implements custom_vjp with pure_callback and vmap_method="broadcast_all", making it fully compatible with jax.jit, jax.vmap, and jax.pmap. This method uses jax.vmap for batched cases to ensure each problem is solved individually, which is required for JIT compatibility. Args: warm_start: Optional WarmStart or BatchedWarmStart from a previous solve, passed as extra arguments to solve_warm when available. """ solver_ctx = self.ctx.solver_ctx # type: ignore[attr-defined] jax_solver = solver_ctx.get_jax_solver() # Apply per-call solver_args to solver settings if solver_args: settings = jax_solver._impl._settings for key, value in solver_args.items(): if hasattr(settings, key): setattr(settings, key, value) # Validate warm start batch compatibility warm_start = self._validate_warm_start(warm_start, batch) # Prepare warm start arrays for solve_warm (public API, vmap-safe) if warm_start is not None: warm_x = jnp.asarray(warm_start.x, dtype=jnp.float64) warm_z = jnp.asarray(warm_start.z, dtype=jnp.float64) warm_s = jnp.asarray(warm_start.s, dtype=jnp.float64) # Use cached solve functions (captured as closure, not dynamic lookup) solve_fn = self._moreau_solve_fn solve_warm_fn = self._moreau_solve_warm_fn # CPU fallback: solve_warm is None on CPU backend. # Use _pending_warm_start side channel with the regular solve function. use_side_channel_warm_start = ( warm_start is not None and solve_warm_fn is None ) if use_side_channel_warm_start: self._moreau_jax_impl._pending_warm_start = { 'warm_x': np.asarray(warm_start.x, dtype=np.float64), 'warm_z': np.asarray(warm_start.z, dtype=np.float64), 'warm_s': np.asarray(warm_start.s, dtype=np.float64), } # Cache solver_ctx attributes for closure capture. # Parametrization matrices are pre-permuted so matrix multiplication # directly produces values in CSR order — no shuffle indexing needed. b_idx = solver_ctx.b_idx nnz_A = solver_ctx.nnz_A m = solver_ctx.A_shape[0] def _extract_problem_data( P_eval_single: jnp.ndarray | None, q_eval_single: jnp.ndarray, A_eval_single: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Extract P_values, A_values, q, b from parametrized matrices.""" P_values = P_eval_single if P_eval_single is not None else jnp.zeros(0, dtype=jnp.float64) A_values = -A_eval_single[:nnz_A] b_raw = A_eval_single[nnz_A:] b = jnp.zeros(m, dtype=jnp.float64) b = b.at[b_idx].set(b_raw) q = q_eval_single[:-1] return P_values, A_values, q, b def extract_and_solve( P_eval_single, q_eval_single, A_eval_single, ): """Extract problem data and solve (cold).""" P_values, A_values, q, b = _extract_problem_data( P_eval_single, q_eval_single, A_eval_single ) solution, _info = solve_fn(P_values, A_values, q, b) return solution.x, solution.z, solution.s def extract_and_solve_warm( P_eval_single, q_eval_single, A_eval_single, ws_x, ws_z, ws_s, ): """Extract problem data and solve with warm start.""" P_values, A_values, q, b = _extract_problem_data( P_eval_single, q_eval_single, A_eval_single ) solution, _info = solve_warm_fn( P_values, A_values, q, b, ws_x, ws_z, ws_s ) return solution.x, solution.z, solution.s # Select solve function and extra args. # When solve_warm_fn is available (CUDA), pass warm arrays as positional args. # When solve_warm_fn is None (CPU), warm start is already set via # _pending_warm_start side channel above — use the cold solve function. if warm_start is not None and solve_warm_fn is not None: solve = extract_and_solve_warm extra_args = (warm_x, warm_z, warm_s) else: solve = extract_and_solve extra_args = () if batch: P_eval_t = P_eval.T if P_eval is not None else None vmapped = jax.vmap(solve, in_axes=(0,) * (3 + len(extra_args))) primal, dual, slack = vmapped( P_eval_t, q_eval.T, A_eval.T, *extra_args ) else: primal, dual, slack = solve(P_eval, q_eval, A_eval, *extra_args) # Add batch dimension for _recover_results (which expects it) primal = jnp.expand_dims(primal, 0) dual = jnp.expand_dims(dual, 0) if use_side_channel_warm_start: self._moreau_jax_impl._pending_warm_start = None # Always update warm start cache (negligible cost). # Skip when inside jit/vmap — traced arrays can't be converted to numpy. try: if batch: self._warm_start_cache = BatchedWarmStart( x=np.asarray(primal, dtype=np.float64), z=np.asarray(dual, dtype=np.float64), s=np.asarray(slack, dtype=np.float64), ) else: self._warm_start_cache = WarmStart( x=np.asarray(primal.squeeze(0), dtype=np.float64), z=np.asarray(dual.squeeze(0), dtype=np.float64), s=np.asarray(slack, dtype=np.float64), ) except jax.errors.TracerArrayConversionError: pass # Inside jit/vmap — warm start cache not available return _recover_results(primal, dual, self.ctx, batch) def _solve_with_custom_vjp( self, P_eval: jnp.ndarray | None, q_eval: jnp.ndarray, A_eval: jnp.ndarray, batch: tuple, solver_args: dict[str, Any], ) -> tuple[jnp.ndarray, ...]: """Solve using closure-based custom_vjp wrapper (not JIT-compatible). This is used for non-Moreau solvers (e.g., DIFFCP) that don't have native JAX autodiff support. The closure-based approach stores data in a Python dict for the backward pass, which breaks JIT tracing. """ # Store data and adjoint in closure for backward pass # This avoids JAX trying to trace through DIFFCP's Python-based solver data_container: dict[str, Any] = {} @jax.custom_vjp def solve_problem(P_eval, q_eval, A_eval): # Function body: runs when NOT differentiating # Use solve_only to skip computing the adjoint operator data = self.ctx.solver_ctx.jax_to_data(P_eval, q_eval, A_eval) # type: ignore[attr-defined] primal, dual = data.jax_solve_only(solver_args) # type: ignore[attr-defined] return primal, dual def solve_problem_fwd(P_eval, q_eval, A_eval): # Runs when differentiating: compute derivatives for backward pass data = self.ctx.solver_ctx.jax_to_data(P_eval, q_eval, A_eval) # type: ignore[attr-defined] primal, dual, adj_batch = data.jax_solve(solver_args) # type: ignore[attr-defined] # Store for backward pass (outside JAX tracing) data_container["data"] = data data_container["adj_batch"] = adj_batch # Return empty residuals (data is in closure) return (primal, dual), () def solve_problem_bwd(res, g): # Backward pass using adjoint method dprimal, ddual = g data = data_container["data"] adj_batch = data_container["adj_batch"] dP, dq, dA = data.jax_derivative(dprimal, ddual, adj_batch) return dP, dq, dA solve_problem.defvjp(solve_problem_fwd, solve_problem_bwd) primal, dual = solve_problem(P_eval, q_eval, A_eval) # Recover results and apply GP inverse transform if needed return _recover_results(primal, dual, self.ctx, batch)
def scipy_csr_to_jax_bcsr( scipy_csr: scipy.sparse.csr_array | None, ) -> jax.experimental.sparse.BCSR | None: if scipy_csr is None: return None # Use cast to help type checker understand scipy_csr is not None scipy_csr = cast(scipy.sparse.csr_array, scipy_csr) num_rows, num_cols = scipy_csr.shape # type: ignore[misc] # JAX BCSR doesn't handle empty matrices (0 rows) properly. # Create a minimal valid BCSR with a single zero element instead. if num_rows == 0: # Create a (1, num_cols) matrix with a single zero at position (0, 0) # This will produce a (1, ...) result when multiplied, which we'll slice to (0, ...) return _EmptyBCSRWrapper(num_cols) # Get the CSR format components values = scipy_csr.data col_indices = scipy_csr.indices row_ptr = scipy_csr.indptr # Create the JAX BCSR tensor jax_bcsr = jax.experimental.sparse.BCSR( (jnp.array(values), jnp.array(col_indices), jnp.array(row_ptr)), shape=(num_rows, num_cols), ) return jax_bcsr class _EmptyBCSRWrapper: """Wrapper for empty (0-row) sparse matrices that JAX BCSR can't handle. When multiplied with a vector/matrix, returns an empty array with the correct shape. """ def __init__(self, num_cols: int): self.num_cols = num_cols self.shape = (0, num_cols) def __matmul__(self, other: jnp.ndarray) -> jnp.ndarray: # other shape: (num_cols,) or (num_cols, batch) if other.ndim == 1: return jnp.zeros((0,), dtype=other.dtype) else: batch_size = other.shape[1] return jnp.zeros((0, batch_size), dtype=other.dtype)