\( \newcommand{\matr}[1] {\mathbf{#1}} \newcommand{\vertbar} {\rule[-1ex]{0.5pt}{2.5ex}} \newcommand{\horzbar} {\rule[.5ex]{2.5ex}{0.5pt}} \newcommand{\E} {\mathrm{E}} \)
abstract banner
Show Question
Math and science::INF ML AI

Sharpness aware minimization

How is the training process changed when using "sharpness aware minimization"? 

Standard SGD/Adam
Find parameters w where loss L(w) is small.
Sharpness aware minimization (SAM)
Find parameters w where loss is small in a neighborhood around w. More precisely, minimize: \( max_{||ε|| ≤ \rho} L(w + ε) \). This finds a point where even the worst nearby perturbation has low loss → flat minimum.

The two passes

The first forward and backward pass are used to update the model parameters up the gradient (the direction of worse performance). From here, we do another forward pass, giving us the gradients we will eventually use. However, we will first move back to the original location before updating.

Intuition

We update parameters using gradient from the perturbed location, not the current location. This pushes us toward regions where the loss is low even when perturbed.

Implementation


def training_step(self, sample, optimizer, scaler):
    """Override training step to handle SAM's two-step process."""
    if not self.use_sam:
        # Standard training
        optimizer.zero_grad()
        _, loss = self.forward(sample)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        return loss.item()
    else:
        # SAM training: requires two forward-backward passes
        def closure():
            optimizer.zero_grad()
            _, loss = self.forward(sample)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            return loss

        # First forward-backward pass to compute gradient
        loss = closure()

        # First step: perturb parameters
        optimizer.first_step(zero_grad=True)

        # Second forward-backward pass at perturbed location
        loss = closure()

        # Second step: actual parameter update
        optimizer.second_step(zero_grad=True)

        return loss.item()