如何正确复现 Instruct GPT / RLHF? (English)
如何正确复现 Instruct GPT / RLHF? (English)
Generated: 2026-06-21 17:15:30
---
I've been hitting every possible pitfall with RLHF since last year, and today I'm spilling it all for you.
You know what? The first time I ran PPO, I almost threw my computer out the window.
It was the second half of last year, right when GPT was taking the whole internet by storm. I excitedly opened the most popular "one-click InstructGPT reproduction" project on GitHub and ran the official example script. Three hours later, the loss had blown up, the reward was dead flat at zero, and my GPU fan was still spinning like crazy.
Guess how I felt?
I wanted to cry.
So what did I do next? I went framework by framework—PaLM-rlhf-pytorch, ColossalAI-Chat, DeepSpeed-Chat, TRLX, Huggingface TRL—I tried them all. They all had one thing in common: they either crashed halfway through for no reason, or the reward curve was jumping around like it had eaten a bag of Pop Rocks.
Most of those "one-click reproduction" repos on GitHub? Nine out of ten can't reproduce the results from the paper at all.
Alright, I'm not going to sugarcoat it for you today. I'm just laying out everything I've tried, everything I've fallen into, and everything I've finally verified.
---
Which InstructGPT are you reproducing?
The projects on GitHub fall into two camps.
First camp: the "distillation school."
These people are clever—they just use the ChatGPT API to collect data and then fine-tune an open-source base model. That's how Alpaca came from LLaMA, how BELLE came from BLOOMZ. Fast and not bad at all. Why? Because the "ground truth" answers were literally copied from ChatGPT.
But let's be clear: that's distillation, not real RLHF.
Second camp: the "purist school."
These folks insist on doing the full three-stage pipeline: SFT → Reward Model → PPO. And these are the frameworks with the most problems, because every single step has engineering traps that'll drive you crazy. That's what I'm talking about.
Look, the InstructGPT paper describes those three stages clearly:
Stage one, SFT: Supervised fine-tuning on human-annotated instruction data.
Stage two, Reward Model: Train a scoring model with pairwise ranking loss.
Stage three, PPO: Proximal policy optimization using reward signals plus a KL penalty.
Every one of these stages can make you crash and burn so hard you won't recognize yourself.
Let me go through them one by one.
---
The SFT stage: You think it's a warm-up, but it's the first trap
SFT is the starting point of RLHF. Lots of people think—just grab a dialogue dataset, run a few epochs of LoRA, how hard can it be?
No.
It's a huge trap.
I fell into this one: I took that 52k Alpaca dataset, fine-tuned LLaMA-7B for three epochs, and got the loss down to 0.3. I was pretty happy, thinking, "Okay, this is solid." Then I generated something, and it was like a broken record.
Why? Two reasons: overfitting, and bad data quality.
Here are a few key points you absolutely need to pay attention to.
First, don't set the learning rate too high.
When I used LLaMA-3-8B, I set the full fine-tuning learning rate at 1e-5, and LoRA at about 2e-4. That common LLaMA3 config you see online—learningrate: 0.0001, loratarget: qproj,vproj—after three epochs the loss can indeed drop to 0.26. But if you use that same config on your own data—first check whether your validation loss suddenly spikes again. If it does, you're overfitting.
Second, set a reasonable cutoff_len.
I usually use 1024. But in multi-turn dialogue scenarios, some responses can be really long. If you cut them too short, the model never learns the complete structure. I once did something stupid—I truncated a 2000-token response to 512. The result? The model learned to leave sentences unfinished. Its outputs would just stop mid-sentence, out of nowhere.
Third, the template must match the base model exactly.
This sounds simple, but you have no idea how many people it's tripped up. LLaMA's tokenizer has special tokens, and if your format is wrong, you're essentially feeding garbage to the model. For instance, in the llamafactory-cli config, if you write template: llama3 and get a single letter wrong, the model starts babbling nonsense.
Here's my current SFT habit: I use 10% of the data as a validation set and keep an eye on the loss. If the validation loss stops dropping for two consecutive epochs, I stop. No way am I forcing it through all three epochs.
---
Training the Reward Model: the sweetest trap
The core idea of the Reward Model is to take a pair—a good response and a bad response—and teach the model to score the good one high and the bad one low.
The loss from the paper looks like this:
loss = -log(sigmoid(reward_chosen - reward_rejected))
The loss function itself is fine. But the key question is: what do you use as the reward output?
I've seen three approaches:
First, add a linear layer at the end of the language model to output a single scalar.
Second, take the last token, i.e., the hidden state of , and pass it through a projection layer.
Third,
Cael Lee
Full-stack developer with 8+ years of experience. Currently building AI-powered developer tools. I've tested 20+ AI API providers and coding assistants.