SFT is bad RL
There have been a few papers recently that observed that learning from incorrect examples helps, even in verifiable domains. Here verifiable just means we have access to some ground truth reward function, usually 0/1 or something in-between. Those papers have found that training on a larger number of incorrect examples can result in better performance than training on just positive examples. This is puzzling. Why would we ever want to train on incorrect examples? My gut says that correct/incorrect is the wrong thing. Instead, we should be asking: What’s the advantage of any given datapoint?
In this blog post, we will show that you can do better than training directly on incorrect examples. We first show that supervised learning on incorrect examples is an instance of reinforcement learning (RL), and that we can do better by actually following RL basics.
Let’s go straight to the math (it’s simple).
The goal of supervised learning is to train a student policy to clone the behaviour of a teacher . The teacher gives sample trajectories which serve as examples to teach the student. SFT searches over student policies to minimize the KL-divergence from the teacher to the student, which is equivalent to maximizing the log-likelihood of the teacher’s samples:
Tangent: KL vs Cross-Entropy
The differences between equation (2) and (3) are interesting. We used to normally only see (3) in traditional ML, where we were learning from Human teachers. We usually do not have access to calibrated probabilities of human demonstrations. But nowadays, we have actually have a lot of teacher models in the form of open weight LLMs, letting us optimize (2). Utilizing teacher probabilities might help with training sample complexity. Imagine what the optimal student policy looks like for each equation when you only have a single example from the teacher.
This looks similar to off-policy policy gradient, where you optimize the expected reward using samples from an importance distribution (the teacher):
Let’s examine each of the three terms in equation (7) and contrast them with equation (3).
- Importance weight denominator : Since the trajectories are sampled from it’s reasonable to assume that the samples are high probability under and this can pretty safely be ignored.
- Importance weight numerator : This is rather than the you see in equation (3). Since they are monotonic transformations of each other, this may not affect learning too much.
- Reward : In the SFT setting, the teacher examples are assumed to have reward . This doesn’t make sense if you have the true reward function and the teacher examples do not have .
This gives an obvious way to improve over naive SFT on incorrect examples: Use the reward when you can. In practice, you can add on the typical bells and whistles to your gradient estimator. This should look like a policy grad method that uses advantages, with the sub-optimal teacher demonstrations seeding the replay buffer along with some on-policy samples from . This would allow the model to mimic bad samples when the student is worse, but ignore those samples when the student is better.