Distributed Training Basics

date
Feb 16, 2025
slug
dist-train
status
Published
tags
MLSys
summary
type
Post

TL; DR

This post summarizes two popular DP techniques: DeepSpeed ZeRO and FSDP.

Why we need distributed training?

As modern large language models continue to grow in size, we cannot train them using a single GPU (or even a single node) since they require substantial amounts of GPU memory for weights, activations, gradients, and optimizer states. Therefore, we need to use distributed training and partition these components across multiple GPUs.

Parallel Strategies

Current parallel strategies can be categorized to three categories according to classical view:
  • Data Parallelism: Each GPU only get a fraction of mini-batch data, and have a full copy of model weights. After each gradient computation pass, the gradient is averaged across all workers using all-reduce.
  • Model Parallelism / Vertival Model Parallelism: Models are vertically sliced and different layers are placed on different GPUs. However, naive vertical model parallelism will require GPU computing later layers to wait for previous layers computation. This can be improved by Pipeline Parallelism.
  • Tensor Parallelism: In tensor parallelism, models are sliced horizontally across GPUs. Different workers process the same batch of data, and exchange data they need from other workers, with each worker computing the gradient of the slice on the worker.
However, we can also classify the parallelism using a more computationally clear criteria:
  • Inter-Op Parallelism: Assign different operators to different devices.
  • Intra-Op Parallelism: Assign different regions of a single operator to different devices.
notion image
Using the above criteria, we can classify Data Parallelism and Tensor Parallelism to Intra-Op, and vertical model parallelism to Inter-Op.

ZeRO-powered Data Parallelism

notion image
ZeRO (Zero Redundancy Optimizer) is a form of data parallelism that exploits the redundancy in data parallel training and improves memory efficiency by trading off some communication volume(depending on the ZeRO stage).

Baseline

Simple data parallelism, as implemented in Pytorch DDP. Gradients are averaged across different GPUs using all-reduce (typically implemented by a reduce-scatter and all-gather) and model weights are updated accordingly on each GPU.

ZeRO Stage 1: Optimizer State Partitioning

The first stage distributes optimizer states across different GPUs, with each GPU containing only a non-redundant partition. During forward and backward passes, the process remains the same as normal data parallelism. After using all-reduce to obtain the average gradient, each GPU updates only its own partition of optimizer states.
To minimize communication overhead, after the initial reduce-scatter step, different parts of the averaged gradients reside on each worker, allowing us to update the corresponding weights. We then perform the all-gather operation as in normal data parallelism. Therefore it has the same communication volume as simple data parallelism.

ZeRO Stage 2: Optimizer + Gradient Partitioning

In stage 2, both optimizer states and gradients are partitioned. After gradient computation, a reduce-scatter step is performed, and each GPU receives its own gradient partition to perform optimizer state and weight updates. An all-gather step then synchronizes the model weights across GPUs. The communication volume remains the same as normal data parallelism.

ZeRO Stage 3: Optimizer + Gradient + Parameter Partitioning

In stage 3, optimizer states, gradients, and parameters are all partitioned. To perform a forward pass, parameters must first be gathered using an all-gather operation, which adds N data movements for parameters. Combined with the reduce-scatter and final all-gather operations, this totals 3N data movements. As a result, the communication volume increases by 1.5x.

Follow up optimizations

  • ZeRO-R: partition activation as well
  • ZeRO-Offload: offload less intensive operations to host CPU
  • ZeRO-Infinity: state offload to NVMe as well.
  • ZeRO++: hierarchical partitioning + quantized weights + gradients

Fully-Sharded Data Parallelism

FSDP is another data parallel technique that is based on the idea of ZeRO. FSDP has two sharding strategies: full sharding and hybrid sharding.

Full Sharding

notion image
Model weights are partitioned across GPUs, during forward pass:
  • gather model weights using all-gather
  • run forward computation
  • discard redundant weights
During backward pass:
  • gather model weights using all-gather
  • run backward computation
  • run reduce-scatter to sync gradients
  • discard redundant weights
To summary: Full sharding model is to partition weights on multiple GPUs and gather data needed on demand.

Hybrid Sharding

notion image
In hybrid sharding mode, model weights are partitioned within each node but replicated across nodes. Within a node, it functions like full sharding, while gradients are further reduced across all nodes.

Are DeepSpeed and FSDP here to stay?

DeepSpeed ZeRO and FSDP are likely to remain dominant or be replaced only by similar strategies, thanks to their ease of use and straightforward integration with various components of large language models.

Efficient-Finetuning Techniques

Some common techniques used in efficient finetuning:
  • mixed precision
  • parameter-efficient finetuning (PEFT)
  • flash-attention
  • gradient / activation checkpointing
  • quantization
  • gradient accumulation

© Lifan Sun 2023 - 2025