import warnings
from typing import Any, cast
import cvxpy as cp
import numpy as np
import scipy.sparse
import torch
import cvxpylayers.utils.parse_args as pa
class _ScipySparseMatmul(torch.autograd.Function):
"""Sparse matrix-vector multiply using scipy on CPU with autograd support.
PyTorch sparse CSR matmul on CPU is 80-200x slower than scipy due to
MKL/OpenBLAS thread-spawning overhead. This uses scipy for the forward
pass and computes gradients via the transpose multiply.
"""
@staticmethod
def forward(scipy_csr: scipy.sparse.csr_array, x: torch.Tensor) -> torch.Tensor:
x_np = x.detach().numpy()
result = scipy_csr @ x_np
return torch.from_numpy(np.asarray(result))
@staticmethod
def setup_context(ctx: Any, inputs: tuple, output: torch.Tensor) -> None:
scipy_csr, x = inputs
ctx.scipy_csr_T = scipy_csr.T.tocsr()
ctx.save_for_backward(x)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[None, torch.Tensor]:
grad_np = grad_output.numpy()
result = ctx.scipy_csr_T @ grad_np
return None, torch.from_numpy(np.asarray(result))
def _reshape_fortran(array: torch.Tensor, shape: tuple) -> torch.Tensor:
"""Reshape array using Fortran (column-major) order.
PyTorch doesn't support order='F' directly, so we use permutation.
Args:
array: Input tensor to reshape
shape: Target shape tuple
Returns:
Reshaped tensor in Fortran order
"""
if len(array.shape) == 0:
return array.reshape(shape)
x = array.permute(*reversed(range(len(array.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
def _apply_gp_log_transform(
params: tuple[torch.Tensor, ...],
ctx: pa.LayersContext,
) -> tuple[torch.Tensor, ...]:
"""Apply log transformation to geometric program (GP) parameters.
Geometric programs are solved in log-space after conversion to DCP.
This function applies log transformation to the appropriate parameters.
Args:
params: Tuple of parameter tensors in original GP space
ctx: Layer context containing GP parameter mapping info
Returns:
Tuple of transformed parameters (log-space for GP params, unchanged otherwise)
"""
if not ctx.gp or ctx.gp_log_mask is None:
return params
# Use pre-computed mask for JIT compatibility (no dict lookups)
return tuple(
torch.log(p) if needs_log else p
for p, needs_log in zip(params, ctx.gp_log_mask)
)
def _flatten_and_batch_params(
params: tuple[torch.Tensor, ...],
ctx: pa.LayersContext,
batch: tuple,
) -> torch.Tensor:
"""Flatten and batch parameters into a single stacked tensor.
Converts a tuple of parameter tensors (potentially with mixed batched/unbatched)
into a single concatenated tensor suitable for matrix multiplication with the
parametrized problem matrices.
Args:
params: Tuple of parameter tensors
ctx: Layer context with batch info and ordering
batch: Batch dimensions tuple (empty if unbatched)
Returns:
Concatenated parameter tensor with shape (num_params, batch_size) or (num_params,)
"""
flattened_params: list[torch.Tensor | None] = [None] * (len(params) + 1)
# Always work with batch_size >= 1 for uniform processing
# batch is () for unbatched, (B,) for batched
effective_batch = batch if batch else (1,)
for i, param in enumerate(params):
# Check if this parameter is batched or needs broadcasting
if ctx.batch_sizes[i] == 0: # type: ignore[index]
# Unbatched parameter - expand to match batch size
param_expanded = param.unsqueeze(0).expand(effective_batch + param.shape)
flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
param_expanded,
effective_batch + (-1,),
)
else:
# Already batched
flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
param,
effective_batch + (-1,),
)
# Add constant 1.0 column for offset terms in canonical form
flattened_params[-1] = torch.ones(
effective_batch + (1,),
dtype=params[0].dtype,
device=params[0].device,
)
assert all(p is not None for p in flattened_params), "All parameters must be assigned"
p_stack = torch.cat(flattened_params, -1) # type: ignore[arg-type]
# p_stack is (batch_size, num_params), transpose to (num_params, batch_size)
p_stack = p_stack.T
# For unbatched case, squeeze the trailing dimension to get (num_params,)
# len(batch) is 0 for unbatched, 1 for batched - use this as a shape selector
# squeeze(-1) removes the last dimension if it's 1, otherwise no-op
# But we need conditional-free: use reshape with computed output shape
output_shape = (p_stack.shape[0],) + batch # (num_params,) or (num_params, B)
return p_stack.reshape(output_shape)
def _svec_to_symmetric(
svec: torch.Tensor,
n: int,
batch: tuple,
rows: np.ndarray,
cols: np.ndarray,
scale: np.ndarray | None = None,
) -> torch.Tensor:
"""Convert vectorized form to full symmetric matrix.
Args:
svec: Vectorized form, shape (*batch, n*(n+1)/2)
n: Matrix dimension
batch: Batch dimensions (from tensor shape, not boolean)
rows: Row indices for unpacking
cols: Column indices for unpacking
scale: Optional scaling factors (for svec format with sqrt(2) scaling)
Returns:
Full symmetric matrix, shape (*batch, n, n)
"""
rows_t = torch.tensor(rows, dtype=torch.long, device=svec.device)
cols_t = torch.tensor(cols, dtype=torch.long, device=svec.device)
if scale is not None:
scale_t = torch.tensor(scale, dtype=svec.dtype, device=svec.device)
data = svec * scale_t
else:
data = svec
# Output shape: (*batch, n, n) - works for () or (B,) batch
out_shape = batch + (n, n)
result = torch.zeros(out_shape, dtype=svec.dtype, device=svec.device)
# Use ellipsis indexing - works uniformly for any number of batch dimensions
result[..., rows_t, cols_t] = data
result[..., cols_t, rows_t] = data
return result
def _unpack_primal_svec(svec: torch.Tensor, n: int, batch: tuple) -> torch.Tensor:
"""Unpack symmetric primal variable from vectorized form.
CVXPY stores symmetric variables in upper triangular row-major order:
[X[0,0], X[0,1], ..., X[0,n-1], X[1,1], X[1,2], ..., X[n-1,n-1]]
Args:
svec: Vectorized symmetric variable
n: Matrix dimension
batch: Batch dimensions
Returns:
Full symmetric matrix
"""
rows, cols = np.triu_indices(n)
return _svec_to_symmetric(svec, n, batch, rows, cols)
def _unpack_svec(svec: torch.Tensor, n: int, batch: tuple) -> torch.Tensor:
"""Unpack scaled vectorized (svec) form to full symmetric matrix.
The svec format stores a symmetric n x n matrix as a vector of length n*(n+1)/2,
with off-diagonal elements scaled by sqrt(2). Uses column-major lower triangular
ordering: (0,0), (1,0), (1,1), (2,0), ...
Args:
svec: Scaled vectorized form
n: Matrix dimension
batch: Batch dimensions
Returns:
Full symmetric matrix with scaling removed
"""
rows_rm, cols_rm = np.tril_indices(n)
sort_idx = np.lexsort((rows_rm, cols_rm))
rows = rows_rm[sort_idx]
cols = cols_rm[sort_idx]
# Scale: 1.0 for diagonal, 1/sqrt(2) for off-diagonal
scale = np.where(rows == cols, 1.0, 1.0 / np.sqrt(2.0))
return _svec_to_symmetric(svec, n, batch, rows, cols, scale)
def _recover_results(
primal: torch.Tensor,
dual: torch.Tensor,
ctx: pa.LayersContext,
batch: tuple,
) -> tuple[torch.Tensor, ...]:
"""Recover variable values from primal/dual solutions.
Extracts the requested variables from the solver's primal and dual
solutions, unpacks symmetric matrices if needed, applies inverse GP
transformation, and removes batch dimension for unbatched inputs.
Args:
primal: Primal solution from solver
dual: Dual solution from solver
ctx: Layer context with variable recovery info
batch: Batch dimensions tuple (empty () for unbatched, (B,) for batched)
Returns:
Tuple of recovered variable values
"""
results = []
# Solver always returns 2D tensors (batch_size, num_vars), even for unbatched (batch_size=1)
# Use internal_batch for intermediate processing, output_batch for final reshape
internal_batch = tuple(primal.shape[:-1]) # (1,) for unbatched, (B,) for batched
output_batch = batch # () for unbatched, (B,) for batched
for var in ctx.var_recover:
# Use pre-computed source field to select data (JIT-compatible)
if var.source == "primal":
data = primal[..., var.primal]
else: # var.source == "dual"
data = dual[..., var.dual]
# Use pre-computed unpack_fn field (JIT-compatible)
if var.unpack_fn == "svec_primal":
result = _unpack_primal_svec(data, var.shape[0], internal_batch)
elif var.unpack_fn == "svec_dual":
result = _unpack_svec(data, var.shape[0], internal_batch)
elif var.unpack_fn == "reshape":
result = _reshape_fortran(data, internal_batch + var.shape)
else:
raise ValueError(f"Unknown variable recovery type: {var.unpack_fn}")
# Reshape to output batch shape (removes dummy batch dim for unbatched)
# reshape from (1,) + var.shape to () + var.shape works because total elements match
results.append(result.reshape(output_batch + var.shape))
# Apply exp transformation to recover primal variables from log-space for GP
# (dual variables stay in original space - no transformation needed)
# Uses pre-computed source field (JIT-compatible)
if ctx.gp:
results = [
torch.exp(r) if var.source == "primal" else r
for r, var in zip(results, ctx.var_recover)
]
return tuple(results)
[docs]
class CvxpyLayer(torch.nn.Module):
"""A differentiable convex optimization layer for PyTorch.
This layer wraps a parametrized CVXPY problem, solving it in the forward pass
and computing gradients via implicit differentiation in the backward pass.
Example:
>>> import cvxpy as cp
>>> import torch
>>> from cvxpylayers.torch import CvxpyLayer
>>>
>>> # Define a simple QP
>>> x = cp.Variable(2)
>>> A = cp.Parameter((3, 2))
>>> b = cp.Parameter(3)
>>> problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0])
>>>
>>> # Create the layer
>>> layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
>>>
>>> # Solve with gradients
>>> A_t = torch.randn(3, 2, requires_grad=True)
>>> b_t = torch.randn(3, requires_grad=True)
>>> (solution,) = layer(A_t, b_t)
>>> solution.sum().backward()
"""
[docs]
def __init__(
self,
problem: cp.Problem,
parameters: list[cp.Parameter],
variables: list[cp.Variable],
solver: str | None = None,
gp: bool = False,
verbose: bool = False,
canon_backend: str | None = None,
solver_args: dict[str, Any] | None = None,
) -> None:
"""Initialize the differentiable optimization layer.
Args:
problem: A CVXPY Problem. Must be DPP-compliant (``problem.is_dpp()``
must return True).
parameters: List of CVXPY Parameters that will be filled with values
at runtime. Order must match the order of tensors passed to forward().
variables: List of CVXPY Variables whose optimal values will be returned
by forward(). Order determines the order of returned tensors.
solver: CVXPY solver to use (e.g., ``cp.CLARABEL``, ``cp.SCS``).
If None, uses the default diffcp solver.
gp: If True, problem is a geometric program. Parameters will be
log-transformed before solving.
verbose: If True, print solver output.
canon_backend: Backend for canonicalization. Options are 'diffcp',
'cuclarabel', or None (auto-select).
solver_args: Default keyword arguments passed to the solver.
Can be overridden per-call in forward().
Raises:
AssertionError: If problem is not DPP-compliant.
ValueError: If parameters or variables are not part of the problem.
"""
super().__init__()
if solver_args is None:
solver_args = {}
self.ctx = pa.parse_args(
problem,
variables,
parameters,
solver,
gp=gp,
verbose=verbose,
canon_backend=canon_backend,
solver_args=solver_args,
)
if self.ctx.reduced_P.reduced_mat is not None: # type: ignore[attr-defined]
self.register_buffer(
"P",
scipy_csr_to_torch_csr(self.ctx.reduced_P.reduced_mat), # type: ignore[attr-defined]
)
self._P_scipy: scipy.sparse.csr_array | None = self.ctx.reduced_P.reduced_mat.tocsr() # type: ignore[attr-defined]
else:
self.P = None
self._P_scipy = None
self.register_buffer("q", scipy_csr_to_torch_csr(self.ctx.q.tocsr()))
self._q_scipy: scipy.sparse.csr_array = self.ctx.q.tocsr()
self.register_buffer(
"A",
scipy_csr_to_torch_csr(self.ctx.reduced_A.reduced_mat), # type: ignore[attr-defined]
)
self._A_scipy: scipy.sparse.csr_array = self.ctx.reduced_A.reduced_mat.tocsr() # type: ignore[attr-defined]
[docs]
def forward(
self, *params: torch.Tensor, solver_args: dict[str, Any] | None = None
) -> tuple[torch.Tensor, ...]:
"""Solve the optimization problem and return optimal variable values.
Args:
*params: Tensor values for each CVXPY Parameter, in the same order
as the ``parameters`` argument to __init__. Each tensor shape must
match the corresponding Parameter shape, optionally with a batch
dimension prepended. Batched and unbatched parameters can be mixed;
unbatched parameters are broadcast.
solver_args: Keyword arguments passed to the solver, overriding any
defaults set in __init__.
Returns:
Tuple of tensors containing optimal values for each CVXPY Variable
specified in the ``variables`` argument to __init__. If inputs are
batched, outputs will have matching batch dimensions.
Raises:
RuntimeError: If the solver fails to find a solution.
Example:
>>> # Single problem
>>> (x_opt,) = layer(A_tensor, b_tensor)
>>>
>>> # Batched: solve 10 problems in parallel
>>> A_batch = torch.randn(10, 3, 2)
>>> b_batch = torch.randn(10, 3)
>>> (x_batch,) = layer(A_batch, b_batch) # x_batch.shape = (10, 2)
"""
if solver_args is None:
solver_args = {}
batch = self.ctx.validate_params(list(params))
# Apply log transformation to GP parameters
params = _apply_gp_log_transform(params, self.ctx)
# Flatten and batch parameters
p_stack = _flatten_and_batch_params(params, self.ctx, batch)
# Get dtype and device from input parameters to ensure type/device matching
param_dtype = p_stack.dtype
param_device = p_stack.device
# Evaluate parametrized matrices
if param_device.type == "cpu":
# Use scipy sparse matmul on CPU (80-200x faster than torch sparse CSR)
P_eval = (
_ScipySparseMatmul.apply(self._P_scipy, p_stack)
if self._P_scipy is not None
else None
)
q_eval = _ScipySparseMatmul.apply(self._q_scipy, p_stack)
A_eval = _ScipySparseMatmul.apply(self._A_scipy, p_stack)
else:
# Use torch sparse CSR on GPU (fast there)
P_eval = (
(self.P.to(dtype=param_dtype, device=param_device) @ p_stack)
if self.P is not None
else None
)
q_eval = self.q.to(dtype=param_dtype, device=param_device) @ p_stack # type: ignore[operator]
A_eval = self.A.to(dtype=param_dtype, device=param_device) @ p_stack # type: ignore[operator]
# Get the solver-specific _CvxpyLayer class
from cvxpylayers.interfaces import get_torch_cvxpylayer
_CvxpyLayer = get_torch_cvxpylayer(self.ctx.solver)
# Solve optimization problem
primal, dual, _, _ = _CvxpyLayer.apply( # type: ignore[misc]
P_eval,
q_eval,
A_eval,
self.ctx,
solver_args,
)
# Recover results and apply GP inverse transform if needed
return _recover_results(primal, dual, self.ctx, batch)
def scipy_csr_to_torch_csr(
scipy_csr: scipy.sparse.csr_array | None,
) -> torch.Tensor | None:
if scipy_csr is None:
return None
# Use cast to help type checker understand scipy_csr is not None
scipy_csr = cast(scipy.sparse.csr_array, scipy_csr)
# Get the CSR format components
values = scipy_csr.data
col_indices = scipy_csr.indices
row_ptr = scipy_csr.indptr
num_rows, num_cols = scipy_csr.shape # type: ignore[misc]
# Create the torch sparse csr_tensor
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Sparse CSR tensor support is in beta state",
category=UserWarning,
)
torch_csr = torch.sparse_csr_tensor(
crow_indices=torch.tensor(row_ptr, dtype=torch.int64),
col_indices=torch.tensor(col_indices, dtype=torch.int64),
values=torch.tensor(values, dtype=torch.float64),
size=(num_rows, num_cols),
)
return torch_csr