Attention is All You Need
If you weren’t already aware, the T in chatGPT stands for Transformer, a crucial framework for developing advanced AI. Originally engineered for machine translation, the Transformer is a neural network architecture that pioneered self-attention in the paper: Attention is All You Need. This framework employs layers of interconnected nodes that construct an internal mathematical representation to determine relationships and relevance, converting an input sequence into an output sequence.
If Attention is All You Need, Let’s Make it Better…
The emergence of the Transformer architecture opened a new chapter in AI research, concentrating on amplifying the efficiency of its fundamental mechanism: attention. The scalability of attention is hindered by its time and memory complexity, which increases quadratically or O(n^2) with the input sequence length, n. This poses significant challenges, as effectively modeling long sequences is vital for capturing long-range dependencies essential for interpreting lengthy texts, codebases, and high-resolution images. To address these challenges, numerous researchers are developing hardware-aware and memory-efficient algorithms, such as FlashAttention.
Introduction
This article aims to showcase the concepts that contributed to the success of FlashAttention (2022) in achieving a wall-clock speedup over traditional attention mechanisms. Techniques utilized in its second (2023) and third (2024) iterations will be discussed in future blog posts.
Prerequisites
A working knowledge of the following topics will aid in comprehending the material discussed in this article:
- Understanding the Transformer and the Attention Mechanism
- Matrix Multiplication
- Softmax Operation
- Forward and Backward Propagation
- The GPU Memory Hierarchy
- CUDA Programming Concepts
- Floating Point Formats (FP16, BF16, FP8)
Designing Hardware-Aware and Memory-Efficient Algorithms
Modern accelerators, including Hopper and Ampere GPUs, boast a high number of floating-point operations per second, a metric reflecting a device’s computational power. However, these accelerators face limitations due to memory bandwidth— the speed at which data transfers between the GPU’s memory and its processing units. Therefore, creating hardware-aware and memory-efficient algorithms for GPUs necessitates careful planning on how to maximize the use of memory hierarchy and leverage the theoretical maximum FLOPS.
FlashAttention exemplifies a hardware-aware and memory-efficient algorithm that enhances the context length in Transformers by optimizing the attention mechanism based on the hardware used for computation.
FlashAttention (2022)
FlashAttention is described as an “IO-aware exact attention algorithm that utilizes tiling to minimize data reads/writes between the GPU’s high bandwidth memory (HBM) and on-chip SRAM.”
Let’s elaborate on that.
GPU Memory: HBM & SRAM
The terminology regarding GPU memory types can be perplexing, with various terms that describe overlapping ideas. FlashAttention makes use of two types of memory, HBM and SRAM.
GPU Compute Model
Understanding data transfer within the GPU is crucial.
- Input begins in HBM (GPU Memory)
- Data transitions to compute units & SRAM for processing
- Output returns to HBM
Computing Attention
The self-attention calculation in matrix form illustrates the process, drawing from The Illustrated Transformer by Jay Alammar.
The Attention Line-up
Here’s a summary of the variables essential for computing the self-attention layer of the transformer.
Query (Q): This vector represents the current input for which attention calculations are performed. It is part of a query matrix of size Nxd, where N indicates the sequence length (typically ranging from 1K to 8K) and d represents the head dimension (typically 64-128).
Key (K): The key matrix shares the same dimensions as the query matrix. The interaction of key vectors with query vectors yields the similarity score.
Similarity Score (S): This score measures the similarity between the query and each sequence element. By multiplying the query matrix with the transposed key matrix, an NxN similarity score matrix is generated.
Attention Probability (P in algorithm, A in diagram): This probability distribution is derived by applying softmax to the similarity scores. The softmax function normalizes similarity scores to ensure positivity and that their sum totals one.
Note that S and P/A matrices are intermediate matrices and are not depicted in the formula.
Value (V): The value vectors from the Nxd value matrix contain data about each sequence element and combine with attention probabilities to produce the Nxd output.
Attention algorithm as illustrated in the FlashAttention paper demonstrates the loading of Q and K matrices into HBM for calculating S, followed by reading S from HBM to perform softmax, resulting in P written back into HBM, which is the most time-consuming step.
FlashAttention is IO-aware
Having established that the standard attention implementation lacks IO-awareness due to redundant data transfer with slow GPU memory (HBM), we will explore the challenges FlashAttention overcame to achieve IO-awareness.
Kernel Fusion
FlashAttention enhances performance by amalgamating attention computations into a single CUDA kernel. Although kernel fusion may appear uncomplicated, the algorithm required careful design to ensure that on-chip memory utilization adheres to hardware limits.
Tiling
Tiling is a method that segments data into smaller blocks, or “tiles,” that fit onto on-chip memory. Thanks to tiling-assisted kernel fusion, bandwidth demands are reduced as data only transfers from global memory to streaming multiprocessors once per tile.
Tiling works particularly well for associative tasks like matrix multiplication, permitting computations to be reordered without altering outcomes, thus allowing effective processing of smaller tiles. However, as the softmax operation in self-attention is non-associative, the order of computations becomes significant.
Making Softmax Associative
<pEmploying the online softmax trick to establish softmax associativity is a pivotal innovation of FlashAttention. The attention computation is reshaped to allow incremental softmax computation with input blocks for Q, K, V, thus eliminating the need for storing intermediate matrices in HBM—these computations occur in SRAM instead.
Recomputation in the Backward Pass
FlashAttention circumvents redundant read/write operations by forgoing the storage of intermediate S and A/P matrices and instead recalculating them during the backward pass. This involves retaining the output O and normalization statistics to recompute the intermediate matrices in SRAM as needed.
Conclusion
Through skillful reordering of attention computations using classical methods like tiling and recomputation, FlashAttention significantly accelerated the attention mechanism while minimizing memory utilization from quadratic to linear in terms of sequence length. This algorithm underscores a compelling blend of artistry and efficiency in hardware-aware algorithm design.
Thank you for engaging with the DigitalOcean Community. Explore our offerings in compute, storage, networking, and managed databases.
Welcome to DediRock, your trusted partner in high-performance hosting solutions. At DediRock, we specialize in providing dedicated servers, VPS hosting, and cloud services tailored to meet the unique needs of businesses and individuals alike. Our mission is to deliver reliable, scalable, and secure hosting solutions that empower our clients to achieve their digital goals. With a commitment to exceptional customer support, cutting-edge technology, and robust infrastructure, DediRock stands out as a leader in the hosting industry. Join us and experience the difference that dedicated service and unwavering reliability can make for your online presence. Launch our website.