Batched Execution¶
CVXPYlayers supports solving multiple problem instances in parallel through batching.
How Batching Works¶
Add a batch dimension as the first dimension of your parameter tensors:
import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer
# Problem with parameters of shape (3, 2) and (3,)
x = cp.Variable(2)
A = cp.Parameter((3, 2))
b = cp.Parameter(3)
problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0])
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
# Single instance: shapes (3, 2) and (3,)
A_single = torch.randn(3, 2)
b_single = torch.randn(3)
(x_single,) = layer(A_single, b_single) # x_single shape: (2,)
# Batched: shapes (batch_size, 3, 2) and (batch_size, 3)
batch_size = 10
A_batch = torch.randn(batch_size, 3, 2)
b_batch = torch.randn(batch_size, 3)
(x_batch,) = layer(A_batch, b_batch) # x_batch shape: (10, 2)
Broadcasting¶
Parameters can be mixed batched and unbatched. Unbatched parameters are broadcast:
batch_size = 10
# A is batched, b is shared across the batch
A_batch = torch.randn(batch_size, 3, 2) # Different A for each instance
b_shared = torch.randn(3) # Same b for all instances
(x_batch,) = layer(A_batch, b_shared) # x_batch shape: (10, 2)
This is useful when some parameters are fixed and others vary.
Performance Considerations¶
Batch Size Selection¶
Larger batches are more efficient due to:
Parallelized matrix operations
Reduced Python overhead
Better GPU utilization (if using CuClarabel)
# Efficient: solve 100 problems at once
A_batch = torch.randn(100, m, n)
(x_batch,) = layer(A_batch, b)
# Less efficient: solve 100 problems one at a time
solutions = []
for i in range(100):
(x,) = layer(A_batch[i], b)
solutions.append(x)
Memory Trade-offs¶
Batching increases memory usage linearly with batch size. If you run out of memory:
Reduce batch size
Process in chunks:
def chunked_solve(layer, A_all, b, chunk_size=32):
results = []
for i in range(0, len(A_all), chunk_size):
chunk = A_all[i:i+chunk_size]
(x_chunk,) = layer(chunk, b)
results.append(x_chunk)
return torch.cat(results, dim=0)
Batched Gradients¶
Gradients work naturally with batching:
batch_size = 10
A_batch = torch.randn(batch_size, 3, 2, requires_grad=True)
b_batch = torch.randn(batch_size, 3, requires_grad=True)
(x_batch,) = layer(A_batch, b_batch)
# Sum over batch for scalar loss
loss = x_batch.sum()
loss.backward()
# Gradients have same shape as inputs
print(A_batch.grad.shape) # (10, 3, 2)
print(b_batch.grad.shape) # (10, 3)
JAX: vmap Integration¶
In JAX, you can also use jax.vmap for batching:
import jax
import jax.numpy as jnp
from cvxpylayers.jax import CvxpyLayer
layer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
# Option 1: Pass batched arrays directly (built-in support)
A_batch = jax.random.normal(key, (batch_size, 3, 2))
(x_batch,) = layer(A_batch, b_single)
# Option 2: Use vmap explicitly
def solve_single(A):
(x,) = layer(A, b_single)
return x
solve_batched = jax.vmap(solve_single)
x_batch = solve_batched(A_batch)
Example: Batch Portfolio Optimization¶
import cvxpy as cp
import torch
from cvxpylayers.torch import CvxpyLayer
n_assets = 10
n_scenarios = 50
# Portfolio optimization problem
weights = cp.Variable(n_assets)
expected_returns = cp.Parameter(n_assets)
risk_aversion = cp.Parameter(nonneg=True)
# Simplified: maximize return - risk_aversion * variance
problem = cp.Problem(
cp.Maximize(expected_returns @ weights - risk_aversion * cp.sum_squares(weights)),
[cp.sum(weights) == 1, weights >= 0]
)
layer = CvxpyLayer(
problem,
parameters=[expected_returns, risk_aversion],
variables=[weights]
)
# Different return predictions for each scenario
returns_batch = torch.randn(n_scenarios, n_assets)
risk_aversion_val = torch.tensor(0.5) # Shared across scenarios
# Solve all scenarios at once
(optimal_weights,) = layer(returns_batch, risk_aversion_val)
print(f"Optimal weights shape: {optimal_weights.shape}") # (50, 10)