Distributed Training in PyTorch
Overview of some options available for multi[-node training in pytorch
Contents
Distributed Training
Why?
Need more compute power to process large batches in parallel (DDP)
- Uses collective communication
Large model that couldn’t be fit in memory of one GPU (RPC)
- Uses P2P communication
All of the above XD
DDP in Pytorch
- Every GPU has a model replica, controlled by a process.
- Every process fetches different batch of data.
- Forward.
- Overlapping between computation of and communication(broadcast - allreduced) of gradient.
- Validation
4 steps-recipe to Distributed Training
Initialize Distributed Group
|
|
Data
- Local Training
|
|
- Distributed Training
|
|
Model
- Local Training
|
|
- Distributed Training
|
|
Saving and Logging
|
|
DP vs DDP
DP:
- Can’t scale to multiple machines
- Single Process, multiple threads
- Module is replicated on each device, and the gradients are all summed into the original module
- Doesn’t give the best performance, as a result of the GIL problem with multi-thread applications in python
DDP:
- Can be used for Single machine, multiple GPUs training, or for multi-node training
- It initiates process for every device (eg. 2 nodes, with 4 GPUs each = 8 processes)
- Gradients are gathered using “all_reduce” operation
It’s advised to use DDP for any distributed training
Torch.Distributed.Launch vs Torchrun
Distributes Launch scripts:
- We need to run the script on every node, with the correct ranks
- We have to pass it all necessary environment variables
Example
|
|
Watch out for “&”
Torchrun:
- We run the script only once, and it runs it on all nodes
- It adds all environment variables, and we can use them directly in the code
Example:
|
|
Overall, Torchrun removes a lot of mundane steps.
Hugging Face
Option 1: using HF trainer
- You can use HF trainer, but in this case, you need to manually run the training script on every node ( torch.distributed launch with for loop, or using torchrun once)
Option 2: using HF accelerator
- You just need to run the script using the accelerate library.
- You need to create the training loop manually, and can’t use HF trainer then
PyTorch Lightning
- PyTorch Lightning is the easiest in running distributed training
- We pass in the number of nodes, and number of GPUs per node to the trainer
- Calls your script internally multiple times with the correct environment variables
Common errors:
Dict names issue:
Problem: When we wrap our model with DDP, pytorch adds (module.dict_key) for all keys in the state_dict
Solution: We need to add a function, that detect if we are running distributed training or not, and add or delete “module” from all keys accordingly