Optimizing Token Generation in PyTorch Decoder Models | Towards Data Science

Optimizing Token Generation in PyTorch Decoder Models | Towards Data Science


that have pervaded nearly every facet of our daily lives are autoregressive decoder models. These models apply compute-heavy kernel operations to churn out tokens one by one in a manner that, at first glance, seems extremely inefficient. Given the enormous demand for generative AI, it is no surprise that extraordinary engineering effort is being invested into its optimization. Whether it be through custom CUDA kernels, CUDA Graphs, dedicated AI accelerators, or speculative sampling — any technique that reduces latency and/or cost by even a fraction of a percentage is a win.

In this post, we demonstrate a technique for optimizing token generation in PyTorch using CUDA stream interleaving. While simple to implement, the method addresses a specific, often overlooked bottleneck and can lead to meaningful performance boosts. While pipelining model execution using CUDA streams is common in AI systems engineering, we did not find any tutorial documenting the specific PyTorch-level application we describe here. If you find the technique useful, please be so kind as to reference this post.

To facilitate our discussion, we will use a simple GPT-2 PyTorch decoder model from HuggingFace’s transformers (v5.1.0) library. We will run our experiments on an NVIDIA L40S GPU and PyTorch (2.10.0).

Disclaimer: The code we will share is intended for demonstrative purposes. Please do not rely on its accuracy or optimality. Please do not interpret our mentions of any library, platform, or service as an endorsement of its use.

Importantly, the value of the CUDA stream-based method we will discuss can vary greatly based on the details of your model and runtime environment. Please be sure to run your own benchmarks before integrating its use.

Our focus in this post is on PyTorch-native inference workloads which remain extremely prevalent in development and test settings. However, it is important to note that for production environments dedicated LLM inference libraries such as vLLM or NVIDIA TensorRT-LLM tend to deliver greater performance and should be used whenever relevant.

A Toy GPT-2 Model

To simplify our discussion, we will use a GPT-2 decoder model from the HuggingFace transformers library and have it run autoregressively on a batch of empty prompts.

In the following code block, we initialize the model and define a naive token generation function that creates a batch of random streams up to a given length.

import torch
from transformers import GPT2LMHeadModel, GPT2Config

torch.set_float32_matmul_precision('high')

DEVICE = "cuda"

# define the decoder model
config = GPT2Config.from_pretrained("gpt2")
model = GPT2LMHeadModel(config).to(DEVICE).eval()


@torch.inference_mode()
def generate_sequence(model, max_seqlen, batch_size):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    
    for i in range(max_seqlen):
        outputs = model(all_tokens)
        # extract new token
        logits = outputs.logits[:, -1, :]
        new_tokens = torch.argmax(logits, dim=-1)
        # append new token to sequence
        all_tokens = torch.cat(
            [all_tokens, new_tokens.unsqueeze(-1)],
            dim=-1
        )
        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break
    
    return all_tokens

Next, we define a simple benchmarking function which we use to measure the runtime performance and memory utilization of our token generator in different scenarios.

import time, statistics


def benchmark(func, num_runs=10):
    # Warmup
    func()
    torch.cuda.synchronize()
    
    runtimes = []
    
    for _ in range(num_runs):
        # reset memory stats before each run
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        
        start = time.perf_counter()
        _ = func()
        torch.cuda.synchronize()
        end = time.perf_counter()
        
        runtimes.append(end - start)
    
    # Get memory allocator stats from last run
    mem_stats = torch.cuda.memory_stats()
    allocated_peak = mem_stats.get('allocated_bytes.all.peak', 0)
    reserved_peak = mem_stats.get('reserved_bytes.all.peak', 0)
    f_peak = reserved_peak - allocated_peak
    f_pct = (
        100 * f_peak / reserved_peak
        if reserved_peak > 0 else 0
    )
    
    print(f"\n{'='*60}")
    print(f"Runtime Results:")
    print(f" Mean:               {statistics.mean(runtimes):.4f}s")
    print(f" Std:                {statistics.stdev(runtimes):.4f}s")
    print(f" Min:                {min(runtimes):.4f}s")
    print(f" Max:                {max(runtimes):.4f}s")

    print(f"\nMemory Stats:")
    print(f" Allocated bytes (peak): {allocated_peak / 1e9:.3f} GB")
    print(f" Reserved bytes (peak):  {reserved_peak / 1e9:.3f} GB")
    print(f" Fragmentation (peak):   {f_peak / 1e9:.3f} GB ({f_pct:.1f}%)")
    print(f"{'='*60}\n")


batch_size = 32
for max_seqlen in [100, 200, 400]:
    print(
        f"Benchmarking generation with batch size {batch_size} "
        f"and max sequence length {max_seqlen}..."
    )
    benchmark(
        lambda: generate_sequence(
            model, max_seqlen=max_seqlen, batch_size=batch_size
        )
    )

In the table below we capture the results for a batch size of 32 and several different sequence lengths:

Optimizing Token Generation in PyTorch Decoder Models | Towards Data Science
Baseline Results (By Author)

As the sequence length doubles, the runtime quadruples — appearing to follow a classic O(N²) scaling pattern. Additionally, high memory fragmentation points to severe strain on the CUDA memory allocator, which can result in frequent memory faults and degrade runtime performance. The fragmentation results from each step asking for slightly larger tensor allocations, a pattern which ends up leaving multiple pockets of unusable memory.

Our first optimization, KV caching, addresses the runtime complexity of our decoder model.

KV Caching

Our naive generator is extremely inefficient — rather than storing and reusing the intermediate tensors from previous tokens, it recalculates the entire sequence at every step.

We address the computation inefficiency by using KV caching: We store and reuse the intermediate Key and Value tensors for previous tokens. KV caching reduces the runtime complexity of token generation from O(N²) to O(N).

In the following code block, we utilize the transformers library’s built-in support for KV caching to reprogram our token generation function to compute a single batch of tokens in each step.

@torch.inference_mode()
def generate_sequence(model, max_seqlen, batch_size, use_cache=False):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)

    # past_key_values is used to store the cached key/values for each layer
    past_key_values = None

    for i in range(max_seqlen):
        current_input = (
            all_tokens if past_key_values is None
            else all_tokens[:, -1:]
        )
        outputs = model(
            current_input,
            past_key_values=past_key_values,
            use_cache=use_cache
        )
        # update cache for next step
        past_key_values = outputs.past_key_values
        logits = outputs.logits[:, -1, :]
        new_tokens = torch.argmax(logits, dim=-1)
        # append new token to sequence
        all_tokens = torch.cat(
            [all_tokens, new_tokens.unsqueeze(-1)],
            dim=-1
        )
        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break
    
    return all_tokens

The resulting performance numbers are captured in the following table:

Token Generation With KV Caching (By Author)

The performance improvement is profound and, as expected, increases as a function of the sequence length.

Although somewhat better than in our baseline experiment, the degree of memory fragmentation remains a concern. To address this we explore two methods, expandable memory allocations and static KV caching.

Expandable CUDA Memory Allocations

To reduce CUDA memory fragmentation, we program PyTorch to use expandable memory segments. As of the time of this writing, this memory optimization is an experimental feature and should be used with caution. Please see the PyTorch documentation for details. To use the feature we set the following environment variable:

export PYTORCH_ALLOC_CONF="expandable_segments:True"

Rerunning our benchmark results in the following table:

KV Caching With Expandable Memory Segments (By Author)

Not only do we see a marked improvement in fragmentation, but we also get an additional (marginal) improvement in runtime performance.

KV Caching With StaticCache

The default cache in HuggingFace is dynamic — it grows as the number of keys and values increases during the generation progresses. HuggingFace supports a fixed-size cache, StaticCache, which pre-allocates a maximum cache size for the KV pairs and reduces strain on the CUDA memory allocator. The disadvantage of using StaticCache is that the full length of the cache participates in the attention computation at each token generation step, where irrelevant tokens are masked out. This results in a waste of computation that grows with the sequence length. For example, when generating a sequence of 400 tokens, the attention computation for each token will be run on full 400X400-sized tensors.

In the code block below we enhance our sequence generator to support the use of a StaticCache:

che:

from transformers import StaticCache

@torch.inference_mode()
def generate_sequence(
    model, max_seqlen, batch_size, use_cache=False, use_static_cache=False
):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    
    # Initialize static cache if requested
    if use_cache and use_static_cache:
        past_key_values = StaticCache(
            config=config,
            max_batch_size=batch_size,
            max_cache_len=max_seqlen,
            device=DEVICE,
            dtype=model.dtype
        )
    else:
        past_key_values = None
    
    # Initialize cache position tracking for static cache
    cache_positions = torch.arange(max_seqlen, device=DEVICE)
    
    for i in range(max_seqlen):
        current_input = (
            all_tokens if past_key_values is None
            else all_tokens[:, -1:]
        )
        cache_position = (
            cache_positions[i:i+1] if use_static_cache else None
        )
        outputs = model(
            current_input,
            past_key_values=past_key_values,
            cache_position=cache_position,
            use_cache=use_cache
        )
        # update cache for next step
        past_key_values = outputs.past_key_values
        logits = outputs.logits[:, -1, :]
        new_tokens = torch.argmax(logits, dim=-1)
        # append new token to sequence
        all_tokens = torch.cat(
            [all_tokens, new_tokens.unsqueeze(-1)],
            dim=-1
        )
        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break
    
    return all_tokens

The updated results are captured below:

Token Generation With Static KV Cache (By Author)

Using a fixed-sized cache greatly improves memory utilization as indicated by the decrease in memory fragmentation. However, its impact on runtime performance is mixed — for 100 tokens it reduces performance compared to a dynamic cache, whereas for 200 and 400 tokens it boosts performance by 9% and 10%, respectively.

There are more advanced methods of implementing attention that optimize for memory utilization without the cost of wasted computation. In a previous post, Optimizing Transformer Models for Variable-Length Input Sequences, we covered some PyTorch techniques for computing attention sparsely to reduce computation waste. For production settings, libraries such as vLLM use PagedAttention for maximizing memory utilization. These methods are outside the scope of this post.

For more details on caching in HuggingFace, please see the caching strategies overview.

Model Compilation

One of the documented advantages of using a fixed-sized cache is that it allows for taking advantage of many just-in-time (JIT) optimizations.

In the following code block we apply our benchmark to a PyTorch-compiled version of our decoder model:

batch_size = 32
max_seqlen = 100

model = torch.compile(model)

benchmark(
    lambda: generate_sequence(
        model,
        max_seqlen=max_seqlen,
        batch_size=batch_size,
        use_cache=True,
        use_static_cache=True
    )
)

Model compilation results in an additional boost to runtime performance as shown in the table below:

Token Generation With torch.compile (By Author)

Note that we can apply model compilation when using dynamic caching, as well. However, torch.compile provides the best results when the computation graph is composed of fixed-sized tensors (e.g., see here for more details).

The Performance Penalty of Early Stopping

An integral part of common token generators is checking for the end-of-sequence (EOS) at the end of each step. Without this test, token generators would always run for max_seqlen, even if all the sequences in the batch have ended. This could result in considerable computation waste and unnecessary latency — especially when common sequence lengths are much shorter than the maximum length. In the case of our toy experiment, we wait for all the sequences in the batch to end and discontinue token generation. Production-grade implementations will commonly perform continuous batching — replacing completed sequences with new prompts on the input queue.

        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break

Importantly, the .item() call on the stop_gpu tensor, triggers a blocking host-device synchronization event. More specifically, in order to evaluate the conditional if statement, the CPU must wait for the GPU to complete its computation and copy the contents of the tensor to host memory. While the CPU waits, it is blocked from executing the next step of the token generation loop, or more accurately, it is blocked from loading the next computation kernels onto the GPU.

To measure the impact of the stopping condition on runtime performance, we add instrumentation for performance profiling with NVIDIA Nsight™ Systems (nsys) using the torch.cuda.profiler and nvtx (v0.2.14) APIs. (See our recent post for more details on performance profiling with nsys).

ore details on performance profiling with nsys).

import nvtx
from torch.cuda import profiler

@torch.inference_mode()
def generate_sequence(
    model, max_seqlen, batch_size, use_cache=False, use_static_cache=False
):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    
    # Initialize static cache if requested
    if use_cache and use_static_cache:
        past_key_values = StaticCache(
            config=config,
            max_batch_size=batch_size,
            max_cache_len=max_seqlen,
            device=DEVICE,
            dtype=model.dtype
        )
    else:
        past_key_values = None
    
    # Initialize cache position tracking for static cache
    cache_positions = torch.arange(max_seqlen, device=DEVICE)
    
    for i in range(max_seqlen):
        if i == 30:
            # start nsys profiler
            torch.cuda.synchronize()
            profiler.start()
        elif i == 50:
            # stop nsys profiler
            torch.cuda.synchronize()
            profiler.stop()
        with nvtx.annotate(f"Step {i+1}", color="blue"):
            with nvtx.annotate("Model Forward", color="green"):
                current_input = (
                    all_tokens if past_key_values is None
                    else all_tokens[:, -1:]
                )
                cache_position = (
                    cache_positions[i:i+1] if use_static_cache else None
                )
                outputs = model(
                    current_input,
                    past_key_values=past_key_values,
                    cache_position=cache_position,
                    use_cache=use_cache
                )
                past_key_values = outputs.past_key_values
                logits = outputs.logits[:, -1, :]
                new_tokens = torch.argmax(logits, dim=-1)
                                all_tokens = torch.cat(
                    [all_tokens, new_tokens.unsqueeze(-1)],
                    dim=-1
                )
                finished |= (new_tokens == config.eos_token_id)
                stop_gpu = torch.all(finished)
            with nvtx.annotate("Check Stop Condition", color="red"):
                # checking stop condition
                if stop_gpu.item():
                    print(f"All sequences finished at step {i+1}")
                    break
    
    return all_tokens

We run our script using the cudaProfilerApi option to start and stop the profiler programmatically. Please see the official documentation for full details on profiling from the nsys CLI.

nsys profile \
  --capture-range=cudaProfilerApi \
  --trace=cuda,nvtx,osrt \
  --output=baseline \
  python train.py

The following trace, captured for a batch size of 16 and sequence length of 100, shows the GPU idling for about 110 microseconds in between steps — an eternity in the context of high-performance GPU workloads. This is a direct result of the synchronization event triggered by the EOS test.

GPU Utilization Drops Between Each Step (By Author)

In production-grade implementations such synchronization issues are avoided by some combination of 1) use of lower level (e.g., C/C++) code that avoids the limitation of the Python interpreter, 2) using CUDA graphs to reduce overhead of kernel loading, 3) moving conditional checks onto the GPU using conditional nodes, and 4) continuously and asynchronously preparing subsequent requests while the EOS check is in progress.

In the next section, we demonstrate a technique for hiding the overhead of the host-device synchronization in PyTorch using CUDA streams.

A CUDA Stream Optimization

A CUDA stream is a linear sequence of operations (kernels, memory copies, etc.) that execute in order on the GPU. While operations within a single stream are guaranteed to execute sequentially, operations in different streams can execute concurrently or overlap.

In previous posts (e.g., here and here) we demonstrated the use of CUDA streams in pipelining common AI/ML workloads, e.g., executing a model on batch N while preparing batch N+1. In this post we will use CUDA streams to enable the CPU to load the GPU kernels of step N+1 before checking the stopping criteria of step N. Contrary to our previous demonstrations of CUDA streams, our current example will not necessarily involve concurrent GPU kernel execution.
We implement an alternative token generation function that interleaves two CUDA streams, running the following operations iteratively:

Program stream i%2 to: (A) wait for stream (i-1)%2 to complete its generation of token i-1, (B) use the updated tensors to calculate the token i(C) run the EOS test for token i on the GPU, and (D) perform a (non-blocking) copy of the EOS test result to pinned memory on the CPU.

On the default CUDA stream, wait for stream (i-1)%2 to complete its generation of token i-1.

On the default CUDA stream, check if the stopping criteria for token i-1 were met. If so, halt the generator and return. Otherwise, increment i and return to step 1.

Whereas previously, the initialization of token i generation was blocked by the EOS test on token i-1, the use of CUDA streams allows us to program the generation of token i before we check the result of the EOS test on token i-1. In practice, the EOS test for token i-1 on the CPU runs while the GPU is computing token i.

@torch.inference_mode()
def generate_sequence_pipelined(
    model,
    max_seqlen,
    batch_size,
    use_cache=False,
    use_static_cache=False
):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    past_key_values = None
    
    # Initialize static cache if requested
    if use_cache and use_static_cache:
        past_key_values = StaticCache(
            config=config,
            max_batch_size=batch_size,
            max_cache_len=max_seqlen,
            device=DEVICE,
            dtype=model.dtype
        )
    
    # Initialize cache position tracking for static cache
    cache_positions = torch.arange(max_seqlen, device=DEVICE)
    
    # Dual streams for pipelining
    streams = [torch.cuda.Stream(), torch.cuda.Stream()]
    stop_host = [
        torch.tensor(False, pin_memory=True),
        torch.tensor(False, pin_memory=True)
    ]
    
    for i in range(max_seqlen):
        curr_idx, prev_idx = i % 2, (i+1) % 2
        curr_s, prev_s = streams[curr_idx], streams[prev_idx]
        
        # Launch iteration i in current stream
        with torch.cuda.stream(curr_s):
            # program stream to wait for previous stream to complete
            curr_s.wait_stream(prev_s)
            current_input = (
                all_tokens if past_key_values is None
                else all_tokens[:, -1:]
            )
            cache_position = (
                cache_positions[i:i+1] if use_static_cache else None
            )
            outputs = model(
                current_input,
                past_key_values=past_key_values,
                cache_position=cache_position,
                use_cache=use_cache
            )
            past_key_values = outputs.past_key_values
            logits = outputs.logits[:, -1, :]
            new_tokens = torch.argmax(logits, dim=-1)
            all_tokens = torch.cat(
                [all_tokens, new_tokens.unsqueeze(-1)],
                dim=-1
            )
            
            finished |= (new_tokens == config.eos_token_id)
            stop_gpu = torch.all(finished)
            stop_host[curr_idx].copy_(stop_gpu, non_blocking=True)
        
        # Check previous iteration's stop signal
        torch.cuda.current_stream().wait_stream(prev_s)
        if stop_host[prev_idx].item():
            print(f"All sequences finished at step {i}")
            break
    
    return all_tokens

The image below captures the nsys trace for our new token generator:

Constant GPU Activity When Applying CUDA Streams (By Author)

In the CUDA section of the trace we can see the use of two CUDA streams, with token generation being passed back and forth in a sort of ping-pong effect: One stream generates all of the odd tokens and second all of the even tokens. The CPU is about half a step ahead of the GPU — allowing it to program step i while the GPU is computing step i-1. The CPU-side EOS stop-check of step i-1 (in red) occurs after step i is fully programmed (and has started running). Most importantly, we now find the GPU utilization to be consistent — the idling we saw before is gone.

The CUDA stream interleaving results in an additional performance boost, as shown in the table below:

Token Generation With CUDA Streams (By Author)

We would expect the benefit of the ping-pong solution we have implemented to be impacted by the ratio between the GPU idle time (i.e., the overhead of kernel loading) and the kernel computation time. To test this, we fix the sequence length at 100 and rerun the benchmark for a number of batch sizes:

Impact of Pipelining for Varying Batch Size (By Author)

As expected, the highest performance gain, 11.6%, occurs when the batch size is smallest and the kernel computation load is at its lowest. As the kernel compute increases, the ratio of kernel loading to kernel compute time decreases as does the impact of CUDA stream interleaving.

Note that there is some overhead to the use of CUDA streams. This can be demonstrated by comparing our interleaving solution to a token generator that skips the EOS test altogether:

Overhead of CUDA Stream Interleaving (By Author)

The Potential Performance Pitfalls of Using CUDA Streams

CUDA streams should be used with extreme caution. When using the default stream we can rely on PyTorch to perform any necessary synchronization when data is moved around. However, when using CUDA streams, we must ensure appropriate synchronization explicitly. In particular, we must ensure appropriate data transfer between the streams. Otherwise, we may experience CUDA errors (e.g., “device-side assert triggered”) — if we are lucky. If we are less lucky, we may experience data corruption without even knowing it. See the PyTorch CUDA stream documentation for more details on appropriate use.

For AI/ML workloads with large CUDA memory utilization, such as LLMs, another consideration is memory utilization. The PyTorch caching allocator manages memory on a per-stream basis; using multiple streams can lead to increased memory reservation and fragmentation. These could result in increased memory faults that might overshadow the potential gains from the use of streams.

Results

In the table below we summarize the runtime results of applying static caching, compilation, and pipelining on a batch of 32 sequences and a maximum sequence length of 100. The results are sorted in increasing order of performance:

Token Generation Optimization Results (By Author)

In the case of our toy GPT-2 model, the best results — nearly 5 times the baseline performance — are achieved when employing PyTorch compilation and the CUDA stream interleaving method discussed in this post. However, as we have seen, the impact of CUDA interleaving could vary greatly based on the properties of the workload and runtime environment, particularly on the ratio between the kernel loading time and the kernel compute time. Please be sure to run your own benchmarks before adopting this method.

Summary

In high-performance AI engineering, any hint of GPU under-utilization presents an opportunity for optimization. One of the primary optimization tools on NVIDIA GPUs is CUDA streams. In this post, we demonstrated their use in solving the idle GPU time that results from the host-device synchronization associated with early-stopping in PyTorch-native autoregressive token generation. By interleaving CUDA streams in a “ping-pong” pattern, we successfully hid the latency imposed by the EOS-check which resulted in a meaningful increase the workload’s throughput. By combining this technique with the well-known methods of model compilation and static caching, we can maximize the performance of PyTorch-native inference.

Leave a Reply

Your email address will not be published. Required fields are marked *