Facebook AI Introduces Fully Sharded Data Parallel (FSDP) Algorithm That Makes Training Large AI Models Easier Using GPUs

Training AI models is a crucial step in developing a precise and accurate model. In addition to requiring significant computing and engineering resources, most scaling methods for large-scale training models require additional communication costs and involve a high level of engineering complexity. 

Furthermore, engineers need to weigh trade-offs between memory requirements and computational efficiency. For example, typical data parallel training needs to maintain redundant copies of the model on each GPU, whereas parallel model training needs additional communication costs to move activations between workers (GPUs).

Facebook introduces Fully Sharded Data-Parallel (FSDP) that makes training large AI models easier. FSDP is a data-parallel training approach that shards the model’s parameters among data-parallel workers and can offload some training computation to the CPUs if needed. 

The processing for each micro-batch of data is still local to each GPU worker, even though the parameters are sharded among various GPUs. FSDP shards parameters more equally and is capable of higher performance via communication and computation overlaps during training compared to other approaches such as optimizer state+gradient sharding data-parallel method.

The high computational cost of large-scale training

OpenAI trained GPT-3 with 175 billion parameters, the largest ever neural language model last year. It has been estimated that about 355 GPU years have taken to train GPT-3, the equivalent of 1,000 GPUs working continuously for more than five months. With FSDP, it is now possible to efficiently train larger models using fewer GPUs.

FSDP saves memory by sharding model parameters, gradients, and optimizer states across GPUs. Additionally, it increases computational efficiency by decomposing communication and overlaying it with both forward and backward passes. 

How FSDP Works

Although DDP is widely adopted, it consumes a lot more GPU memory than required because the model weights and optimizer states are copied across all DDP workers. The replications can be reduced by employing full parameter sharding, which makes only a fraction of the model parameters, gradients, and optimizers available for local computation. 

The all-reduce operations in DDP can be decomposed into separate reduce-scatter and all-gather operations, which is the key to unlocking full parameter sharding. After each layer’s forward pass, the entire weights can then be discarded to save memory for the following layers, maximizing memory efficiency.

Source: https://engineering.fb.com/2021/07/15/open-source/fsdp/

FSDP achieves the same results as regularly distributed data-parallel (DDP) training and comes with an intuitive interface that replaces PyTorch’s DistributedDataParallel module. The researchers state that it can enable scaling to trillions of parameters.

FSDP is open source and has been implemented in the FairScale library. Engineers and developers can quickly scale and optimize the training of their models with simple APIs. Currently, FSDP supports NLP and vision models with SGD and Adam optimizers. The team plans to generalize FSDP across a wide range of AI algorithms. They also aim to develop algorithms for auto-tuning both GPU memory usage and training performance.

Source: https://engineering.fb.com/2021/07/15/open-source/fsdp/

FairScale Library: https://github.com/facebookresearch/fairscale

Tanushree Shenwai is a consulting intern at MarktechPost. She is currently pursuing her B.Tech from the Indian Institute of Technology(IIT), Bhubaneswar. She is a Data Science enthusiast and has a keen interest in the scope of application of artificial intelligence in various fields. She is passionate about exploring the new advancements in technologies and their real-life application.

[Sponsored] 🐝 Meet Julius AI: An intelligent data analyst tool that enables users to analyze, interpret, and visualize complex data using natural language commands in a chat interface