Quickstart

Build your first differentiable optimization layer in 4 steps.

Prerequisites

  • Python 3.11+

  • CVXPY (pip install cvxpy)

  • One of: PyTorch, JAX, or MLX


1 Define the Problem

Create a parametrized convex optimization problem with CVXPY:

import cvxpy as cp

# Problem dimensions
n, m = 2, 3

# Decision variable (what we solve for)
x = cp.Variable(n)

# Parameters (inputs that change at runtime)
A = cp.Parameter((m, n))
b = cp.Parameter(m)

# Build the problem
problem = cp.Problem(
    cp.Minimize(cp.sum_squares(A @ x - b)),  # Objective
    [x >= 0]                                   # Constraints
)

Tip

Parameters are placeholders for values you’ll provide later. Variables are what the solver finds.


2 Check DPP Compliance

Your problem must follow Disciplined Parametrized Programming rules:

assert problem.is_dpp(), "Problem must be DPP-compliant"
What is DPP?

DPP ensures the problem structure is fixed — only parameter values change, not the problem shape. This is required for implicit differentiation.

Valid: Parameters in linear/affine expressions

A @ x - b  # Good: A and b appear linearly

Invalid: Parameters that change problem structure

cp.quad_form(x, P)  # P as parameter may break DPP

3 Create the Layer

Wrap your problem as a differentiable layer:

import torch
from cvxpylayers.torch import CvxpyLayer

layer = CvxpyLayer(
    problem,
    parameters=[A, b],   # CVXPY parameters (in order)
    variables=[x]        # Variables to return
)
import jax
from cvxpylayers.jax import CvxpyLayer

layer = CvxpyLayer(
    problem,
    parameters=[A, b],
    variables=[x]
)
import mlx.core as mx
from cvxpylayers.mlx import CvxpyLayer

layer = CvxpyLayer(
    problem,
    parameters=[A, b],
    variables=[x]
)

4 Solve & Differentiate

Pass tensor values and backpropagate:

# Create tensors with gradients enabled
A_t = torch.randn(m, n, requires_grad=True)
b_t = torch.randn(m, requires_grad=True)

# Forward: solve the optimization
(solution,) = layer(A_t, b_t)

# Backward: compute gradients
loss = solution.sum()
loss.backward()

print(f"Solution: {solution}")
print(f"dL/dA: {A_t.grad}")
print(f"dL/db: {b_t.grad}")
import jax.numpy as jnp

# Create arrays
key = jax.random.PRNGKey(0)
A_jax = jax.random.normal(key, (m, n))
b_jax = jax.random.normal(key, (m,))

# Forward
(solution,) = layer(A_jax, b_jax)

# Gradients via jax.grad
def loss_fn(A, b):
    (sol,) = layer(A, b)
    return sol.sum()

dA, db = jax.grad(loss_fn, argnums=[0, 1])(A_jax, b_jax)

print(f"Solution: {solution}")
print(f"dL/dA: {dA}")
print(f"dL/db: {db}")
# Create arrays
A_mx = mx.random.normal((m, n))
b_mx = mx.random.normal((m,))

# Forward
(solution,) = layer(A_mx, b_mx)

# Gradients via mx.grad
def loss_fn(A, b):
    (sol,) = layer(A, b)
    return mx.sum(sol)

grad_fn = mx.grad(loss_fn, argnums=[0, 1])
dA, db = grad_fn(A_mx, b_mx)

print(f"Solution: {solution}")
print(f"dL/dA: {dA}")
print(f"dL/db: {db}")

Complete Example

Here’s everything together — a training loop that learns matrix A:

import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer

# 1. Define problem
n, m = 5, 10
x = cp.Variable(n)
A = cp.Parameter((m, n))
b = cp.Parameter(m)

problem = cp.Problem(
    cp.Minimize(cp.sum_squares(A @ x - b)),
    [x >= 0]
)

# 2. Create layer
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])

# 3. Training setup
A_true = torch.randn(m, n)
x_true = torch.abs(torch.randn(n))
b_true = A_true @ x_true

A_learn = torch.randn(m, n, requires_grad=True)
optimizer = torch.optim.Adam([A_learn], lr=0.1)

# 4. Training loop
for i in range(100):
    optimizer.zero_grad()
    (x_pred,) = layer(A_learn, b_true)
    loss = torch.sum((x_pred - x_true) ** 2)
    loss.backward()
    optimizer.step()

    if i % 25 == 0:
        print(f"Step {i:3d} | Loss: {loss.item():.4f}")

Next Steps

Basic Usage

Constructor options, parameter handling, error handling.

Basic Usage
Batching

Solve multiple problems in parallel for 10-100x speedup.

Batched Execution
Solvers

Choose the right solver: SCS, Clarabel, CuClarabel.

Solvers
Examples

Real-world applications: control, finance, ML, robotics.

Examples