Differentiating through optimization with the IFT
A startup recently came out that reminded me of an old (embarrassing) report I wrote on differentiating through optimization problems with the implicit function theorem (IFT). If I remember correctly, I was playing with ways to differentiate through massively sparse softmax problems. Unfortunately, I think I gave up after a week. Don’t be me.
Regardless, the IFT is super cool! It lets you differentiate through implicitly defined functions. For example, you want to optimize the parameters of an optimization problem. The issue is that the solver for this problem uses gradient descent and therefore takes many steps. Reverse-mode autodiff would require storing the entire execution trace to compute gradients, which requires memory linear in the number of iterations. The IFT lets you throw all of that away and compute gradients using only the solution and the optimality conditions, which implicitly defines a function.
A solver wiggles from to the optimum (red), producing a long trace that reverse-mode AD must store. The IFT (blue, dashed) computes directly from the optimality conditions at , without looking at the solver’s path.
The IFT shows up in a number of places. OptNet used it to backpropagate through quadratic program solvers, letting neural networks learn constraints (e.g. Sudoku from input-output pairs). Influence functions used it for training data attribution (I guess mainly Anthropic). Meta-learning methods like MAML differentiate through inner-loop optimization, and iMAML used the IFT to do so without unrolling. Hyperparameter optimization and variational inference also fit naturally into this framework. Note: You can also derive autodiff itself from the IFT.
In this post, I’ll walk through the mechanics of applying the IFT to softmax, expressed as an optimization problem. Warning: This post goes a bit deeper into calculations than my previous posts as the IFT is pretty mechanical.
The Implicit Function Theorem
As a starting example, consider the unit circle governed by the relation
Computing the derivative is not straightforward, as fails the vertical line test and we cannot write as a function of globally. However, we can find local parameterizations: on the upper semicircle (), , and on the lower semicircle (), . These let us compute at particular solution points, though not everywhere (e.g. not at ).
The unit circle . Near , we can locally write (green arc), so exists there. At , the circle is vertical, so no local parameterization exists and the derivative is undefined.
The IFT generalizes this. Instead of finding a local parameterization, it guarantees one exists and tells you its derivative. The parameterization is left implicit, hence the name.
Formally, given a system of equations and a solution point , the IFT says there exists a local solution mapping if:
- We have a solution .
- Continuous first derivatives .
- The Jacobian is nonsingular at the solution point.
When these hold, the derivative of the solution mapping is
The key point: as long as the conditions hold at a solution point, we can compute this derivative regardless of how we arrived at that solution.
The Softmax Optimization Problem
Let’s apply this to something concrete: softmax. Softmax has a known Jacobian, so we can verify the IFT gives the right answer. The exercise is to express softmax as an optimization problem, then derive its Jacobian using the IFT.
Softmax is defined as follows. Given items with independent utilities, where , softmax yields a distribution over items:
with .
Softmax is also the solution of the following constrained optimization problem:
where is the entropy.
Our goal is to compute the Jacobian of softmax
using the IFT and the optimization problem above.
Applying the IFT
Applying the IFT consists of four steps:
- Find a solution to the optimization problem.
- Write down a system of equations derived from the optimality conditions.
- Check that the conditions of the IFT hold.
- Compute the derivative of the implicit solution mapping with respect to the parameters.
We assume the first step has been done for us, and we have a solution to the softmax problem. Most of these manipulations are math homework flavored.
Step 2: The KKT conditions determine the system of equations
In order to apply the IFT, we need a system of equations for which our outputs of interest are solution points. For solutions of optimization problems, the Karush-Kuhn-Tucker (KKT) conditions are a natural choice for defining such a system of equations. Given an optimization problem, the KKT conditions determine a system of equations that the solution must satisfy based on the optimality criteria. They are stationarity (the gradient should be at a local optimum) and feasibility (the constraints of the problem should not be violated).
We will use the KKT conditions of the softmax problem to determine the vector-valued function in the IFT. The optimization problem has both equality and inequality constraints, but for finite the softmax solution is always strictly positive:
That means the inequality constraints are inactive at the solution points we care about, so it is cleaner to work with the reduced KKT system containing only the equality constraint. We therefore introduce only the equality multiplier and write out the Lagrangian:
We therefore have the solution point , with parameters and solution . We then have the following necessary conditions for a solution , i.e. the KKT conditions:
As we only need a system of equations with equations to determine the solution variables , we use the first two conditions: stationarity and primal feasibility (equality).
In full, the system of equations we choose for the softmax problem is
Step 3: Check that the IFT conditions hold at the solution point
The IFT only applies if the following three conditions to hold, which must be checked on a case-by-case basis for particular solution points. Note that the derivative may still be computed via other means if it exists.
- ,
- has at least continuous first derivatives,
- , or equivalently is full rank.
In the softmax problem, the first condition holds as we have a solution to the optimization problem and was chosen using the KKT conditions. The second condition also holds, as has continuous first derivatives. All that remains is to check the third condition, that the Jacobian matrix
(evaluated at the solution point) is non-singular.
The Jacobian matrix is given by
The upper-left block is (from the Hessian of ); its entries blow up if any component . We saw a similar issue in the unit circle example, where the derivative was undefined when . Luckily, softmax already satisfies for finite .
Once , the block is invertible. The Schur complement of the upper-left block is
where we used primal feasibility. Therefore the Jacobian of is nonsingular. This shows that the conditions of the IFT hold for the solution points that are feasible, optimal, and have strictly positive . Recall that the IFT applies at particular solution points, meaning we can pick and choose which points to analyze.
Step 4: Compute
Now that we have a set of solution points where the IFT holds, we can use the IFT to compute . Recall that we have the solution . The second part of the IFT tells us that we can compute the Jacobian of the solution mapping
then pick out the relevant components.
The second term is simple. Since only appears in the first vector-valued function of , we have
Rather than write down the full inverse explicitly, it is simpler to solve the linear system induced by the IFT identity. Differentiating the equations in gives
Solving the first equation for yields
Substituting this into the equality constraint gives
where we again used feasibility: . Therefore
Plugging this back in,
Therefore the Jacobian of the solution mapping is
which agrees with the analytic Jacobian.
For this toy example, the special structure of the KKT system lets us solve the linear system in closed form. However, solving the IFT system in general requires computing the inverse Hessian of the Lagrangian, which takes time. In practice, we typically avoid materializing the full Jacobian and instead compute JVPs or VJPs through the implicit function, where the cost can be further alleviated with approximate inverse-Hessian-vector-product methods.