How a hierarchical, parameter-free trick lets you pre-train Transformers on million-token contexts — and still get a fully dense model at the end.
To understand why this paper exists, you only need to understand one number: $N^2$.
In a Transformer, every token in your sequence "looks at" every other token via the attention operation. If your sequence is 1,000 tokens long, that's a million pairwise comparisons. If it's 1 million tokens long? One trillion comparisons. This is the famous quadratic cost of scaled dot-product attention (SDPA), and it's the wall that everyone hits when they try to train on long contexts.
Modern frontier models want context windows of 128K, 1M, even longer — for agents reasoning over many steps, for processing entire codebases, for analyzing video. FlashAttention made the constants better, but it didn't change the asymptotics. A quadratic curve is still a quadratic curve.
A whole family of sparse attention methods says: instead of every token looking at every other token, let each query only look at a small, cleverly-chosen subset. Methods like MoBA, Native Sparse Attention, and DeepSeek Sparse Attention all live in this space.
But the authors point out two design choices that make these methods fit poorly for training from scratch:
Queries stay at full resolution, only keys/values get pooled. The hierarchy becomes "compressed memory," not a true multi-scale view.
Selection logic lives inside the attention kernel. You can't reuse stock FlashAttention — every method ships its own custom kernel.
And there's a deeper concern: if you train using a sparse approximation, does the resulting model still know how to be a normal dense Transformer when you serve it? Inference-time sparse methods don't have to worry — they're evaluated against a dense backbone they didn't change. Training-time methods have to prove they didn't damage the underlying model.
That's Lighthouse Attention in one paragraph. The "ships" are tokens. The "clusters" are pyramid levels that pool $p^\ell$ tokens together. The "looking interesting" is a parameter-free score (just the $\ell_2$ norm of the projection). And the "focusing" is a top-K selection.
Then at the end of training, you do a brief continuation with normal dense SDPA — like a short rehab. The model "remembers" how to be a full Transformer. The crucial empirical claim: this recovered model is as good or better than one trained dense the whole way, but you spent only ~60% of the compute getting there.
Here's the entire forward pass. The orange path is the data; the green path is the index selection; the red dashed path is the gradient.
The previous section was the cartoon. This is the engineering. We walk through all four stages, with their math, their motivation, and the design choices that make each one defensible.
Before Lighthouse changes anything, we need to remember what we're replacing. Standard scaled dot-product attention takes input $X \in \mathbb{R}^{N \times d_{\text{model}}}$, projects it three ways, then computes:
The problem is the $QK^\top$ matrix: it's $N \times N$. At $N = 100{,}000$ that's 10 billion entries per layer per head. FlashAttention makes this fit in memory by tiling, but it doesn't change the fundamental $\Theta(N^2 d)$ compute cost.
Lighthouse replaces the entire equation with a four-stage pipeline. Let's walk through each.
Take the sequence of length $N$. Average-pool it down by a factor of $p$ to get a sequence of length $N/p$. Pool again to get $N/p^2$. Repeat $L$ times. You now have an $L$-level pyramid where level 0 is the original sequence and level $L{-}1$ is a heavily compressed summary.
Crucially: do this to all three projections — Q, K, and V — in lockstep. This is the part that makes Lighthouse different from NSA, HISA, and InfLLM-v2, which all leave queries alone and only pool keys/values.
The $i$-th window at level $\ell$ covers a span of $p^\ell$ consecutive tokens:
Then each pyramid entry at level $\ell$, position $i$ is the mean of all the base-level tokens in that window:
Two big payoffs. First, pooled queries and pooled keys now live in the same representation space at every level: a pooled query at level 2 can route to a pooled key at level 2 and the dot product is meaningful. With asymmetric pooling (only K and V), the coarse keys exist but there's nothing coarse to query them.
Second, the dense attention call cost drops from $\mathcal{O}(NSd)$ down to $\mathcal{O}(S^2 d)$, because both sides of the attention are now of length $S$. That's a huge win when $N$ is in the millions.
Total pyramid cost: $\sum_{\ell=0}^{L-1} N/p^\ell \le N \cdot p/(p-1) = \Theta(N)$. Pooling is essentially free.
We now have a pyramid of candidates. There are about $2N$ of them total (across all levels). We need to pick the $K$ most important to form our sub-sequence. The score is the simplest thing imaginable: the $\ell_2$ norm of the projection.
At the base level, the scores are:
At coarser levels, we don't recompute on the pooled projections — instead we max-pool the base-level scores. A coarse window "inherits" the importance of its single most-relevant token:
Then we run a single global top-K over everything: every level, every position, both QK and KQ scores. The output is a set of indices $\mathcal{I}$:
where $\mathcal{P}$ is the set of all pyramid coordinates. There's one extra rule: the coarsest level is always kept in full. This is cheap (it's small) and guarantees that every base position has at least one summary that touches it. The remaining budget is spent on finer levels where it matters more.
You might wonder: why not learn what's important? Methods like DSA and NSA do exactly that, training a small scoring head end-to-end. Lighthouse takes the harder, more honest path:
Now the magic happens. We use the indices $\mathcal{I}$ to gather the selected triples into a contiguous sequence:
The total length of this sub-sequence is:
The first term is the coarsest level (always kept in full); the second is the budget spent on the finer levels. The numbers are wild. The paper gives this example:
At $N = 10^6,\ L = 4,\ p = 4,\ k = 4096$:
$S \approx 65{,}000 \ll 1{,}000{,}000$
Attention runs on a sequence 15× shorter than the original. Because the gather is topologically sorted, the causal mask is just a standard $S \times S$ causal mask — no sparse indexing inside the kernel.
Then attention is just standard attention on this sub-sequence:
Every previous selection-based method jams its index list inside the attention kernel and writes a custom sparse matmul. That means:
Lighthouse sidesteps all of this. The gather produces a dense, contiguous tensor. Stock FlashAttention runs on it, unchanged. Forward and backward are bit-for-bit identical to a standard Transformer's. Switching to FlashAttention-3, FlashAttention-4, or whatever comes next is free.
We computed attention on $S$ tokens. But the layer's output is supposed to have $N$ tokens. So each output entry from the sub-sequence needs to be scattered back to all the base positions it represents.
An entry at pyramid level $\ell$, position $i$ summarized a window $W^{(\ell)}_i = [ip^\ell,\; (i+1)p^\ell - 1]$. But its output isn't written to that range directly — it's written to a shifted range that respects causality:
The shift of $p^\ell - 1$ is the causality trick: it ensures that a base position $j$ never receives a summary that contains its own future. This is non-obvious and matters! Without the shift, position 0 might read a summary that already "saw" position 5, which would leak future information.
Final output is just the sum of all contributions arriving at each position:
The per-position fan-in is bounded by $L$ (the number of pyramid levels) regardless of how big $k$ gets. Because there are no "holes" in the gathered sequence and no missing base positions, the output is fully dense — it's a compressive approximation of full attention, but with no gaps to cause gradient instabilities.
This is the most surprising part of the design. There's no straight-through estimator, no Gumbel-softmax, no auxiliary scoring loss. Just a hard top-K with zero gradient. The whole thing still works.
Here's how the gradient flows during the backward pass:
Gradients flow back through the entire data path — from the loss, through the scatter, through FlashAttention's existing backward, through the gather, and into the projection matrices. The scorer and the top-K never receive a gradient; they're bypassed entirely.
This is the elegant part. Because $W_Q, W_K, W_V$ get gradient signal only through entries that were selected, the projections learn to produce values that are useful when they're chosen — not to game a scoring function. There's no scorer to game. The scorer is just an $\ell_2$ norm; it doesn't have parameters.
This sidesteps every known optimization pathology of learned selectors: scorer collapse (where the scorer fixates on one set of tokens), scorer-attention misalignment (where the scorer picks tokens the attention doesn't actually use), auxiliary-loss tuning hell.
The whole pipeline is a sequence of stages. Each one has a cost. The only super-linear term is the dense FlashAttention call on the gathered sub-sequence — and even that becomes polylogarithmic with the right choice of $L$.
| Stage | Primitive | Cost |
|---|---|---|
| Projections Q, K, V | GEMM | $\Theta(N\, d_{\text{model}}\, d)$ |
| Pyramid pool | view + mean | $\Theta(N\, d)$ |
| Scoring (norms, max-pool) | norm + max | $\Theta(N\, d)$ |
| Top-K selection | chunked bitonic | $\Theta(N \log k)$ |
| Gather to sub-sequence | torch.gather | $\Theta(S\, d)$ |
| Dense sub-sequence attention | FlashAttention | $\Theta(S^2\, d)$ |
| Scatter-back | custom atomic | $\Theta(N\, d)$ |
Recall $S = N/p^{L-1} + (L-1)\, p\, k$. If we set $L = \log_p(N/k)$, then $p^L = N/k$, and:
So the dense attention term becomes:
For bounded $k$, the total per-layer compute reduces to $\Theta(N \cdot d)$. That's the same asymptotic class as linear attention and SSMs, while still using softmax attention.
| Method | Per-layer compute | Notes |
|---|---|---|
| Dense softmax attention | $\Theta(N^2 \cdot d)$ | What we're trying to escape |
| Log-Linear Attention | $\Theta(N \log N \cdot d)$ | Better, still super-linear |
| Lighthouse (bounded k) | $\Theta(N \cdot d)$ | Linear, with softmax |
| Linear attention / SSMs | $\Theta(N \cdot d)$ | Same class, no softmax |
This is the part that makes Lighthouse usable. If you trained the entire run with the hierarchical approximation, you'd be left with a model that only knows how to do hierarchical attention — not a normal Transformer. So the recipe has two phases.
The crucial insight: Lighthouse's inner attention call is just FlashAttention. When you flip the switch in stage 2 — disable selection, pass full Q, K, V into the same kernel — the model is immediately running real attention. It's not switching architectures. The weights $W_Q, W_K, W_V$ are the same weights; they've been training all along to produce sensible Q, K, V projections. The only thing that changes is which tokens those projections get applied to.
The "recovery" is the model adjusting from only seeing top-K subsets to seeing everything. Empirically this is fast. The paper tests three resume points (10k+6k, 11k+5k, 12k+4k) and all three end at lower final loss than dense-from-scratch at the same token budget.
The empirical story is clean. Compared to dense-SDPA from scratch on the exact same architecture, data, and token budget, every Lighthouse configuration was faster and at least as good on final loss.
Numbers from Table 1 in the paper. 530M Llama-3-style model, 16k optimizer steps, ~50.3B tokens, 98K context. "Final Loss" is training loss at step 16,000.
| Config | LH Steps | B200-Hours ↓ | Tok/s/GPU ↑ | Final Loss ↓ |
|---|---|---|---|---|
| SDPA Baseline (dense from scratch) | — | 303.2 | 45.6K | 0.7237 |
| LH→SDPA · 12k+4k · k=6144 | 12k | 214.7 | 74.7K | 0.7102 |
| LH→SDPA · 11k+5k · k=6144 | 11k | 219.6 | 75.4K | 0.7001 |
| LH→SDPA · 10k+6k · k=6144 | 10k | 228.0 | 75.0K | 0.6980 |
| L=3, p=2, k=1536 (best loss) | 10k | 203.9 | 93.9K | 0.6825 |
| L=3, p=4, k=1536, Norm scorer (best speed) | 10k | 179.6 | 126.0K | 0.6946 |
| 1M-token training (CP=8) · k=4096 | 10k | 1300.3 | 48.9K | 0.6721 |
Read the table this way: every Lighthouse row has a lower final loss than dense-from-scratch (0.7237) while training in less wall-clock time. The best speed-vs-quality balance — L=3, p=4, k=1536, norm scorer — finishes in 60% of the dense baseline's time with a loss of 0.6946 (a meaningful improvement). The best raw loss configuration — L=3, p=2, k=1536 — saves about a third of the compute and gets 0.6825.
Counter-intuitively, decreasing the selection budget $k$ tends to improve final loss within the tested range: $0.6825 \to 0.6880 \to 0.6890 \to 0.6951$ as $k$ goes $1536 \to 2048 \to 3072 \to 4096$. The authors speculate that sparser configurations regularize better against their limited training budget — picking fewer, more confident tokens helps generalization. Whether this reverses at much larger budgets remains open.
The most stringent test: hide a single digit in random alphanumeric filler at various depths up to 96K context, and ask the model to retrieve it. The Lighthouse-trained models with dilated scoring beat the dense baseline (mean retrieval 0.76 at k=2048 dilated vs 0.72 baseline). The norm scorer hurts retrieval a bit more than loss, so the right default depends on your downstream task.
A few things in this paper are worth flagging because they push against the field's current direction:
Many sparse attention papers are evaluated only at inference, against a dense backbone they didn't train. That's a much easier test — the dense model already works, you're just trying not to lose too much. Lighthouse explicitly takes on the harder question: does the model you trained still work as a full Transformer? The answer they get (yes, with a brief resume) is non-obvious and a contribution in its own right.
Once selection lives outside the kernel, the kernel can be anything. FlashAttention-2, -3, -4 — Lighthouse picks up all upstream improvements for free. Custom sparse kernels can't say that. Ring attention for context parallelism? Works unchanged. The implementation surface shrinks dramatically.
A persistent theme in deep learning: the simpler thing often works as well as the learned thing. Lighthouse adds zero learnable parameters or auxiliary losses; the $\ell_2$ norm of the projection is enough. The paper even argues this is the cheaper option, so any benefit they show is a lower bound on what a smarter scorer could achieve.
Most of the field treats discrete operations as a problem to soften (Gumbel-softmax, straight-through estimators). Lighthouse just keeps them hard and routes gradients around them. The projections learn to be useful when selected — a cleaner optimization story than "learn to score well so you get selected."
The paper is upfront about its limits, and they matter: