Math and science::INF ML AI
Sharpness aware minimization
How is the training process changed when using "sharpness aware minimization"?
- Standard SGD/Adam
- Find parameters w where loss L(w) is small.
- Sharpness aware minimization (SAM)
- Find parameters w where loss is small in a neighborhood around w. More precisely, minimize: \( max_{||ε|| ≤ \rho} L(w + ε) \). This finds a point where even the worst nearby perturbation has low loss → flat minimum.
The two passes
The first forward and backward pass are used to update the model parameters up the gradient (the direction of worse performance). From here, we do another forward pass, giving us the gradients we will eventually use. However, we will first move back to the original location before updating.
Intuition
We update parameters using gradient from the perturbed location, not the current location. This pushes us toward regions where the loss is low even when perturbed.
Implementation
def training_step(self, sample, optimizer, scaler):
"""Override training step to handle SAM's two-step process."""
if not self.use_sam:
# Standard training
optimizer.zero_grad()
_, loss = self.forward(sample)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
return loss.item()
else:
# SAM training: requires two forward-backward passes
def closure():
optimizer.zero_grad()
_, loss = self.forward(sample)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
return loss
# First forward-backward pass to compute gradient
loss = closure()
# First step: perturb parameters
optimizer.first_step(zero_grad=True)
# Second forward-backward pass at perturbed location
loss = closure()
# Second step: actual parameter update
optimizer.second_step(zero_grad=True)
return loss.item()