The Challenge
Detecting polyps in gastrointestinal endoscopy is a crucial but challenging task. Traditional detection methods often suffer from slow inference speed, high computational cost, and lower detection accuracy in real-time applications.
With the increasing demand for efficient and accurate AI-driven medical diagnostics, we needed a fast and lightweight solution to detect polyps without sacrificing accuracy.
The Solution: Pruned YOLOv3
Leveraging the power of YOLOv3, we designed an optimized deep learning model that improves real-time polyp detection by applying model pruning techniques.
Key Features
- ** Model Pruning**: Reduces redundant parameters for faster inference.
- ** Training Pipeline**: Implements dataset augmentation and transfer learning.
- ** Real-Time Detection**: Optimized for endoscopy applications.
- ** Performance Optimization**: Achieves higher efficiency with minimal loss in accuracy.
How It Works
I structured the notebook into three main phases: first, I prepared the Kvasir‑SEG dataset and implemented the full YOLOv3 architecture in PyTorch. I downloaded and preprocessed the images and annotations, ran K‑Means clustering to find optimal anchor boxes, and built data loaders with appropriate augmentations. After defining the backbone, detection heads, and loss functions (including mAP and F1 tracking), I trained the unpruned model to establish a performance baseline on both loss curves and detection metrics.
Next, I integrated two pruning strategies—Taylor‑expansion pruning and random pruning—for channel‑wise slimming of the convolutional layers. For Taylor pruning, I accumulated the absolute value of (weight × gradient) over a batch to estimate each channel’s sensitivity to the loss. Channels with the lowest accumulated scores were deemed least important and were masked out. As a sanity check, I also implemented random pruning by selecting the same number of channels at random. Both pruned models were then fine‑tuned for several epochs to recover accuracy.
Below is a distilled snippet showing how I computed Taylor importance scores and applied a pruning mask to each convolutional layer. After creating the masks, I re‑initialized the optimizer and ran a short fine‑tuning loop to adapt the remaining weights.
import torch
import torch.nn as nn
def compute_taylor_scores(model, data_loader, device):
model.eval()
scores = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
scores[name] = torch.zeros(module.out_channels, device=device)
# Accumulate |w * grad| over one pass
for images, targets in data_loader:
images = images.to(device)
outputs = model(images)
loss = model.compute_loss(outputs, targets.to(device))
model.zero_grad()
loss.backward()
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and module.weight.grad is not None:
# Sum over kernel dims and input channels
grad = module.weight.grad.abs()
weight = module.weight.data.abs()
# importance per output channel
imp = (grad * weight).sum(dim=(1,2,3))
scores[name] += imp
break # one batch is enough for scoring
return scores
def prune_model(model, scores, prune_ratio=0.2):
for name, module in model.named_modules():
if name in scores:
num_ch = module.out_channels
k = int(prune_ratio * num_ch)
# get indices of channels to prune
_, prune_idx = torch.topk(scores[name], k, largest=False)
mask = torch.ones(num_ch, device=scores[name].device)
mask[prune_idx] = 0.0
# reshape mask for broadcasting to weights and bias
mask = mask[:, None, None, None]
module.weight.data *= mask
if module.bias is not None:
module.bias.data *= mask.squeeze()
return model
# Example usage:
device = torch.device('cuda')
scores = compute_taylor_scores(model, train_loader, device)
pruned_model = prune_model(model, scores, prune_ratio=0.3)
# Fine‑tune
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=1e-3, momentum=0.9)
for epoch in range(5):
pruned_model.train()
for imgs, tgts in train_loader:
imgs, tgts = imgs.to(device), tgts.to(device)
out = pruned_model(imgs)
loss = pruned_model.compute_loss(out, tgts)
optimizer.zero_grad()
loss.backward()
optimizer.step()