Introduction
The journey of computing innovation has been punctuated by optimizations that appear straightforward but have often gone unnoticed for years.
FlashAttention (2022) exemplifies such a pivotal advancement. While many researchers concentrated on minimizing FLOPs through approximations, Tri Dao and the FlashAttention team identified the real issue: unnecessary memory accesses between GPU HBM and SRAM. By employing classic techniques like kernel fusion and tiling, FlashAttention achieved significant time reductions in attention computation without sacrificing accuracy as seen in approximation methods.
Subsequently, FlashAttention-2 (2023) further refined this hardware and IO-marketing strategy, doubling the speed of its predecessor.
This article will delve into how FlashAttention-2 advanced FlashAttention. The key modifications to the algorithm include:
- Minimizing non-matmul FLOPs to sustain high throughput
- Modifying work distribution among warps to lower shared memory accesses
- Enhancing occupancy through improved parallelization
Prerequisites
Before proceeding, it’s recommended to review the previous article on FlashAttention. Familiarity with GPU performance optimization, knowledge of the GPU memory hierarchy, and understanding warps could be beneficial.
Reduce Non-Matmul FLOPs to Maintain High Throughput
Maintaining a high throughput—the rate at which a GPU processes data—is essential for managing increased workloads. To achieve this, applications must be designed to effectively harness computational resources.
NVIDIA GPUs feature optimized processing units known as Tensor Cores, which accelerate matrix multiplication. However, floating point calculations that do not involve matrix multiplication (non-matmul FLOPs) are not fast-tracked with these specialized units and thus take longer to execute. By cutting down on non-matmul tasks that lack the computational power of Tensor Cores, high throughput can be preserved.
Fused Attention vs PyTorch Attention
This illustration demonstrates that a significant portion of attention computation time is consumed by non-matmul operations.
Modifying the Online Softmax Trick to Reduce Non-Matmul FLOPs
FlashAttention-2 sought to lower non-matmul FLOPs by pinpointing areas for adjustment that would not impact the final output. This was accomplished by refining the computation of online softmax.
For a more comprehensive understanding, the blog post by Aleksa Gordić and the article by Zihao Ye effectively break down the algorithm’s intricacies.
In FlashAttention, the softmax was computed incrementally for each block while maintaining additional statistics (m, l) to adjust the output. The total of all individual blocks was consolidated at the conclusion for the correct result. However, FlashAttention-2 keeps an unscaled version of the output until the completion of the loop, scaling only at that point to derive the final output.
Moreover, FlashAttention-2 opts to store only logsumexp for the backward pass instead of both the maximum 𝑚 ^( 𝑗 ) and sum of exponentials ℓ^( 𝑗 ), which condenses the resources used.
Adjusting How Work is Partitioned Among Warps to Reduce Shared Memory Access
Recall that a thread consists of the program’s code, its current execution point, and the variable values and data structures. These threads are organized into thread blocks and executed by a streaming multiprocessor, which handles numerous threads simultaneously.
Warp-level thread management becomes feasible due to NVIDIA’s single instruction multiple thread (SIMT) model, where one instruction operates across multiple threads within a 32-thread warp. Threads in a warp can collaborate to conduct tasks like matrix multiplication and can also communicate by accessing shared memory.
Distributing tasks to warps involves dividing substantial computations into smaller, manageable tasks for concurrent execution. Suboptimal task assignments can lead to repeated access to shared memory. Therefore, FlashAttention-2 aims to optimize shared memory access through the strategic division of attention computation among warps.
FlashAttention (Split-K) | FlashAttention-2 (Split-Q) | |
---|---|---|
Which matrix/matrices are split among 4 warps? | K and V | Q |
Which matrix/matrices are accessible by all 4 warps? | Q | K and V |
How is QK^T computed? | Each of the 4 warps multiply with each other for a partial QK^T sum. | Each warp calculates its portion of QK^T. |
Is synchronization and communication between warps necessary? | Yes, all 4 warps must write their intermediate results to shared memory, synchronize, and sum the intermediate outcomes. | Eliminates communication or synchronization between warps for the forward pass. Each warp can directly multiply with V for the output. In the backward pass, some synchronization is required to manage complex dependencies among inputs and gradients. |
What does this mean for speed? | Forward pass incurs delays due to multiple shared memory reads/writes. | Sharing K^T and V among warps while splitting Q eradicates the need for shared memory accesses between warps, thus speeding up both forward and backward passes compared to FlashAttention. |
Increase Occupancy With More Parallelization
Occupancy represents the ratio of the number of warps allocated to a streaming multiprocessor compared to the maximum supported. Memory-intensive tasks like attention computation benefit from higher occupancy.
The A100 GPU’s 108 streaming multiprocessors optimally function with a minimum of 80 thread blocks. Insufficient thread blocks may leave streaming multiprocessors inactive, leading to GPU resources not being fully utilized.
FlashAttention-2 boosts occupancy through parallelization across sequence length. This approach involves executing independent tasks concurrently across thread blocks without requiring synchronization.
FlashAttention (and Standard Multi-Head Attention) parallelizes over: | FlashAttention-2 parallelizes over: | |
---|---|---|
Batch size: The number of input sequences in a batch | ✔️ | ✔️ |
Head dimension: The number of attention heads | ✔️ | ✔️ |
Sequence length: The number of elements in an input sequence | ✔️ |
While FlashAttention parallelizes batch size and head dimension, meaning thread blocks equal batch size * head dimension, FlashAttention-2 expands this to include sequence length. Thus, thread blocks equal batch size * head dimension * sequence length.
Why Bother Parallelizing Over Sequence Length?
A longer sequence length leads to a reduced batch size since fewer input sequences can fit in one batch.
For FlashAttention, this results in fewer active thread blocks due to the calculation being limited to batch size * head dimension. Consequently, incorporating sequence length into FlashAttention-2’s parallelization permits better use of the GPU’s streaming multiprocessors, where the number of thread blocks now is batch size * head dimension * sequence length.
Loop Reversal
This illustration reveals how FlashAttention processes K and V matrices by loading them into SRAM in a specific loop order (illustrated with red arrows). FlashAttention loads blocks of the Q matrix and writes the attention computation output back to HBM, while FlashAttention-2 reverses this process.
Loop | FlashAttention | FlashAttention-2 |
---|---|---|
Outer | Over Q blocks | Over K,V blocks |
Inner | Over K,V blocks | Over Q blocks |
Phil Tillet was instrumental in developing and executing these optimizations of loop order reversal and sequence length parallelization in Triton.
This figure demonstrates the forward and backward passes of the parallelization approach, depicting worker thread blocks.
Forward pass: Each thread block processes a block of rows from the attention matrix.
Backward pass: Each thread block handles a block of columns from the attention matrix.
Conclusion
In summary, FlashAttention-2 improved upon FlashAttention by reducing non-matmul FLOPs for high throughput, adding sequence-length parallelization for better occupancy, and effective work partitioning among warps to minimize shared memory interaction.
The achievements of FlashAttention and FlashAttention-2 demonstrate that collaborating with hardware rather than working against it can yield outstanding results. By deeply understanding the systems we operate with instead of viewing them simply as remote technologies, revolutionizing advancements in technology becomes possible.
Thank you for engaging with the DigitalOcean Community. Explore our services for compute, storage, networking, and managed databases.
Learn more about our products
Welcome to DediRock, your trusted partner in high-performance hosting solutions. At DediRock, we specialize in providing dedicated servers, VPS hosting, and cloud services tailored to meet the unique needs of businesses and individuals alike. Our mission is to deliver reliable, scalable, and secure hosting solutions that empower our clients to achieve their digital goals. With a commitment to exceptional customer support, cutting-edge technology, and robust infrastructure, DediRock stands out as a leader in the hosting industry. Join us and experience the difference that dedicated service and unwavering reliability can make for your online presence. Launch our website.