0

Fine-Tuning CoT: GRPO vs. PPO vs. Few-Shot Prompting

Implementation of GRPO, PPO, and CoT few-shot prompting from scratch to compare fine-tuning approaches for Chain-of-Thought reasoning in GPT-Neo using the GSM8K dataset.

GRPO Objective Function

Comparing Fine-Tuning Approaches for Chain-of-Thought (CoT)

This project explores the effectiveness of three different fine-tuning strategies for Chain-of-Thought (CoT) reasoning in large language models:

  • Group Relative Policy Optimization (GRPO)
  • Proximal Policy Optimization (PPO)
  • Few-Shot Prompting

We apply these techniques to GPT-Neo using the GSM8K dataset, a benchmark for arithmetic and logical reasoning. A reward model is trained using DistilRoBERTa-base, providing structured feedback for reinforcement learning.


Implementing GRPO from Scratch

GRPO (Group Relative Policy Optimization) stabilizes policy updates by ensuring a relative improvement over a reference policy. Below is the GRPO loss function used in this implementation:

 
def grpo_loss(policy, old_policy, advantage, epsilon=0.2, beta=0.01):
    ratio = policy / old_policy
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    policy_loss = torch.min(ratio * advantage, clipped_ratio * advantage)
    kl_div = torch.nn.functional.kl_div(policy.log(), old_policy, reduction="batchmean")
    return -policy_loss.mean() - beta * kl_div
 

The original GRPO algorithm does not use GAE or any value‐network–based advantage estimator. In fact, GRPO explicitly omits a critic and simply defines the advantage for each sample in a group as the z‑score of its reward

 
def compute_group_relative_advantage_GRPO(rewards: torch.Tensor,
                                     eps: float = 1e-8) -> torch.Tensor:
                                     
    mu    = rewards.mean(dim=1, keepdim=True)
    sigma = rewards.std(dim=1, unbiased=False, keepdim=True)
    return (rewards - mu) / (sigma + eps)
 
 

Implementing PPO from Scratch

PPO constrains policy updates using a clipped objective to prevent large policy shifts.

 
def ppo_loss(policy, old_policy, advantage, epsilon=0.2):
    ratio = policy / old_policy
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    return -torch.min(ratio * advantage, clipped_ratio * advantage).mean()
 

**Here’s a snippet that (1) computes GAE‐style advantages and then (2) “centers” them per‐group so that each group’s mean advantage is zero:

 
 
def compute_gae(rewards: torch.Tensor,
                values: torch.Tensor,
                dones: torch.Tensor,
                gamma: float = 0.99,
                lam: float = 0.95) -> torch.Tensor:
    
    T = rewards.shape[0]
    advantages = torch.zeros_like(rewards)
    lastgaelam = 0.0
    for t in reversed(range(T)):
        nonterminal = 1.0 - dones[t]
        delta = rewards[t] + gamma * values[t+1] * nonterminal - values[t]
        lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
        advantages[t] = lastgaelam
    return advantages
 
def compute_group_relative_advantage_PPO(rewards: torch.Tensor,
                                      values: torch.Tensor,
                                      dones: torch.Tensor,
                                      group_ids: torch.Tensor,
                                      gamma: float = 0.99,
                                      lam: float = 0.95) -> torch.Tensor:
 
    # 1) GAE
    advantages = compute_gae(rewards, values, dones, gamma, lam)
    
    # 2) center per-group
    group_ids = group_ids.to(advantages.device)
    rel_adv = torch.zeros_like(advantages)
    for g in torch.unique(group_ids):
        mask = (group_ids == g)
        mean_g = advantages[mask].mean() if mask.any() else 0.0
        rel_adv[mask] = advantages[mask] - mean_g
 
    return rel_adv
 
 

Implementing CoT Few-Shot Prompting

Few-shot prompting uses pre-trained models with examples in the prompt instead of fine-tuning.

 
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
 
def generate_cot_response(prompt, num_samples=3):
    few_shot_examples = """
    Q: What is 12 + 15?
    A: Let's break it down. First, 12 + 10 = 22. Then, 22 + 5 = 27. Answer: 27.
    Q: What is 45 - 18?
    A: Let's break it down. First, 45 - 10 = 35. Then, 35 - 8 = 27. Answer: 27.
    """
    full_prompt = few_shot_examples + "\nQ: " + prompt + "\nA: "
    inputs = tokenizer(full_prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=100)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Training the Reward Model

To provide structured feedback for reinforcement learning, we train a DistilRoBERTa-base reward model on a dataset of comparisons.

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
 
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModelForSequenceClassification.from_pretrained("distilroberta-base", num_labels=1)
 
def train_reward_model(train_dataloader, optimizer, criterion, epochs=3):
    model.train()
    for epoch in range(epochs):
        for batch in train_dataloader:
            inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
            labels = batch["labels"].float()
            optimizer.zero_grad()
            outputs = model(**inputs).logits.squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

Performance Analysis

To evaluate these approaches, we measure:

  • Accuracy on GSM8K problems
  • Training stability (variance in policy updates)
  • Efficiency of fine-tuning (training steps to convergence)

GitHub Repository