Comparing Fine-Tuning Approaches for Chain-of-Thought (CoT)
This project explores the effectiveness of three different fine-tuning strategies for Chain-of-Thought (CoT) reasoning in large language models:
- Group Relative Policy Optimization (GRPO)
- Proximal Policy Optimization (PPO)
- Few-Shot Prompting
We apply these techniques to GPT-Neo using the GSM8K dataset, a benchmark for arithmetic and logical reasoning. A reward model is trained using DistilRoBERTa-base, providing structured feedback for reinforcement learning.
Implementing GRPO from Scratch
GRPO (Group Relative Policy Optimization) stabilizes policy updates by ensuring a relative improvement over a reference policy. Below is the GRPO loss function used in this implementation:
def grpo_loss(policy, old_policy, advantage, epsilon=0.2, beta=0.01):
ratio = policy / old_policy
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
policy_loss = torch.min(ratio * advantage, clipped_ratio * advantage)
kl_div = torch.nn.functional.kl_div(policy.log(), old_policy, reduction="batchmean")
return -policy_loss.mean() - beta * kl_div
The original GRPO algorithm does not use GAE or any value‐network–based advantage estimator. In fact, GRPO explicitly omits a critic and simply defines the advantage for each sample in a group as the z‑score of its reward
def compute_group_relative_advantage_GRPO(rewards: torch.Tensor,
eps: float = 1e-8) -> torch.Tensor:
mu = rewards.mean(dim=1, keepdim=True)
sigma = rewards.std(dim=1, unbiased=False, keepdim=True)
return (rewards - mu) / (sigma + eps)
Implementing PPO from Scratch
PPO constrains policy updates using a clipped objective to prevent large policy shifts.
def ppo_loss(policy, old_policy, advantage, epsilon=0.2):
ratio = policy / old_policy
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
return -torch.min(ratio * advantage, clipped_ratio * advantage).mean()
**Here’s a snippet that (1) computes GAE‐style advantages and then (2) “centers” them per‐group so that each group’s mean advantage is zero:
def compute_gae(rewards: torch.Tensor,
values: torch.Tensor,
dones: torch.Tensor,
gamma: float = 0.99,
lam: float = 0.95) -> torch.Tensor:
T = rewards.shape[0]
advantages = torch.zeros_like(rewards)
lastgaelam = 0.0
for t in reversed(range(T)):
nonterminal = 1.0 - dones[t]
delta = rewards[t] + gamma * values[t+1] * nonterminal - values[t]
lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
advantages[t] = lastgaelam
return advantages
def compute_group_relative_advantage_PPO(rewards: torch.Tensor,
values: torch.Tensor,
dones: torch.Tensor,
group_ids: torch.Tensor,
gamma: float = 0.99,
lam: float = 0.95) -> torch.Tensor:
# 1) GAE
advantages = compute_gae(rewards, values, dones, gamma, lam)
# 2) center per-group
group_ids = group_ids.to(advantages.device)
rel_adv = torch.zeros_like(advantages)
for g in torch.unique(group_ids):
mask = (group_ids == g)
mean_g = advantages[mask].mean() if mask.any() else 0.0
rel_adv[mask] = advantages[mask] - mean_g
return rel_adv
Implementing CoT Few-Shot Prompting
Few-shot prompting uses pre-trained models with examples in the prompt instead of fine-tuning.
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
def generate_cot_response(prompt, num_samples=3):
few_shot_examples = """
Q: What is 12 + 15?
A: Let's break it down. First, 12 + 10 = 22. Then, 22 + 5 = 27. Answer: 27.
Q: What is 45 - 18?
A: Let's break it down. First, 45 - 10 = 35. Then, 35 - 8 = 27. Answer: 27.
"""
full_prompt = few_shot_examples + "\nQ: " + prompt + "\nA: "
inputs = tokenizer(full_prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
Training the Reward Model
To provide structured feedback for reinforcement learning, we train a DistilRoBERTa-base reward model on a dataset of comparisons.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModelForSequenceClassification.from_pretrained("distilroberta-base", num_labels=1)
def train_reward_model(train_dataloader, optimizer, criterion, epochs=3):
model.train()
for epoch in range(epochs):
for batch in train_dataloader:
inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
labels = batch["labels"].float()
optimizer.zero_grad()
outputs = model(**inputs).logits.squeeze()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
Performance Analysis
To evaluate these approaches, we measure:
- Accuracy on GSM8K problems
- Training stability (variance in policy updates)
- Efficiency of fine-tuning (training steps to convergence)