Output-approximating methods
Quantization
Quantization is the process of decreasing the numerical precision in which weights and activations are stored, transferred and operated upon. The default representation of weights and activations is usually 32 bits floating numbers, with quantization we can drop the precision to 8 or even 4 bit integers.
Quantization can be either applied as an inference-only operation, or it can be incorporated into the training (referred to as Quantisation Aware Training QAT). QAT is generally considered to be a more resilient approach as the model is able to recover some of the quantisation-related quality losses during training. To make sure we get the best cost/quality tradeoff, we tweak the quantization strategy (e.g. select different precisions for weights vs activations) and the granularity in which we apply quantisation to Tensors (e.g. channel or group-wise).
Quantization has multiple performance benefits:
- it reduces the memory footprint of the model, allowing to fit larger models on the same hardware
- it reduces the communication overhead of weights and activations within one chip and across chips in a distributed inference setup- therefore speeding up inference as communication is a major contributor to latency.
- it can enable faster arithmetic operations on these models as some accelerator hardware (e.g. TPUs/GPUs) natively supports faster matrix multiplication operations for some lower precision representations.
Quantization’s impact on quality can be very mild to non-existent depending on the usecase and model. Further, in cases where quantisation might introduce a quality regression, that regression can be small compared to the performance gain, therefore allowing for an effective Quality vs Latency/Cost Tradeoff. For example, Benoit Jacob et al. reported a 2X speed-up for a 2% drop in accuracy for the FaceDetection task on MobileNet SSD.
Distillation
Distillation is a set of training techniques that targets improving the quality of a smaller model (the student) using a larger model (the teacher). This method can be effective because larger models outperform smaller ones even if both are trained on the same data, mainly due to parametric capacity and training dynamics.
Data distillation or model compression is a variant of distillation in which a large model generates more synthetic data to train the smaller model. The increase in data volume will help move the the student further along the quality line compared to only training on the original data. Synthetic data needs to be approached carefully as it needs to be of high quality and can lead to negative effects otherwise.
Knowledge distillation is an approach in which we try to align the output token distribution of the student models to teacher and is more effective than data distillation. On-policy distillation uses the teacher model to provide a feedback to the student in a reinforcement learning setup
Output-preserving methods
Flash Attention
Scaled Dot-product Attention, which is the predominant attention mechanism in the transformer architecture, is a quadratic operation on the input length. Optimizing the self- attention calculation can bring significant latency and cost wins.
Flash Attention optimizes the attention calculation by making the attention algorithm IO Aware, particularly trying to minimize the amount of data we move between the slow HBM (high bandwidth memory) to the faster memory tier (SRAM/VMEM) in TPUs and GPUs. When calculating attention, the order of operations is changed and multiple layers are fused so we can utilize the faster memory tiers as efficiently as possible.
Prefix Caching
One of the most compute intensive, and thus slowest, operations in LLM inference is calculating the attention key and value scores (a.k.a KV) for the input we’re passing to the LLM, this operation is often referred to as prefill. By caching the attention key and value scores for each layer of the transformer for the entire input we can significantly improve the latency and cost of the operation
Tip
We cannot cache Q in decoder-only architecture since we process left-to-right token by token and we need to update the token after each inference step
In the case of large language models (LLMs), input typically refers to the entire sequence fed to the model in one go. This includes everything the model needs to process for generating an output, such as:
- Context: Background information, instructions, or a long prompt that helps the model generate a relevant response.
- User Request or Query: The specific question, command, or update that prompts the model to generate an answer.
For example, if you’re working with a document (like a PDF book) and you ask multiple questions about it, your input for each query might look like: Context (PDF Book Content) + Query (User’s Question)
Tip
In practice, the LLM cannot itself decide which part of the input it receives are stable context itself, so it is the application surrounding the inference system that specifies that only the PDF needs to be cached.
Speculative decoding
Prefill is compute bound due large matrix operations on many tokens occurring in parallel. The following phase, known as decode, is generally memory bound as tokens are auto-regressively decoded one at a time.
It is not easy to naively use additional parallel compute capacity to speed up decode given the need to wait for the current token to be produced before we can calculate what the next token should be (as per the self-attention mechanism), the decode process is inherently serial.
The main idea behind speculative decoding is to use a much smaller secondary model (often referred to as the drafter) to run ahead of the main model and predict more tokens. (e.g. 4 tokens ahead). This will happen very quickly as the drafter is much faster and smaller than the main model. We then use the main model to verify the hypotheses of the drafter in parallel for each of the 4 steps (i.e. the first token, the first two tokens, the first 3 tokens and finally all 4 tokens), and we then select the accepted hypothesis with the maximum number of tokens.
For example, in predicting the answer to What is the capital of France the drafter model could reply It is Lyon. However the drafter model will only take 1ms for inference while the main model takes 10ms, so in this case we will have a total inference time of 13ms (3 * 1ms, plus rejecting Lyon and inference of 10ms for Paris)
Batching and parallelization
Much like any software system, there are opportunities to improve throughput and latency through a combination of
- batching less compute-intensive operations (i.e. we can run multiple requests on the same hardware simultaneously to better utilize the spare compute)
- parallelizing the more compute-intensive parts of the computations (i.e. we can divide the computation and split it amongst more hardware instances to get more compute capacity and therefore better latencies
Batching in LLMs is most useful on the decode side since it is not compute-bound and therefore there’s an opportunity to batch more requests. We need to be careful that we batch computations in a way that enables utilization of the spare capacity which is possible to do on accelerators (e.g. TPUs and GPUs). We also need to make sure we remain within the memory limits, as decode is a memory intensive operations, batching more requests will put more pressure on the free memory available. Batching has become an important component in most high-throughput LLM inference setups.
Parallelization is also a widely used technique given the variety of opportunities in transformers for horizontal scaling across more hardware instances. There are multiple parallelism techniques:
- across the model input (Sequence parallelism)
- the model layers(Pipeline parallelism)
- within a single layer (Tensor parallelism).