Batch Normalization in Neural Networks: A Comprehensive Guide

Contents

  1. 1. Intuition Behind Batch Normalization
  2. 2. Core Logic
  3. 3. Running Statistics
  4. 4. Dynamic Momentum vs Static Momentum
  5. 5. Batch Normalization in MultiGPU Training
  6. 6. Multi GPU Running Stats (ADVANCED)

1. Intuition Behind Batch Normalization

Imagine you’re training a machine learning model (like a neural network) to classify images. Each layer of the model learns patterns from the data, but sometimes, these patterns get misaligned due to large variations in the input data or intermediate outputs (activations). This can slow down training and make optimization difficult.

Batch Normalization (BN) is like a “balancing act” that ensures the activations (outputs of neurons) at each layer remain well-scaled and centered. It stabilizes the learning process and helps the network converge faster.

Think of it as making sure the data flowing through your network behaves consistently, like keeping traffic on a highway moving smoothly without bottlenecks.

Why is Batch Normalization Useful?

  • Stabilizes Training: Keeps activations well-scaled, avoiding extreme values that can destabilize training.
  • Speeds Up Convergence: Reduces the need for careful initialization and aggressive learning rate tuning.
  • Regularization Effect: Acts as a mild form of regularization by adding noise to activations, reducing over fitting.
  • Accurate running statistics help the model generalize to unseen data during inference, ensuring consistent prediction.

2. Core Logic

At its core, Batch Normalization normalizes the activations of a layer for each mini-batch during training.

2.1 Training Phase

  1. Calculate the Mean and Variance

    For each feature in the mini-batch, compute the mean (𝜇) and variance (𝜎²).

  2. Normalize the Activations

    Center and scale the activations so they have zero mean and unit variance:

    \[\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}\]

    Here, 𝜖 is a small constant added for numerical stability (to avoid division by zero).

  3. Learnable Scaling and Shifting

    Instead of just normalizing, Batch Normalization introduces learnable parameters 𝛾 (scale) and 𝛽 (shift), allowing the model to adjust the normalized values:

    \[y = \gamma \hat{x} + \beta\]
  • learnable parameters 𝛾 (scale) and 𝛽 (shift) are learned separately for each BN layer.
  • learnable parameters 𝛾 (scale) and 𝛽 (shift) are shared across all mini-batches (whole batch) per BN layer in the network.
  • learnable parameters 𝛾 (scale) and 𝛽 (shift) are learned/updated during back propagation.
  1. Important to Note

    If you have a neural network with 3 Batch Normalization layers, and each layer has $k$ output features, then:

    1. Each layer has its own set of $\gamma$ and $\beta$ parameters, with $k$ parameters for each.
    2. These parameters are learned globally for the layer and remain fixed after training.

    For example:

    • Layer 1 (output dimension $k_1$): \(\gamma_1 \in \mathbb{R}^{k_1}, \quad \beta_1 \in \mathbb{R}^{k_1}\)
    • Layer 2 (output dimension $k_2$): \(\gamma_2 \in \mathbb{R}^{k_2}, \quad \beta_2 \in \mathbb{R}^{k_2}\)

Batch Normalization Example

Imagine a single layer of a neural network that outputs:

\[x = [2.0, 4.0, 6.0, 8.0]\]

The mean ($\mu$) is calculated as:

\[\mu = \frac{2.0 + 4.0 + 6.0 + 8.0}{4} = 5.0\]

The variance ($\sigma^2$) is calculated as:

\[\sigma^2 = \frac{(2 - 5)^2 + (4 - 5)^2 + (6 - 5)^2 + (8 - 5)^2}{4} = 5.0\]

Step 2: Normalize

For each $x_i$, calculate the normalized value:

\[\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}\]

Using $\epsilon = 1e^{-5}$ for numerical stability:

\[\hat{x} = \left[ \frac{2 - 5}{\sqrt{5 + 1e^{-5}}}, \frac{4 - 5}{\sqrt{5 + 1e^{-5}}}, \frac{6 - 5}{\sqrt{5 + 1e^{-5}}}, \frac{8 - 5}{\sqrt{5 + 1e^{-5}}} \right]\]

Simplifying:

\[\hat{x} = [-1.34, -0.45, 0.45, 1.34]\]

Step 3: Scale and Shift

Assume $\gamma = 2.0$ (scale) and $\beta = 1.0$ (shift). Then:

\[y_i = \gamma \hat{x}_i + \beta\]

Substitute the values:

\[y = [2(-1.34) + 1, 2(-0.45) + 1, 2(0.45) + 1, 2(1.34) + 1]\]

Simplifying:

\[y = [-1.68, 0.1, 1.9, 3.68]\]
A Python-generated visualization of the example activations before normalization, after normalization, and after scale-and-shift.

2.2 Inference/Testing Phase

During inference, instead of recalculating mean (𝜇) and variance (𝜎²) for each batch, use moving averages / running statistics of the mean and variance collected during training.

3. Running Statistics

Note: For simplicity of understanding assume a single GPU training instant for following. So batch == mini batch.

3.1 Introduction

In Batch Normalization, running statistics refer to the mean (𝜇) and variance (𝜎²) of the activations that are computed per BN layer for all batches during training. These statistics represent a global mean and variance of the data.

  • During training, the mean and variance are computed on each mini-batch. But we also maintain a “moving average” of these values across all mini-batches. This is what we call the running statistics. Do not mistake Running statistics as same as learnable parameters, they are different concepts.
  • During inference, these running statistics are used instead of recalculating the mean and variance on each input batch.

3.2 Implementation

The running mean and variance are updated incrementally using a moving average formula (exponential moving average (EMA)) as each batch is processed. The idea is to slowly average the batch statistics over time to estimate the overall mean and variance.

For the mean ($\mu_{\text{running}}$) and variance ($\sigma_{\text{running}}^2$):

\[\mu_{\text{running}} \gets (1 - \alpha) \cdot \mu_{\text{running}} + \alpha \cdot \mu_{\text{current-batch}}\] \[\sigma_{\text{running}}^2 \gets (1 - \alpha) \cdot \sigma_{\text{running}}^2 + \alpha \cdot \sigma_{\text{current-batch}}^2\]

Where:

  • $\alpha$ is the momentum parameter (e.g., 0.9 or 0.99), which controls the influence or how quickly the running statistics adapt to new data.
  • (1−$\alpha$) determines how much influence previous running statistics have.
  • $\mu_{\text{current-batch}}$ and $\sigma_{\text{current-batch}}^2$ are the mean and variance computed on the current mini-batch.

Once training is complete:

The running mean ($\mu_{\text{running}}$) and running variance ($\sigma_{\text{running}}^2$) are fixed. These are used for normalization during inference to ensure consistency:

\[\hat{x} = \frac{x - \mu_{\text{running}}}{\sqrt{\sigma_{\text{running}}^2 + \epsilon}}\]

Key Properties of EMA

  1. Weighted Contribution:
    • Recent mini-batches contribute more to the running statistics.
    • Older mini-batches have less influence over time due to the decaying factor (1−α).
  2. Smoothed Global Estimate:
    • The running mean and variance gradually stabilize over many mini-batches and approximate the global statistics of the entire dataset.

Example:

Suppose we are training with 3 mini-batches, and the batch-specific means are:

Mini-batch 1: $\mu_{\text{current-batch}} = 5.0$

Mini-batch 2: $\mu_{\text{current-batch}} = 6.0$

Mini-batch 3: $\mu_{\text{current-batch}} = 7.0$

Let the initial running mean be $\mu_{\text{running}} = 0.0$ and momentum $\alpha = 0.9$.

Step 1 (Mini-batch 1):

\[\mu_{\text{running}} = (1 - 0.9) \cdot 0.0 + 0.9 \cdot 5.0 = 4.5\]

Step 2 (Mini-batch 2):

\[\mu_{\text{running}} = (1 - 0.9) \cdot 4.5 + 0.9 \cdot 6.0 = 5.85\]

Step 3 (Mini-batch 3):

\[\mu_{\text{running}} = (1 - 0.9) \cdot 5.85 + 0.9 \cdot 7.0 = 6.765\]

Similarly, the running variance $\sigma_{\text{running}}^2$ is updated using the same formula with the batch variance.

4. Dynamic Momentum vs Static Momentum

Formulas to calculate moving average with a fixed momentum values have a static momentum. EMA has a static momentum (eg: 0.9).

Dynamic momentum offers several advantages over static momentum when calculating running statistics in Batch Normalization:

  1. Adaptability to Training Dynamics:
    • Dynamic momentum adjusts the contribution of the current batch and past batches based on the training stage or data behavior.
    • In early training, it may weigh current statistics more heavily, allowing faster adaptation to changing patterns.
    • Later in training, it can prioritize past statistics, stabilizing the updates.
  2. Improved Stability:
    • Static momentum uses a fixed value, which might not suit all training stages or datasets.
    • Dynamic momentum can help avoid oscillations during training by adapting the smoothing factor as needed.
  3. Better Handling of Non-Stationary Data:
    • Dynamic momentum can respond to shifts in data distribution more effectively, ensuring that the running statistics stay representative of the underlying data.
  4. Optimized Convergence:
    • By adjusting the momentum dynamically, models can converge faster and more reliably, especially in cases where the optimal fixed momentum value is hard to determine.
  5. Customizability:
    • Dynamic momentum allows for more nuanced control, as it can be tied to training variables like the learning rate, epoch, or batch size, making it more flexible for different use cases.

Momentum is adjusted based on the training epoch or iteration. This is what is usually enclosed inside a Batch Normalization Momentum Schedular.

Common strategies include:

  1. Linear Decay: \(\text{momentum} = \text{start} - \left( \frac{\text{epoch}}{\text{total\_epochs}} \times (\text{start} - \text{end}) \right)\)
  2. Exponential Decay: \(\text{momentum} = \text{start} \cdot \text{(decay\_rate)}^{\text{epoch}}\)

Example

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Dummy data generator
def generate_data(batch_size, features, num_batches):
    for _ in range(num_batches):
        yield torch.randn(batch_size, features)

# Simple model with BatchNorm
class SimpleModel(nn.Module):
    def __init__(self, momentum):
        super(SimpleModel, self).__init__()
        self.bn = nn.BatchNorm1d(5, momentum=momentum)

    def forward(self, x):
        return self.bn(x)

# Training function
def train_model(model, data_loader, dynamic_momentum=False):
    running_means, running_vars = [], []

    for epoch, batch in enumerate(data_loader):
        if dynamic_momentum:
            # Dynamically adjust momentum based on epoch
            new_momentum = max(0.9 - (epoch / 50), 0.1)  # Linear decay
            model.bn.momentum = new_momentum

        # Forward pass
        _ = model(batch)

        # Save running statistics
        running_means.append(model.bn.running_mean.clone().cpu())
        running_vars.append(model.bn.running_var.clone().cpu())

    return running_means, running_vars

# Generate dummy data
data_loader = generate_data(batch_size=16, features=5, num_batches=50)

# Static Momentum
model_static = SimpleModel(momentum=0.9)
static_means, static_vars = train_model(model_static, data_loader)

# Dynamic Momentum
model_dynamic = SimpleModel(momentum=0.9)  # Initial momentum
data_loader = generate_data(batch_size=16, features=5, num_batches=50)
dynamic_means, dynamic_vars = train_model(model_dynamic, data_loader, dynamic_momentum=True)

# Plot Results
epochs = range(len(static_means))

plt.figure(figsize=(12, 5))

# Running mean comparison
plt.subplot(1, 2, 1)
plt.plot(epochs, [mean.mean().item() for mean in static_means], label="Static Momentum")
plt.plot(epochs, [mean.mean().item() for mean in dynamic_means], label="Dynamic Momentum")
plt.title("Running Mean")
plt.xlabel("Epoch")
plt.ylabel("Mean Value")
plt.legend()

# Running variance comparison
plt.subplot(1, 2, 2)
plt.plot(epochs, [var.mean().item() for var in static_vars], label="Static Momentum")
plt.plot(epochs, [var.mean().item() for var in dynamic_vars], label="Dynamic Momentum")
plt.title("Running Variance")
plt.xlabel("Epoch")
plt.ylabel("Variance Value")
plt.legend()

plt.tight_layout()
plt.show()
A Python-generated comparison of running mean and running variance under static momentum and a simple dynamic momentum schedule.

5. Batch Normalization in MultiGPU Training

When training on multiple GPUs using Distributed Data Parallel (DDP), the data is divided across GPUs. Each GPU processes its portion of the batch independently. Batch Normalization works on per-GPU mini-batches, meaning that the mean and variance for normalization are computed only from the samples on that GPU.

Problem in Multi-GPU Training

  • Inconsistent Statistics: Since each GPU computes its own mean and variance for Batch Normalization, the model might end up using slightly different statistics for the same layer on different GPUs. This inconsistency can lead to degraded performance, especially when the per-GPU batch size is small.
  • Small Batch Size Issue: Batch Normalization relies on accurate statistics, which are harder to compute with small batch sizes.

5.1 Synchronized Batch Normalization

Synchronized Batch Normalization (SyncBatchNorm) solves the problem by synchronizing the mean and variance across all GPUs during training. Instead of computing batch statistics per GPU, SyncBatchNorm gathers activations from all GPUs, computes global statistics (mean and variance), and uses these shared statistics for normalization.

How SyncBatchNorm Works

  • During the forward pass:
    • Activations from all GPUs are gathered.
    • Global mean and variance are computed using all activations.
    • These global statistics are used to normalize the activations on all GPUs.
  • The process involves communication between GPUs, which is handled using efficient backends like NCCL.

To convert a model to use Synchronized Batch Normalization in Pytorch:

The following code is partial and only an example!

from torch.nn import SyncBatchNorm
import pytorch_lightning as pl
from pytorch_lightning import Trainer

# Instantiate the model
model = SimpleNN()

# Convert to SyncBatchNorm
model = SyncBatchNorm.convert_sync_batchnorm(model)

# Wrap with LightningModule
class LightningModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

# Train with DDP
trainer = Trainer(
    strategy="ddp",
    devices=4,
    accelerator="gpu",
    max_epochs=10
)

lightning_model = LightningModel(model)
trainer.fit(lightning_model)

6. Multi GPU Running Stats (ADVANCED)

Key Challenge in Multi-GPU Training

In multi-GPU setups:

  • Each GPU processes its own mini-batch of data independently.
  • Batch Normalization is designed to compute the mean and variance for the entire mini-batch across the GPUs, but each GPU initially only “sees” the activations of its local mini-batch.

Example:

If:

  • Batch size per GPU = 32
  • Number of GPUs = 3

The total batch size is $32 \times 3 = 96$.

However:

  • GPU 1 computes statistics for its 32 samples.
  • GPU 2 computes statistics for its 32 samples.
  • GPU 3 computes statistics for its 32 samples.

This results in local batch statistics on each GPU, which are insufficient for calculating global statistics across all 96 samples.

Solution: Synchronizing Batch Statistics

To solve this, synchronized Batch Normalization (SyncBN) is used. Here’s how it works:

Step 1: Compute Local Statistics

Each GPU computes the mean ($\mu_{\text{local}}$) and variance ($\sigma_{\text{local}}^2$) for its local mini-batch of size 32.

Step 2: All-Reduce Operation

The local statistics from all GPUs are aggregated using a collective communication operation called all-reduce. This operation computes the global mean and variance by:

  • Summing the local means and variances from all GPUs.
  • Dividing by the total number of samples ($32 \times 3 = 96$) to compute the global statistics.

Formulas:

Global Mean:

\[\mu_{\text{global}} = \frac{\sum_{g=1}^{N} \left( \mu_{\text{local}, g} \cdot \text{batch\_size}_g \right)}{\text{total\_batch\_size}}\]

Global Variance:

\[\sigma_{\text{global}}^2 = \frac{\sum_{g=1}^{N} \left( \sigma_{\text{local}, g}^2 + \left( \mu_{\text{local}, g} - \mu_{\text{global}} \right)^2 \right) \cdot \text{batch\_size}_g}{\text{total\_batch\_size}}\]

Where:

  • $N$: Number of GPUs.
  • $\text{batch_size}_g$: Mini-batch size on GPU $g$.

6.1 Running Statistics in Multi-GPU Training

When it comes to the running statistics (e.g., running mean and variance used during inference):

  • After each mini-batch:
    • The global statistics ($\mu_{\text{global}}$ and $\sigma_{\text{global}}^2$) are computed using the method above.
    • Update the running statistics using the same moving average formula:
\[\mu_{\text{running}} \gets (1 - \alpha) \cdot \mu_{\text{running}} + \alpha \cdot \mu_{\text{global}}\] \[\sigma_{\text{running}}^2 \gets (1 - \alpha) \cdot \sigma_{\text{running}}^2 + \alpha \cdot \sigma_{\text{global}}^2\]

Thus, the running statistics reflect the global behavior of the activations across all GPUs.

6.2 Why SyncBN Is Important

Without synchronized Batch Normalization:

  • Each GPU would compute statistics based only on its local mini-batch.
  • This could lead to inconsistent normalization and harm the model’s performance, especially when the total batch size is large or the batch size per GPU is small.

6.3 Alternatives to SyncBN

When SyncBN is impractical or expensive:

  • Group Normalization (GN): Divides channels into groups and normalizes them within each group, removing the need for batch-wide statistics.
  • Instance Normalization (IN): Normalizes each sample independently.
  • Layer Normalization (LN): Normalizes across the features of each input independently of the batch size.

These alternatives avoid the need for synchronization but might have different performance characteristics.

6.4 Summary

In multi-GPU training, synchronized Batch Normalization is used to compute global mean and variance across all GPUs.

  • The local statistics on each GPU are aggregated using an all-reduce operation.
  • The running statistics (used for inference) are updated based on the global statistics, ensuring consistency across GPUs.
  • SyncBN is critical for maintaining consistent normalization across all GPUs, especially in distributed setups.

Share this article

LinkedIn X Facebook Reddit WhatsApp Email



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Assembling objects with robots