Source code for cvxpylayers.mlx.cvxpylayer

from typing import Any, cast

import cvxpy as cp
import mlx.core as mx
import numpy as np
import scipy.sparse

import cvxpylayers.utils.parse_args as pa


def _scipy_csr_to_dense(
    scipy_csr: scipy.sparse.csr_array | scipy.sparse.csr_matrix | None,
) -> np.ndarray | None:
    """Convert scipy sparse CSR matrix to dense numpy array.

    MLX does not currently support sparse linear algebra, so we convert
    to dense matrices for computation.
    """
    if scipy_csr is None:
        return None
    scipy_csr = cast(scipy.sparse.csr_matrix, scipy_csr)
    return np.asarray(scipy_csr.toarray())


def _reshape_fortran(array: mx.array, shape: tuple) -> mx.array:
    """Reshape array using Fortran (column-major) order.

    MLX doesn't support order='F' directly, so we use transpose.

    Args:
        array: Input array to reshape
        shape: Target shape tuple

    Returns:
        Reshaped array in Fortran order
    """
    if len(array.shape) == 0:
        return mx.reshape(array, shape)
    x = mx.transpose(array, axes=tuple(reversed(range(len(array.shape)))))
    reshaped = mx.reshape(x, tuple(reversed(shape)))
    if len(shape) > 0:
        reshaped = mx.transpose(reshaped, axes=tuple(reversed(range(len(shape)))))
    return reshaped


def _apply_gp_log_transform(
    params: tuple[mx.array, ...],
    ctx: pa.LayersContext,
) -> tuple[mx.array, ...]:
    """Apply log transformation to geometric program (GP) parameters.

    Geometric programs are solved in log-space after conversion to DCP.
    This function applies log transformation to the appropriate parameters.

    Args:
        params: Tuple of parameter arrays in original GP space
        ctx: Layer context containing GP parameter mapping info

    Returns:
        Tuple of transformed parameters (log-space for GP params, unchanged otherwise)
    """
    if not ctx.gp or not ctx.gp_param_to_log_param:
        return params

    params_transformed = []
    for i, param in enumerate(params):
        cvxpy_param = ctx.parameters[i]
        if cvxpy_param in ctx.gp_param_to_log_param:
            # This parameter needs log transformation for GP
            params_transformed.append(mx.log(param))
        else:
            params_transformed.append(param)
    return tuple(params_transformed)


def _flatten_and_batch_params(
    params: tuple[mx.array, ...],
    ctx: pa.LayersContext,
    batch: tuple,
) -> mx.array:
    """Flatten and batch parameters into a single stacked array.

    Converts a tuple of parameter arrays (potentially with mixed batched/unbatched)
    into a single concatenated array suitable for matrix multiplication with the
    parametrized problem matrices.

    Args:
        params: Tuple of parameter arrays
        ctx: Layer context with batch info and ordering
        batch: Batch dimensions tuple (empty if unbatched)

    Returns:
        Concatenated parameter array with shape (num_params, batch_size) or (num_params,)
    """
    flattened_params: list[mx.array | None] = [None] * (len(params) + 1)

    for i, param in enumerate(params):
        # Check if this parameter is batched or needs broadcasting
        if ctx.batch_sizes[i] == 0 and batch:  # type: ignore[index]
            # Unbatched parameter - expand to match batch size
            param_expanded = mx.broadcast_to(mx.expand_dims(param, axis=0), batch + param.shape)
            flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
                param_expanded,
                batch + (-1,),
            )
        else:
            # Already batched or no batch dimension needed
            flattened_params[ctx.user_order_to_col_order[i]] = _reshape_fortran(
                param,
                batch + (-1,),
            )

    # Add constant 1.0 column for offset terms in canonical form
    flattened_params[-1] = mx.ones(batch + (1,), dtype=params[0].dtype)

    assert all(p is not None for p in flattened_params), "All parameters must be assigned"
    p_stack = mx.concatenate(cast(list[mx.array], flattened_params), axis=-1)
    # When batched, p_stack is (batch_size, num_params) but we need (num_params, batch_size)
    if batch:
        p_stack = mx.transpose(p_stack)
    return p_stack


def _svec_to_symmetric(
    svec: mx.array,
    n: int,
    batch: tuple,
    rows: np.ndarray,
    cols: np.ndarray,
    scale: np.ndarray | None = None,
) -> mx.array:
    """Convert vectorized form to full symmetric matrix.

    MLX doesn't support advanced indexing like torch/jax, so we use numpy
    for the indexing operations and convert back to MLX.

    Args:
        svec: Vectorized form, shape (*batch, n*(n+1)/2)
        n: Matrix dimension
        batch: Batch dimensions
        rows: Row indices for unpacking
        cols: Column indices for unpacking
        scale: Optional scaling factors (for svec format with sqrt(2) scaling)

    Returns:
        Full symmetric matrix, shape (*batch, n, n)
    """
    if scale is not None:
        scale_mx = mx.array(scale, dtype=svec.dtype)
        data = svec * scale_mx
    else:
        data = svec

    out_shape = batch + (n, n)
    if batch:
        batch_size = int(np.prod(batch))
        data_flat = mx.reshape(data, (batch_size, -1))
        # Build result by iterating (MLX lacks advanced indexing)
        results = []
        for b in range(batch_size):
            data_b = data_flat[b]
            # Use numpy for indexing, then convert
            result_np = np.zeros((n, n), dtype=np.float64)
            data_np = np.array(data_b)
            result_np[rows, cols] = data_np
            result_np[cols, rows] = data_np
            results.append(mx.array(result_np, dtype=svec.dtype))
        result = mx.stack(results, axis=0)
        return mx.reshape(result, out_shape)
    else:
        # Unbatched: simple approach via numpy
        data_np = np.array(data)
        result_np = np.zeros((n, n), dtype=np.float64)
        result_np[rows, cols] = data_np
        result_np[cols, rows] = data_np
        return mx.array(result_np, dtype=svec.dtype)


def _unpack_primal_svec(svec: mx.array, n: int, batch: tuple) -> mx.array:
    """Unpack symmetric primal variable from vectorized form.

    CVXPY stores symmetric variables in upper triangular row-major order:
    [X[0,0], X[0,1], ..., X[0,n-1], X[1,1], X[1,2], ..., X[n-1,n-1]]

    Args:
        svec: Vectorized symmetric variable
        n: Matrix dimension
        batch: Batch dimensions

    Returns:
        Full symmetric matrix
    """
    rows, cols = np.triu_indices(n)
    return _svec_to_symmetric(svec, n, batch, rows, cols)


def _unpack_svec(svec: mx.array, n: int, batch: tuple) -> mx.array:
    """Unpack scaled vectorized (svec) form to full symmetric matrix.

    The svec format stores a symmetric n x n matrix as a vector of length n*(n+1)/2,
    with off-diagonal elements scaled by sqrt(2). Uses column-major lower triangular
    ordering: (0,0), (1,0), (1,1), (2,0), ...

    Args:
        svec: Scaled vectorized form
        n: Matrix dimension
        batch: Batch dimensions

    Returns:
        Full symmetric matrix with scaling removed
    """
    rows_rm, cols_rm = np.tril_indices(n)
    sort_idx = np.lexsort((rows_rm, cols_rm))
    rows = rows_rm[sort_idx]
    cols = cols_rm[sort_idx]
    # Scale: 1.0 for diagonal, 1/sqrt(2) for off-diagonal
    scale = np.where(rows == cols, 1.0, 1.0 / np.sqrt(2.0))
    return _svec_to_symmetric(svec, n, batch, rows, cols, scale)


def _recover_results(
    primal: mx.array,
    dual: mx.array,
    ctx: pa.LayersContext,
    batch: tuple,
) -> tuple[mx.array, ...]:
    """Recover variable values from primal/dual solutions.

    Extracts the requested variables from the solver's primal and dual
    solutions, unpacks symmetric matrices if needed, applies inverse GP
    transformation, and removes batch dimension for unbatched inputs.

    Args:
        primal: Primal solution from solver
        dual: Dual solution from solver
        ctx: Layer context with variable recovery info
        batch: Batch dimensions tuple (empty if unbatched)

    Returns:
        Tuple of recovered variable values
    """
    results = []
    for var in ctx.var_recover:
        batch_shape = tuple(primal.shape[:-1])
        if var.primal is not None:
            data = primal[..., var.primal]
            if var.is_symmetric:
                # Unpack symmetric primal variable from vectorized form
                results.append(_unpack_primal_svec(data, var.shape[0], batch_shape))
            else:
                results.append(_reshape_fortran(data, batch_shape + var.shape))
        elif var.dual is not None:
            data = dual[..., var.dual]
            if var.is_psd_dual:
                # Unpack PSD constraint dual from scaled vectorized form
                results.append(_unpack_svec(data, var.shape[0], batch_shape))
            else:
                results.append(_reshape_fortran(data, batch_shape + var.shape))
        else:
            raise RuntimeError(
                "Invalid VariableRecovery: both primal and dual slices are None. "
                "At least one must be set to recover variable values."
            )

    # Apply exp transformation to recover primal variables from log-space for GP
    # (dual variables stay in original space - no transformation needed)
    if ctx.gp:
        results = [
            mx.exp(r) if var.primal is not None else r for r, var in zip(results, ctx.var_recover)
        ]

    # Squeeze batch dimension for unbatched inputs
    if not batch:
        results = [mx.squeeze(r, axis=0) for r in results]

    return tuple(results)


[docs] class CvxpyLayer: """A differentiable convex optimization layer for MLX. This layer wraps a parametrized CVXPY problem, solving it in the forward pass and computing gradients via implicit differentiation. Optimized for Apple Silicon (M1/M2/M3) with unified memory architecture. Example: >>> import cvxpy as cp >>> import mlx.core as mx >>> from cvxpylayers.mlx import CvxpyLayer >>> >>> # Define a simple QP >>> x = cp.Variable(2) >>> A = cp.Parameter((3, 2)) >>> b = cp.Parameter(3) >>> problem = cp.Problem(cp.Minimize(cp.sum_squares(A @ x - b)), [x >= 0]) >>> >>> # Create the layer >>> layer = CvxpyLayer(problem, parameters=[A, b], variables=[x]) >>> >>> # Solve and compute gradients >>> A_mx = mx.random.normal((3, 2)) >>> b_mx = mx.random.normal((3,)) >>> (solution,) = layer(A_mx, b_mx) >>> >>> # Gradient computation >>> def loss_fn(A, b): ... (x,) = layer(A, b) ... return mx.sum(x) >>> grad_fn = mx.grad(loss_fn, argnums=[0, 1]) >>> grads = grad_fn(A_mx, b_mx) """
[docs] def __init__( self, problem: cp.Problem, parameters: list[cp.Parameter], variables: list[cp.Variable], solver: str | None = None, gp: bool = False, verbose: bool = False, canon_backend: str | None = None, solver_args: dict[str, Any] | None = None, ) -> None: """Initialize the differentiable optimization layer. Args: problem: A CVXPY Problem. Must be DPP-compliant (``problem.is_dpp()`` must return True). parameters: List of CVXPY Parameters that will be filled with values at runtime. Order must match the order of arrays passed to __call__(). variables: List of CVXPY Variables whose optimal values will be returned by __call__(). Order determines the order of returned arrays. solver: CVXPY solver to use (e.g., ``cp.CLARABEL``, ``cp.SCS``). If None, uses the default diffcp solver. gp: If True, problem is a geometric program. Parameters will be log-transformed before solving. verbose: If True, print solver output. canon_backend: Backend for canonicalization. Options are 'diffcp', 'cuclarabel', or None (auto-select). solver_args: Default keyword arguments passed to the solver. Can be overridden per-call in __call__(). Raises: AssertionError: If problem is not DPP-compliant. ValueError: If parameters or variables are not part of the problem. """ if solver_args is None: solver_args = {} self.ctx = pa.parse_args( problem, variables, parameters, solver, gp=gp, verbose=verbose, canon_backend=canon_backend, solver_args=solver_args, ) # MLX doesn't support sparse LA, so we store dense numpy arrays # and convert to MLX arrays during forward pass if self.ctx.reduced_P.reduced_mat is not None: # type: ignore[attr-defined] self._P_np = _scipy_csr_to_dense(self.ctx.reduced_P.reduced_mat) # type: ignore[attr-defined] else: self._P_np = None self._q_np: np.ndarray = _scipy_csr_to_dense(self.ctx.q.tocsr()) # type: ignore[assignment] self._A_np: np.ndarray = _scipy_csr_to_dense(self.ctx.reduced_A.reduced_mat) # type: ignore[attr-defined, assignment]
[docs] def __call__( self, *params: mx.array, solver_args: dict[str, Any] | None = None, ) -> tuple[mx.array, ...]: """Solve the optimization problem and return optimal variable values. Args: *params: Array values for each CVXPY Parameter, in the same order as the ``parameters`` argument to __init__(). Each array shape must match the corresponding Parameter shape, optionally with a batch dimension prepended. Batched and unbatched parameters can be mixed; unbatched parameters are broadcast. solver_args: Keyword arguments passed to the solver, overriding any defaults set in __init__(). Returns: Tuple of arrays containing optimal values for each CVXPY Variable specified in the ``variables`` argument to __init__(). If inputs are batched, outputs will have matching batch dimensions. Raises: RuntimeError: If the solver fails to find a solution. Example: >>> # Single problem >>> (x_opt,) = layer(A_array, b_array) >>> >>> # Batched: solve 10 problems in parallel >>> A_batch = mx.random.normal((10, 3, 2)) >>> b_batch = mx.random.normal((10, 3)) >>> (x_batch,) = layer(A_batch, b_batch) # x_batch.shape = (10, 2) """ if solver_args is None: solver_args = {} batch = self.ctx.validate_params(list(params)) # Apply log transformation to GP parameters params = _apply_gp_log_transform(params, self.ctx) # Flatten and batch parameters p_stack = _flatten_and_batch_params(params, self.ctx, batch) # Get dtype from input parameters to ensure type matching param_dtype = params[0].dtype # Evaluate parametrized matrices (convert dense numpy to MLX) P_eval = ( mx.array(self._P_np, dtype=param_dtype) @ p_stack if self._P_np is not None else None ) q_eval = mx.array(self._q_np, dtype=param_dtype) @ p_stack A_eval = mx.array(self._A_np, dtype=param_dtype) @ p_stack # Solve optimization problem with custom VJP for gradients primal, dual = self._solve_with_vjp(P_eval, q_eval, A_eval, solver_args) # Recover results and apply GP inverse transform if needed return _recover_results(primal, dual, self.ctx, batch)
[docs] def forward( self, *params: mx.array, solver_args: dict[str, Any] | None = None, ) -> tuple[mx.array, ...]: """Forward pass (alias for __call__).""" return self.__call__(*params, solver_args=solver_args)
def _solve_with_vjp( self, P_eval: mx.array | None, q_eval: mx.array, A_eval: mx.array, solver_args: dict[str, Any], ) -> tuple[mx.array, mx.array]: """Solve the canonical problem with custom VJP for backpropagation.""" ctx = self.ctx # Store data and adjoint in closure for backward pass data_container: dict[str, Any] = {} # Handle None P by using a dummy tensor (required for custom_function signature) param_dtype = q_eval.dtype P_arg = P_eval if P_eval is not None else mx.zeros((1,), dtype=param_dtype) has_P = P_eval is not None @mx.custom_function def solve_layer(P_tensor: mx.array, q_tensor: mx.array, A_tensor: mx.array): # Forward pass: solve the optimization problem quad_values = P_tensor if has_P else None data = ctx.solver_ctx.mlx_to_data(quad_values, q_tensor, A_tensor) # type: ignore[attr-defined] primal, dual, adj_batch = data.mlx_solve(solver_args) # type: ignore[attr-defined] # Store for backward pass (outside MLX tracing) data_container["data"] = data data_container["adj_batch"] = adj_batch data_container["has_P"] = has_P return primal, dual @solve_layer.vjp def solve_layer_vjp(primals, cotangents, outputs): # noqa: F811 # Backward pass using adjoint method if isinstance(cotangents, (tuple, list)): cot_list = list(cotangents) else: cot_list = [cotangents] dprimal = cot_list[0] if cot_list else mx.zeros_like(outputs[0]) ddual = ( cot_list[1] if len(cot_list) >= 2 and cot_list[1] is not None else mx.zeros(outputs[1].shape, dtype=outputs[1].dtype) ) data = data_container["data"] adj_batch = data_container["adj_batch"] dP, dq, dA = data.mlx_derivative(dprimal, ddual, adj_batch) # Return zero gradient for P if problem has no quadratic term if not data_container["has_P"] or dP is None: grad_P = mx.zeros(primals[0].shape, dtype=primals[0].dtype) else: grad_P = dP return (grad_P, dq, dA) primal, dual = solve_layer(P_arg, q_eval, A_eval) # type: ignore[misc] return primal, dual