Discovering The Gems
Discovering The Gems
Abstract
Large Language Models (LLMs) have demonstrated remarkable capabilities in handling long
context inputs, but this comes at the cost of increased computational resources and latency. Our
research introduces a novel approach for the long context bottleneck to accelerate LLM infer-
ence and reduce GPU memory consumption. Our research demonstrates that LLMs can identify
relevant tokens in the early layers before generating answers to a query. Leveraging this insight,
we propose an algorithm that uses early layers of an LLM as filters to select and compress input
tokens, significantly reducing the context length for subsequent processing. Our method, Gem-
Filter, demonstrates substantial improvements in both speed and memory efficiency compared
to existing techniques, such as standard attention and SnapKV/H2O. Notably, it achieves a
2.4× speedup and 30% reduction in GPU memory usage compared to SOTA methods. Evalua-
tion on the Needle in a Haystack task shows that GemFilter significantly outperforms standard
attention, SnapKV and demonstrates comparable performance on the LongBench challenge.
GemFilter is simple, training-free, and broadly applicable across different LLMs. Crucially,
it provides interpretability by allowing humans to inspect the selected input sequence. These
findings not only offer practical benefits for LLM deployment, but also enhance our understand-
ing of LLM internal mechanisms, paving the way for further optimizations in LLM design and
inference. Our code is available at https://github.com/SalesforceAIResearch/GemFilter.
∗
[email protected]. University of Wisconsin-Madison.
†
[email protected]. Salesforce AI Research.
‡
[email protected]. Salesforce AI Research.
§
[email protected]. The University of Hong Kong. [email protected]. University of Wisconsin-Madison.
¶
[email protected]. Salesforce AI Research.
Contents
1 Introduction 2
2 Related Works 4
3 Method 5
3.1 Notations and Preliminary . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
3.2 Our Algorithm: GemFilter . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5
3.3 Running Time and Memory Complexity Analysis . . . . . . . . . . . . . . . . . . . . 6
3.4 Comparison with Other Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
4 Experiments 8
4.1 Needle in a Haystack . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
4.2 LongBench . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10
4.3 Filter Layer Choice . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 11
4.4 Running Time and GPU Memory Consumption . . . . . . . . . . . . . . . . . . . . . 12
5 Conclusion 13
A More Preliminary 16
1
1 Introduction
Large Language Models (LLMs) have demonstrated impressive abilities [WTB+ 22, BCE+ 23] and
found widespread application in various AI systems, such as ChatGPT [SZK+ 22], Gemini [ABW+ 23],
and Claude [Ant24], and so on. They are also a fundamental component in building language-based
AI agents that can orchestrate plans and execute complex tasks through interaction with external
tools. A key requirement for many of these applications is the ability to process long-context inputs.
This ability can also potentially eliminate the need of a retriever in retrieval augmented generation
(RAG) [XPW+ 24] or enhance its performance [JMC24]. Therefore, significant efforts have been
made recently to build LLMs that support long context inputs. For instance, LLaMA 3.1 [DJP+ 24],
Mistral [JSM+ 23], and Phi 3.5 [AJA+ 24] now support input sequences of up to 128K tokens, while
Gemini can handle inputs of up to 1M tokens. However, processing such lengthy inputs comes at
a substantial cost in terms of computational resources and time. Therefore, accelerating the LLM
generation speed while simultaneously reducing GPU memory consumption for long-context inputs
is essential to minimize response latency and increase throughput for LLM API calls.
One prominent optimization for fast text generation in decoder-only LLMs (i.e., using a causal
attention mask) is the KV cache. Specifically, there are two phases involved in auto-regressive
generation. Given a long context input, the first is the prompt computation phase, when the LLM
computes the KV cache for all layers, storing the intermediate attention keys and values of the
input tokens. Next, in the iterative generation phase, the LLM generates tokens iteratively using
the pre-computed KV cache, avoiding redundant computations. GPU memory usage and running
time scale linearly with the KV cache size, meaning that the computational is high for long inputs.
To reduce GPU memory usage and running time during the iterative generation phase, H2O
[ZSZ+ 23] and SnapKV [LHY+ 24] introduce static methods to compress/evict the KV cache. These
techniques can shrink the KV cache size from 128K to 1024 with negligible performance loss,
resulting in faster speeds and lower GPU memory consumption during the iterative generation
phase. However, these methods do not improve the efficiency of the prompt computation phase,
which becomes the dominant bottleneck as the input context lengthens. Thus, we ask:
Can we accelerate the speed and reduce memory usage during the prompt computation phase?
We observe that when serving a query, LLMs often find the
necessary information in the early layers, even before generat-
ing the answer. Specifically, the relevant tokens can be iden-
Attention Matrix: QKT
tified using the attention matrix from these early layers (Fig-
ure 2), which we refer to as filter layers. Figure 1 provides a
real example from the Needle in a Haystack task, where LLMs
must find a small piece of information within a large context.
For LLaMA 3.1 8B, we observe that the information needed
to answer the query can be distilled from the attention matrix
in any of the 13th-19th layers. Furthermore, LLMs explicitly
summarize the required information in these filter layers. As
a consequence, we only need to perform the prompt computa-
Useful information Top k selection
tion on a long context input for the filter layers, allowing us to for retrieval based on last row
compress the input tokens into a smaller subset (e.g., reducing
from 128K tokens to 100), saving both time and GPU memory. Figure 2: The last row of attention
We then feed the selected tokens for full model inference and matrices in early layers can locate
proceed with a standard generation function. Algorithm 1 in answer-related tokens.
Section 3 presents our method GemFilter.
2
108,172 tokens 1,000 times
100 tokens
compress
<|begin_of_text|><|im_start|> This is
a very long story book: <book>
October 2015 When I talk to a startup <|begin_of_text|>
that's been operating for more than
…… <||> This book: < What a bang that
balloon is going
as a scripting language for Unix. (It
would be hard to make it worse. to make when someone pops it by:
Text selection
The best thing to do in San Woman with hammer.N trick to call
Francisco is eat a sandwich and sit on the first That sounds hipper than Lisp toThe
in Dolores Park on a sunny day.
few layer best thing to do in San Francisco
) But I think there are areas where is eat a sandwich and sit in
Dolores Park on a sunny day.
……
If you look at the history of programmi ) But you. trash I sawcarrying case. Full LLM
</book>. I looked inside Is real at
Based on the content of the book, and say," and Question: What is
Question: What is the best thing to Top k selection based the to do in San Francisco?
do in San Francisco? on the last row of Answer:
Answer:
attention matrix
Figure 1: Illustration of our method GemFilter: generation with context selection based on early
filter layers. We demonstrate a real Needle in a Haystack task (Section 4.1). The original input
consists of 108,172 tokens, including the initial instruction, key message, and the query. In the
first step, we use the 13th layer of the LLM (LLaMA 3.1 8B Instruct) as a filter to compress the
input tokens by choosing the top k indices from the last row of the attention matrix. Notably, the
selected input retains the initial instruction, key message, and query. GemFilter achieves a 1000×
compression, reducing the input token length to 100. In the second step, we feed the selected tokens
for full LLM inference using a standard generation function, which produces the correct output.
GemFilter significantly reduces running time and GPU memory with negligible performance loss.
LLaMA 3.1 8B Instruct running time comparison LLaMA 3.1 8B Instruct GPU memory comparison
17.5 standard prompt time standard prompt GPU mem
60
15.0 standard gen time standard gen GPU mem
Running time: seconds
As shown in Figure 3, GemFilter runs faster and consumes less GPU memory than Snap-
KV/H2O and standard attention (full KV cache) during the prompt computation phase. During the
iterative generation phase, GemFilter has the same running time and GPU memory consumption as
SnapKV/H2O, both of which outperform standard attention. We discuss the complexity further in
Section 3.3 theoretically and in Section 4.4 empirically. GemFilter significantly outperforms stan-
dard attention and SnapKV on the Needle in a Haystack benchmark (Section 4.1). Additionally,
on LongBench, a multi-task benchmark designed to rigorously evaluate long-context understanding
across various datasets, GemFilter achieves performance comparable to SnapKV/H2O (Section 4.2).
3
Furthermore, our ablation study in Section 4.3 show that our method is quite robust to the filter
layer selection strategy.
• We found that LLMs can identify relevant tokens using attention matrices in the early layers,
suggesting crucial information is recognized before the answer generation. Furthermore, LLMs
explicitly summarize this information within specific filter layers. This observation provides
insights into LLM mechanisms and opens avenues for LLM understanding and algorithm design.
• GemFilter significantly outperforms both standard attention (all KV cache) and SnapKV on the
Needle in a Haystack benchmark (Section 4.1), while maintaining performance comparable to
SnapKV/H2O on the LongBench benchmark (Table 1).
• Our approach offers several advantages: it is simple, training-free, and broadly applicable to
various LLMs. Furthermore, it enhances interpretability by allowing humans to directly inspect
the selected token sequence.
2 Related Works
Generation Speed-up with Long Context Input. One effective technique to accelerate auto-
regressive generation is KV cache compression/eviction. During generation, LLMs store the previ-
ous key and value matrices to reduce computational complexity. However, when the input context is
long (e.g., 128K tokens), the memory consumption and running time associated with the KV cache
dominate iterative generation. Many studies have focused on KV cache eviction. For instance,
[GZL+ 23] evict long-range contexts on attention heads to prioritize local contexts, using the KV
cache only for heads that broadly attend to all tokens. Streaming LLM [XTC+ 23] introduces an at-
tention sink that retains only the first few tokens and the latest k tokens in the KV cache to enable
fast streaming generation. LOOK-M [WWL+ 24] applies KV eviction in the multimodality so that
the model only needs to look once for the image. LongWriter [BZL+ 24] uses KV eviction to enable
LLMs to generate coherent outputs exceeding 20,000 words. MInference 1.0 [JLZ+ 24] determines
the optimal KV cache pattern for each attention head offline and dynamically builds sparse indices
based on the assigned query during inference. QuickLLaMA [LSJ+ 24] classifies the KV cache to
many subsets, e.g., query tokens, context tokens, global tokens, and local tokens, and only preserves
some types of tokens in the KV cache. ThinK [XJD+ 24] proposes a query-dependent KV cache
pruning method by pruning the least significant channel dimensions of the KV cache. H2O [ZSZ+ 23]
retains only tokens contributing to cumulative attention. SnapKV [LHY+ 24] evicts non-essential
KV positions for each attention head based on observation windows. While the aforementioned
studies focus on eviction and compression of the KV cache during the prompt computation phase
to optimize the iterative generation phase, they do not reduce the running time or GPU memory
usage during the prompt computation phase. In contrast, our method, GemFilter, achieves both
reduced running time and GPU memory usage in the prompt computation phase, as well as during
the iterative generation phase. We provide a more detailed comparison in Section 3.4.
4
More related to our work, [LDLG23] compress input sequences by pruning redundancy in the
context, making inputs more compact. However, they need to keep 50% of input tokens to keep the
LLMs’ performance, whereas GemFilter achieves comparable performance by only reserving 1% of
input tokens. For further details, we refer the reader to Section 4.1.
3 Method
3.1 Notations and Preliminary
While the Transformer and self-attention architecture [VSP+ 17] have already become overwhelm-
ingly popular, we first introduce certain preliminary definitions to provide a better methodological
connection to our proposed GemFilter method in Section 3.2.
For any positive integer n, we use [n] to denote the set {1, 2, · · · , n}. We use ◦ to denote function
composition and ⊙ to denote the Hardamard product. Let n be the input token/prompt length,
d the hidden feature dimension, and V the vocabulary set. We now introduce the key concept of
attention and transformers. We first define the query, key, and value matrices. It is important to
note that during text generation, the key and value matrices are also referred to as the KV cache,
as they are stored in GPU memory to reduce running time during the iterative prediction of the
next token.
Definition 3.1 (Single layer self-attention). Let Q ∈ Rn×d be the query matrix , K ∈ Rn×d the
key cache, and V ∈ Rn×d the value cache. Let Mc ∈ {0, 1}n×n be the causal attention mask, where
(Mc )i,j is 1 if i ≥ j and 0 otherwise. The self-attention function Attn is defined as:
√
Attn(Q, K, V ) = Mc ⊙ Softmax(QK ⊤ / d) · V
Definition 3.2 (Multi-layer transformer). Let T ∈ V n represent the input tokens, and let m denote
the number of transformer layers. Let gi represent components in the i-th transformer layer other
than self-attention, such as layer normalization, residual connections, and the MLP block, where
gi : Rn×d → Rn×d for any i ∈ {0, 1, . . . , m}. Let Attni denote the self-attention module in the i-th
transformer layer. We define an m-layer transformer F1:m : V n → Rn×d as
where E is the input embedding function mapping the input tokens to hidden features using the
vocabulary dictionary, i.e., E(T ) ∈ Rn×d .
Note that the above definitions use a single attention head for simplicity, but in practice, multi-
head attention is used [VSP+ 17].
5
Algorithm 1 GemFilter: Generation with Token Selection Based on Early Layers
1: procedure SelectionGen(F1:m , T ∈ [V]n , r ∈ [m], k ∈ [n])
2: ▷ F1:m : An m-layer transformer network; T : input sequence of tokens
3: ▷ r: filter layer index for token selection; k: number of selected tokens
4: Get Q(r) , K (r) by doing a r-layer forward pass: F1:r (T )
5: ▷ Q(r) , K (r) ∈ Rn×d : the r-th layer query, key
(r) (r) ⊤ (r) (r) ⊤
6: J ← topk index(Qn K , k) ▷ Qn : the last row of Q(r) ; Qn K (r) ∈ Rn are attn scores
7: Sort the indices in J ▷ J ⊆ [n] and |J| = k
8: return Gen(F1:m , TJ ) ▷ Gen is generation function, TJ ∈ [V]k is a sub-sequence of T on J
9: end procedure
The input of the algorithm is an m-layer transformer F1 (Definition 3.2), an input token sequence
T ∈ V n , and two hyperparameters r ≤ m, k ≤ n, where r represents the index of the filter layer for
context token selection and k denotes the number of tokens to select. For example, in the case of
LLaMA 3.1 8B Instruct (Figure 1), we have m = 32, r = 13, and k = 1024.
In the first step (Line 4), we run only the first r layers forward to serve as a filter, obtaining the
r-th layer’s query and key matrices, Q(r) and K (r) . Note that we do not need to run all layers of
the LLM on a long context input, thereby saving both computation time and memory (see detailed
analysis in Section 3.3). In Line 6, we select token indices based on the r-th layer attention matrix.
The selection is made by identifying the k largest values from the last row of the attention matrix,
(r)
i.e., the inner product between the last query token Qn and all key tokens K (r) . For multi-head
attention, the top-k indices are selected based on the summation of the last row across the attention
matrices of all heads. For instance, suppose we have h attention heads, and let Q(r,j) , K (r,j) ∈ Rn×d
represent the query and key matrices for the r-th layer and j-th attention head. Then, we compute
(r,j) ⊤
J ← topk index( hj=1 Qn K (r,j) , k), where J is a set of top k index selection. Note that our
P
method uses a single index set J, whereas SnapKV [LHY+ 24] and H2O [ZSZ+ 23] use different
index sets for each layer and attention head, resulting in m · h index sets in total. A detailed
discussion is provided in Section 3.4.
In Line 6, J is sorted by inner product values. However, we need to re-sort J so that the selected
tokens follow their original input order, ensuring, for example, that the ⟨bos⟩ token is placed at the
beginning. Line 7 performs this reordering operation. Finally, in Line 8, we can run any language
generation function using the selected tokens TJ , which is a sub-sequence of T on the index set J,
across all layers. This generation is efficient as the input context length is reduced from n to k,
e.g., from 128K to 1024 tokens in Figure 1. Below, we provide a formal time complexity analysis.
Theorem 3.3 (Complexity analysis). Let n be the input sequence (prompt) length and d the hidden
feature dimensions. In our Algorithm 1, GemFilter uses the r-th layer as a filter to select k input
tokens. Let SnapKV and H2O also use k as their cache size. Assume the LLM has m attention
layers, each with h attention heads, and each transformer layer’s parameters consume w GPU mem-
ory. Assuming that we generate t tokens with the Gen function and n ≥ max{d, k, t}, the following
table summarizes the complexity for standard attention, SnapKV and H2O, and GemFilter:
6
Complexity Standard attention SnapKV and H2O GemFilter
Prompt Comp. Θ(mhn2 d) Θ(mhn2 d) Θ(rhn2 d)
Time
Iter. generation Θ(mh(nt + t2 )d) Θ(mh(kt + t2 )d) Θ(mh(k 2 + t2 )d)
Prompt Comp. mw + 2mhnd mw + 2hnd + 2mhkd rw + 2hnd
GPU mem.
Iter. generation mw + 2mh(n + t)d mw + 2mh(k + t)d mw + 2mh(k + t)d
Recall that there are two phases in text generation. The first phase is prompt computation,
which involves attention computation on the long context input tokens and generating the KV
cache. The second phase is iterative generation, where auto-regressive generation occurs based on
the pre-computed KV cache. Theorem 3.3 demonstrates that GemFilter is faster and consumes less
GPU memory than SnapKV/H2O and standard attention during the prompt computation phase.
Additionally, during the iterative generation phase, GemFilter has the same running time and GPU
memory consumption as SnapKV/H2O, which is significantly better than standard attention. This
conclusion aligns with our experimental results in Section 4.4.
Case Study. Let us consider the case n ≫ k ≈ t, e.g., n =128K, k = t = 1024 and r < m.
During the prompt computation phase, we have the running time:
We see that GemFilter has a lower time complexity and less GPU memory consumption than
standard attention, SnapKV, and H2O. During the iterative generation phase, we have the running
time:
As such, GemFilter has the same time complexity and GPU memory consumption as SnapKV/H2O,
while significantly outperforming the standard attention.
The running time bottleneck for all methods occurs during prompt computation, which takes
Θ(mhn2 d) for standard attention, SnapKV, and H2O. In contrast, GemFilter only requires Θ(rhn2 d)
for prompt computation, as it only processes the early layers of the LLMs to select and compress
the input tokens during the first run. See detailed proof in Appendix B.
Note that the GPU memory bottleneck for standard attention occurs during iterative generation,
while for other methods, the memory bottleneck arises during prompt computation due to the
reduced KV cache. GemFilter consumes less GPU memory than SnapKV and H2O because it only
requires loading some layer model weights when processing the long context input in its first run.
Our empirical results in Section 4.4 support our complexity analysis findings.
7
3.4 Comparison with Other Methods
GemFilter reduces both running time and GPU memory usage in both the prompt computation
and iterative generation phases, whereas SnapKV [LHY+ 24] and H2O [ZSZ+ 23] focus only on the
iterative generation phase. During the prompt computation phase, standard attention computes
and stores the entire KV cache for all layers in GPU memory, which is used during the generation
phase. SnapKV and H2O, on the other hand, compute the entire KV cache for all layers but
only store a portion of it in GPU memory (e.g., k = 1024). They use the selected KV cache
for memory-efficient generation. SnapKV selects important clustered positions of the KV cache
from an ‘observation’ window located at the end of the prompt, while H2O greedily drops tokens
based on cumulative attention scores to retain only a small portion of the KV cache. In contrast,
GemFilter avoids computing the KV cache for all layers during the prompt computation phase.
Compared to SnapKV and H2O, there are two additional differences. First, SnapKV and H2O
maintain separate index sets for each layer and attention head, resulting in m · h index sets in total.
This leads to different behaviors across attention heads, making their intermediate mechanisms
more difficult to interpret. On the other hand, GemFilter uses a single index set, J, allowing for
easier interpretability by enabling the printing of the selected sequence for human review before the
second run (see a real example in Figure 1). Another distinction lies in how positional embeddings
are handled. In SnapKV and H2O, the maximum positional embedding distance is n + t, as the
same positional embedding is used in both the prompt computation and iterative generation phases.
However, in GemFilter’s second run, the maximum positional embedding distance is reduced to k+t
because the input token length is reduced from n to k, and the RoPE function1 is re-computed. This
reduction makes GemFilter more efficient, as the model can better handle shorter input sequences,
as demonstrated in Figure 4 (a).
4 Experiments
Model and Datasets. We evaluated our approach using three popular long-context models:
LLaMA 3.1 8B Instruct2 [DJP+ 24], Mistral Nemo 12B Instruct3 [JSM+ 23], and Phi 3.5 Mini 3.8B
Instruct4 [AJA+ 24], all of which support an input token length of 128K. We compared our method,
GemFilter, against standard attention and two state-of-the-art methods, SnapKV [LHY+ 24] and
H2O [ZSZ+ 23]5 . For our experiments, we used two popular datasets: Needle in a Haystack [Kam24]
(Section 4.1) and LongBench [BLZ+ 23] (Section 4.2). More implementation details are provided in
Appendix C.2.
Filter Layer. Except Section 4.3, for context selection, we always use the index of 13 out of 32,
19 out of 40, and 19 out of 32 layers as the input filter for LLaMA 3.1, Mistral Nemo and Phi 3.5,
respectively. In Section 4.3, we provide an ablation study for the filter layer choice.
8
Pressure Testing Mistral Nemo 12B Instruct All KV Pressure Testing LLaMA 3.1 8B Instruct All KV
Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0 Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0
0.0 0.0
11.0 11.0
22.0 0.8 22.0 0.8
33.0
Depth Percent
33.0
Depth Percent
Score
Score
56.0 56.0
0.4 67.0 0.4
67.0
78.0 78.0
0.2 89.0 0.2
89.0
100.0 100.0
0.0 0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
31
0
1
02
42 0
75 6
10 3
14 69
17 26
20 82
23 38
27 95
30 51
33 08
36 64
40 21
43 77
46 33
49 90
53 46
56 03
59 59
5
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
82
0
5
1
61
1
10
7
0
2
5
7
0
3
5
8
0
3
5
8
1
3
Token Limit Token Limit
(a) All KV. Mistral Nemo average score: 0.486; LLaMA 3.1 average score: 0.841.
Pressure Testing Mistral Nemo 12B Instruct SnapKV-1024 Pressure Testing LLaMA 3.1 8B Instruct SnapKV-1024
Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0 Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0
0.0 0.0
11.0 11.0
22.0 0.8 22.0 0.8
33.0
Depth Percent
33.0
Depth Percent
Score
Score
56.0 56.0
0.4 67.0 0.4
67.0
78.0 78.0
0.2 89.0 0.2
89.0
100.0 100.0
0.0 0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
31
0
1
02
42 0
75 6
10 3
14 69
17 26
20 82
23 38
27 95
30 51
33 08
36 64
40 21
43 77
46 33
49 90
53 46
56 03
59 59
5
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
82
0
5
1
61
1
10
7
0
2
5
7
0
3
5
8
0
3
5
8
1
3
Pressure Testing Mistral Nemo 12B Instruct GemFilter-1024 (layer-19) Pressure Testing LLaMA 3.1 8B Instruct GemFilter-1024 (layer-13)
Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0 Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0
0.0 0.0
11.0 11.0
22.0 0.8 22.0 0.8
33.0
Depth Percent
33.0
Depth Percent
Score
Score
56.0 56.0
0.4 67.0 0.4
67.0
78.0 78.0
0.2 89.0 0.2
89.0
100.0 100.0
0.0 0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
31
0
1
02
42 0
75 6
10 3
14 69
17 26
20 82
23 38
27 95
30 51
33 08
36 64
40 21
43 77
46 33
49 90
53 46
56 03
59 59
5
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
82
0
5
1
61
1
10
7
0
2
5
7
0
3
5
8
0
3
5
8
1
3
Figure 4: Needle in a Haystack performance comparison of different methods using the Mistral
Nemo 12B Instruct model (left column) and the LLaMA 3.1 8B Instruct model (right column).
Results for the Phi 3.5 Mini 3.8B Instruct model are provided in Appendix C.3. The x-axis
represents the length of the input tokens, while the y-axis shows the position depth percentage of the
‘needle’ information (e.g., 0% indicates the beginning, and 100% indicates the end). A higher score
reflects better performance, meaning more effective retrieval of the ‘needle’ information. GemFilter
significantly outperforms both standard attention (full KV cache) and SnapKV.
(the ‘haystack’), where the sentence can appear at any arbitrary location. The difficulty increases
as the length of the haystack grows. We use input lengths of 60K for Mistral Nemo 12B Instruct
and 120K for LLaMA 3.1 8B Instruct, as these are the maximum lengths for standard attention
on two A100-40GB GPUs. The KV cache size is set to 1024 for both SnapKV and GemFilter. In
Figure 4, we see that GemFilter significantly outperforms both All KV (standard attention) and
SnapKV with Mistral Nemo and LLaMA 3.1.6 The Needle in a Haystack results suggest that our
method, GemFilter, achieves superior retrieval performance for long input contexts compared to
6
H2O cannot be implemented with FlashAttention due to its cumulative attention score strategy and is therefore
unable to handle super long input contexts, which is why we exclude it here, following [LHY+ 24, XJD+ 24].
9
Table 1: Performance comparison on LongBench across various LLMs and methods. A larger
number means better performance. The best score is boldfaced.
SnapKV and standard attention. Additional results are provided in Appendix C.3.
4.2 LongBench
LongBench [BLZ+ 23] is a multi-task benchmark designed to rigorously evaluate long-context un-
derstanding capabilities across various datasets, including single- and multi-document Question
Answering (QA), summarization, few-shot learning, and synthetic tasks. We evaluate on the
English-only dataset, following [LHY+ 24, XJD+ 24].
For each LLM, we evaluate GemFilter and SnapKV with selected tokens/KV caches of 1024,
2048, and 4096. We also evaluated standard attention (all KV cache) and H2O with a KV cache size
of 4096 on the LongBench dataset to further demonstrate the performance of GemFilter, follow-
ing [LHY+ 24]. Table 1 shows a negligible performance drop in LLMs using GemFilter compared to
standard attention, even with only 1024 selected tokens. In some cases, GemFilter even outperforms
standard attention, such as GemFilter-2048 for Mistral Nemo 12B Instruct. It demonstrates sig-
nificantly better performance than H2O and comparable performance with SnapKV. Furthermore,
GemFilter effectively filters key information in long contexts, provides interpretable summaries,
10
Input: 108172 tokens. The distance between Input: 55989 tokens. The distance between Input: 122647 tokens. The distance between
top 1024 nearest neighbors and needle position. 20000top 1024 nearest neighbors and needle position. 60000top 1024 nearest neighbors and needle position.
40000 17500
50000
15000
30000 40000
Token distance
Token distance
Token distance
12500
10000 30000
20000
7500 20000
10000 5000
2500 10000
0 0 0
0 5 10 15 20 25 30 0 5 10 15 20 25 30 35 40 0 5 10 15 20 25 30
LLaMA 3.1 8B Instruct layer index Mistral Nemo 12B Instruct layer index Phi 3.5 Mini 3.8B Instruct layer index
(a) LLaMA 3.1 8B Instruct (b) Mistral Nemo 12B Instruct (c) Phi 3.5 Mini 3.8B Instruct
Figure 5: Distance between the needle position and selected token index position across three
LLMs. The position depth percentage of the “needle” information is 50%. The x-axis means the
layer index of different LLMs. The y-axis means min(topk index − niddle index). When y = 0, it
means the needle information is covered by the selected token. The needle information has been
successfully discovered in the early layers of all three LLMs.
and compresses the input context effectively, e.g., it reduces input tokens to an average of 8% when
using 1024 tokens, and 32% when using 4096, with negligible accuracy drops.
Table 2: Performance of our method on LongBench using different layers as an input filter. A
larger number means better performance. The best score is boldfaced.
We then use the first layer that accurately identifies the needle’s position as the input filter.
In our experiments, we find that this layer remains consistent across different inputs. As shown in
Table 2, performance first increases and then decreases as we select the input filter layer from the
beginning to the end. The peak performance is observed at the 13th layer, which supports our layer
11
selection strategy. Performance remains robust between layers 13 and 25, providing flexibility in
layer selection. Exploring the distinct functions of different layers presents an interesting direction
for future research.
5 20
0 0
8192 16384 32768 65536 131072 8192 16384 32768 65536 131072
Input token number Input token number
(a) Mistral Nemo 12B Instruct
Phi 3.5 Mini 3.8B Instruct running time comparison Phi 3.5 Mini 3.8B Instruct GPU memory comparison
standard prompt time 200 standard prompt GPU mem
12 standard gen time standard gen GPU mem
Running time: seconds
Figure 6: Comparison of time and GPU memory usage across different methods on Mistral Nemo
12B Instruct and Phi 3.5 Mini 3.8B Instruct. GemFilter uses the 19th layer as an input filter for
both LLMs. It achieves a 2.4× speedup and reduces GPU memory usage by 30% compared to
SnapKV.
7
We exclude H2O as it does not support FlashAttention and thus requires more GPU memory and running time
than standard attention during prompt computation.
12
5 Conclusion
In this work, we presented a novel approach, GemFilter, to accelerate LLM inference and reduce
memory consumption for long context inputs. By leveraging the ability of early LLM layers to
identify relevant information, GemFilter achieves significant improvements over existing techniques.
It demonstrates a 2.4× speedup and 30% reduction in GPU memory usage compared to SOTA
methods, while also showing superior performance on the Needle in a Haystack benchmark. Our
approach is simple, training-free, applicable to various LLMs, and offers enhanced interpretability
by directly inspecting selected tokens. These results not only provide practical benefits for LLM
deployment, but also provide insight into a better understanding of LLM internal mechanisms.
References
[ABW+ 23] Rohan Anil, Sebastian Borgeaud, Yonghui Wu, Jean-Baptiste Alayrac, Jiahui Yu, Radu
Soricut, Johan Schalkwyk, Andrew M Dai, Anja Hauth, et al. Gemini: a family of
highly capable multimodal models. arXiv preprint arXiv:2312.11805, 2023.
[AJA+ 24] Marah Abdin, Sam Ade Jacobs, Ammar Ahmad Awan, Jyoti Aneja, Ahmed Awadal-
lah, Hany Awadalla, Nguyen Bach, Amit Bahree, Arash Bakhtiari, Harkirat Behl, et al.
Phi-3 technical report: A highly capable language model locally on your phone. arXiv
preprint arXiv:2404.14219, 2024.
[Ant24] Anthropic. The claude 3 model family: Opus, sonnet, haiku. https://www-
cdn.anthropic.com, 2024.
[BCE+ 23] Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric
Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, et al.
Sparks of artificial general intelligence: Early experiments with gpt-4. arXiv preprint
arXiv:2303.12712, 2023.
[BLZ+ 23] Yushi Bai, Xin Lv, Jiajie Zhang, Hongchang Lyu, Jiankai Tang, Zhidian Huang,
Zhengxiao Du, Xiao Liu, Aohan Zeng, Lei Hou, et al. Longbench: A bilingual, mul-
titask benchmark for long context understanding. arXiv preprint arXiv:2308.14508,
2023.
[BZL+ 24] Yushi Bai, Jiajie Zhang, Xin Lv, Linzhi Zheng, Siqi Zhu, Lei Hou, Yuxiao Dong, Jie
Tang, and Juanzi Li. Longwriter: Unleashing 10,000+ word generation from long
context llms. arXiv preprint arXiv:2408.07055, 2024.
[Dao23] Tri Dao. Flashattention-2: Faster attention with better parallelism and work parti-
tioning. arXiv preprint arXiv:2307.08691, 2023.
[DFE+ 22] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention:
Fast and memory-efficient exact attention with io-awareness. Advances in Neural In-
formation Processing Systems, 35:16344–16359, 2022.
[DJP+ 24] Abhimanyu Dubey, Abhinav Jauhri, Abhinav Pandey, Abhishek Kadian, Ahmad Al-
Dahle, Aiesha Letman, Akhil Mathur, Alan Schelten, Amy Yang, Angela Fan, et al.
The llama 3 herd of models. arXiv preprint arXiv:2407.21783, 2024.
13
[GZL+ 23] Suyu Ge, Yunan Zhang, Liyuan Liu, Minjia Zhang, Jiawei Han, and Jianfeng Gao.
Model tells you what to discard: Adaptive kv cache compression for llms. arXiv
preprint arXiv:2310.01801, 2023.
[JLZ+ 24] Huiqiang Jiang, Yucheng Li, Chengruidong Zhang, Qianhui Wu, Xufang Luo, Surin
Ahn, Zhenhua Han, Amir H Abdi, Dongsheng Li, Chin-Yew Lin, et al. Minference
1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention. arXiv
preprint arXiv:2407.02490, 2024.
[JMC24] Ziyan Jiang, Xueguang Ma, and Wenhu Chen. Longrag: Enhancing retrieval-
augmented generation with long-context llms. arXiv preprint arXiv:2406.15319, 2024.
[JSM+ 23] Albert Q. Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Deven-
dra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume
Lample, Lucile Saulnier, Lélio Renard Lavaud, Marie-Anne Lachaux, Pierre Stock,
Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, and William El
Sayed. Mistral 7b, 2023.
[LDLG23] Yucheng Li, Bo Dong, Chenghua Lin, and Frank Guerin. Compressing context to
enhance inference efficiency of large language models. arXiv preprint arXiv:2310.06201,
2023.
[LHY+ 24] Yuhong Li, Yingbing Huang, Bowen Yang, Bharat Venkitesh, Acyr Locatelli, Hanchen
Ye, Tianle Cai, Patrick Lewis, and Deming Chen. Snapkv: Llm knows what you are
looking for before generation. arXiv preprint arXiv:2404.14469, 2024.
[LSJ+ 24] Jingyao Li, Han Shi, Xin Jiang, Zhenguo Li, Hong Xu, and Jiaya Jia. Quickl-
lama: Query-aware inference acceleration for large language models. arXiv preprint
arXiv:2406.07528, 2024.
[SAL+ 24] Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu.
Roformer: Enhanced transformer with rotary position embedding. Neurocomputing,
568:127063, 2024.
[SBZ+ 24] Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri
Dao. Flashattention-3: Fast and accurate attention with asynchrony and low-precision.
arXiv preprint arXiv:2407.08608, 2024.
[SZK+ 22] John Schulman, Barret Zoph, Christina Kim, Jacob Hilton, Jacob Menick, Jiayi Weng,
Juan Felipe Ceron Uribe, Liam Fedus, Luke Metz, Michael Pokorny, et al. Chatgpt:
Optimizing language models for dialogue. OpenAI blog, 2(4), 2022.
[VSP+ 17] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in
neural information processing systems, 30, 2017.
[WTB+ 22] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud,
Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, et al. Emergent abil-
ities of large language models. arXiv preprint arXiv:2206.07682, 2022.
14
[WWL+ 24] Zhongwei Wan, Ziang Wu, Che Liu, Jinfa Huang, Zhihong Zhu, Peng Jin, Longyue
Wang, and Li Yuan. Look-m: Look-once optimization in kv cache for efficient multi-
modal long-context inference. arXiv preprint arXiv:2406.18139, 2024.
[XJD+ 24] Yuhui Xu, Zhanming Jie, Hanze Dong, Lei Wang, Xudong Lu, Aojun Zhou, Amrita
Saha, Caiming Xiong, and Doyen Sahoo. Think: Thinner key cache by query-driven
pruning. arXiv preprint arXiv:2407.21018, 2024.
[XPW+ 24] Peng Xu, Wei Ping, Xianchao Wu, Lawrence McAfee, Chen Zhu, Zihan Liu, Sandeep
Subramanian, Evelina Bakhturina, Mohammad Shoeybi, and Bryan Catanzaro. Re-
trieval meets long context large language models, 2024.
[XTC+ 23] Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. Efficient
streaming language models with attention sinks. arXiv preprint arXiv:2309.17453,
2023.
[ZSZ+ 23] Zhenyu Zhang, Ying Sheng, Tianyi Zhou, Tianlong Chen, Lianmin Zheng, Ruisi Cai,
Zhao Song, Yuandong Tian, Christopher Ré, Clark Barrett, et al. H2o: Heavy-hitter
oracle for efficient generative inference of large language models. Advances in Neural
Information Processing Systems, 36, 2023.
15
Appendix
A More Preliminary
In this section, we introduce some key definitions of language modeling modules. We begin with
the input embedding function and the output embedding function. They are functions that bridge
between the input token space and the real vector space.
Definition A.1 (Input embedding function and input tokens). The input embedding function E :
V n → Rn×d maps the input tokens to hidden features using the vocabulary dictionary Dvoc ∈ R|V|×d .
Let T ∈ V n be input tokens. Then, we have E(T ) ∈ Rn×d and E(T )i = DTvoci
∈ Rd for any i ∈ [n].
Definition A.2 (Output embedding function). The output embedding function G : Rd → R|V|
maps hidden features to the probability logits of the vocabulary dictionary.
We introduce Softmax, which allows self-attention to learn the probability distribution rather
than function anymore.
Softmax(z) := exp(z)/⟨exp(z), 1n ⟩.
16
During iterative generation, it takes mw GPU memory consumption for the model weights and
2mh(n + t)d GPU memory consumption for the KV cache. Proof of SnapKV and H2O:
During prompting computation, it takes Θ(mhn2 d) time complexity, which is the same as
standard attention.
During iterative generation, it takes Θ(mh(kt + t2 )d) time complexity, as it reduces the KV
cache size from n to k.
During prompting computation, mw GPU memory is consumed for the model weights, 2hnd
for the selection of the key-value matrix for each layer, and 2mhkd for the selected KV cache.
During iterative generation, mw GPU memory is consumed for the model weights and 2mh(k +
t)d GPU memory is consumed for the KV cache.
Proof of our Algorithm 1 GemFilter:
During prompting computation, GemFilter takes Θ(rhn2 d) time complexity, which is faster
than other methods.
During iterative generation, it takes Θ(mh(k 2 + kt + t2 )d) = Θ(mh(k 2 + t2 )d) time complexity,
as it reduces the KV cache size from n to k.
During prompting computation, rw + 2hnd GPU memory is consumed for the model weights
and the selection of the key value matrix for each layer.
During iterative generation, mw + 2mh(k + t)d GPU memory is consumed for the KV cache
and model weights.
Thus, we finish the proof.
17
HuggingFace v4.43 PyTorch implementation. There is no randomness or training in all baseline
methods or our method. For the SnapKV/H2O, we use 32 recent size/observation window, which
is the optimal choice suggested by [LHY+ 24, XJD+ 24]. However, GemFilter does not have an
observation window. We use a maximum pooling kernel size (line 16 of the PyTorch code below) of
5 for SnapKV and our method. For generation, we use standard generation (greedy generation)9 ,
where num beams=1, do sample = False.
9
https://huggingface.co/docs/transformers/v4.43.2/en/main_classes/text_generation
18
Pressure Testing Phi 3.5 Mini 3.8B Instruct All KV
Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0
0.0
11.0
22.0 0.8
Depth Percent33.0
44.0 0.6
Score
56.0
67.0 0.4
78.0
89.0 0.2
100.0
0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
12 231
44
0
1
02
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
47
1
8
Token Limit
(a) All KV. Phi 3.5 average score: 0.851.
44.0 0.6
Score
56.0
67.0 0.4
78.0
89.0 0.2
100.0
0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
12 231
44
0
1
02
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
47
1
8
Token Limit
(b) SnapKV-1024. Phi 3.5 average score: 0.864.
44.0 0.6
Score
56.0
67.0 0.4
78.0
89.0 0.2
100.0
0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
12 231
44
0
1
02
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
47
1
8
Token Limit
(c) GemFilter-1024 (layer-19). Phi 3.5 average score: 0.910.
Figure 7: Needle in a Haystack performance comparison of different methods using the Phi 3.5
Mini 3.8B Instruct model. The x-axis represents the length of the input tokens, while the y-axis
shows the position depth percentage of the ‘needle’ information (e.g., 0% indicates the beginning,
and 100% indicates the end). A higher score reflects better performance, meaning more effective
retrieval of the ‘needle’ information. GemFilter significantly outperforms both standard attention
(full KV cache) and SnapKV.
19
Pressure Testing LLaMA 3.1 8B Instruct GemFilter-1024 (layer-14)
Fact Retrieval Across Context Lengths ("Needle In A HayStack") 1.0
0.0
11.0
22.0 0.8
33.0
Depth Percent
44.0 0.6
Score
56.0
67.0 0.4
78.0
89.0 0.2
100.0
0.0
75 0
14 3
20 6
27 38
33 51
40 64
46 77
53 90
59 03
66 15
72 28
79 41
85 54
92 67
98 79
10 92
11 05
11 718
31
0
1
02
10
5
0
5
0
5
1
6
1
6
1
6
1
6
52
82
1
Token Limit
(a) GemFilter-1024 (layer-14). LLaMA 3.1 average score: 0.870.
Figure 8: Needle in a Haystack performance comparison of different filter layers with LLaMA 3.1
8B Instruct model. The x-axis represents the length of the input tokens, while the y-axis shows the
position depth percentage of the ‘needle’ information (e.g., 0% indicates the beginning, and 100%
indicates the end). A higher score reflects better performance, meaning more effective retrieval of
the ‘needle’ information.
20