arXiv:2605.06554 Nous Research May 2026 Long-Context Pre-training

Lighthouse
Attention.

How a hierarchical, parameter-free trick lets you pre-train Transformers on million-token contexts — and still get a fully dense model at the end.

Attention is expensive. Painfully so.

To understand why this paper exists, you only need to understand one number: $N^2$.

In a Transformer, every token in your sequence "looks at" every other token via the attention operation. If your sequence is 1,000 tokens long, that's a million pairwise comparisons. If it's 1 million tokens long? One trillion comparisons. This is the famous quadratic cost of scaled dot-product attention (SDPA), and it's the wall that everyone hits when they try to train on long contexts.

Modern frontier models want context windows of 128K, 1M, even longer — for agents reasoning over many steps, for processing entire codebases, for analyzing video. FlashAttention made the constants better, but it didn't change the asymptotics. A quadratic curve is still a quadratic curve.

SDPA · Θ(N²d) Lighthouse · Θ(N·d) 8K 32K 64K 128K 256K 512K 21× FASTER @ 512K context latency (ms) sequence length N
Figure · Single-layer attention latency on a B200 GPU. The red curve is regular FlashAttention/SDPA, scaling as $\Theta(N^2 d)$. The cyan curve is Lighthouse, scaling polylogarithmically. At 512K context, Lighthouse is 21× faster on forward and 17.3× on the backward pass.

What others have tried

A whole family of sparse attention methods says: instead of every token looking at every other token, let each query only look at a small, cleverly-chosen subset. Methods like MoBA, Native Sparse Attention, and DeepSeek Sparse Attention all live in this space.

But the authors point out two design choices that make these methods fit poorly for training from scratch:

① Asymmetry

Queries stay at full resolution, only keys/values get pooled. The hierarchy becomes "compressed memory," not a true multi-scale view.

② Entanglement

Selection logic lives inside the attention kernel. You can't reuse stock FlashAttention — every method ships its own custom kernel.

And there's a deeper concern: if you train using a sparse approximation, does the resulting model still know how to be a normal dense Transformer when you serve it? Inference-time sparse methods don't have to worry — they're evaluated against a dense backbone they didn't change. Training-time methods have to prove they didn't damage the underlying model.

"A training-time sparse method must survive a harder test: once training is done, will the resulting model still be a competent dense-attention model?"

Think of it like a lighthouse keeper.

An analogy ↓
Imagine you're a lighthouse keeper who has to watch a million ships in the ocean. You can't pay attention to all of them — that's quadratic, that's impossible. So you build a hierarchy. The closest ships you see individually. Mid-distance ships you group into small clusters. Far ships you group into bigger clusters still. Then at each moment, you only focus on the few ships (or clusters) that look most interesting right now — the ones lighting up your beam.

That's Lighthouse Attention in one paragraph. The "ships" are tokens. The "clusters" are pyramid levels that pool $p^\ell$ tokens together. The "looking interesting" is a parameter-free score (just the $\ell_2$ norm of the projection). And the "focusing" is a top-K selection.

TL;DR — the four moves

  1. Pool symmetrically. Build a pyramid over Q, K, and V — all three projections, not just keys and values like everyone else.
  2. Score parameter-free. Use the $\ell_2$ norm of each pyramid entry — no learned scorer, no extra parameters.
  3. Select top-K, then gather into a contiguous sub-sequence. Now you have a small, dense sequence ready for stock FlashAttention — no custom sparse kernel needed.
  4. Scatter the outputs back. Each pooled token writes its result to all the original positions it summarized.

Then at the end of training, you do a brief continuation with normal dense SDPA — like a short rehab. The model "remembers" how to be a full Transformer. The crucial empirical claim: this recovered model is as good or better than one trained dense the whole way, but you spent only ~60% of the compute getting there.

Four stages wrapped around a stock kernel.

Here's the entire forward pass. The orange path is the data; the green path is the index selection; the red dashed path is the gradient.

H_t N × d_model INPUT PROJECT W_Q, W_K, W_V → Q, K, V ① PYRAMID POOL ② HIERARCHICAL SELECTOR parameter-free · gradient-free ‖·‖₂ score top-K (bitonic) ③ GATHER Q̃, K̃, Ṽ S × d (S ≪ N) I (indices) STOCK SDPA FlashAttention untouched · Θ(S²d) ④ SCATTER BACK Õ → O O_t N × d OUTPUT ∇L · gradient bypasses the selector (non-differentiable top-K) Forward path Selection (gradient-free) Gradient flow ∇L
Figure 1 (recreated) · The Lighthouse layer wraps but does not modify the attention kernel. Stock FlashAttention runs on a tiny sub-sequence of size $S \ll N$. Selection is decoupled from attention; the top-K step is non-differentiable but that's fine — gradients flow through the gather/scatter back into the projections, which learn to produce values that are useful when selected.

Stage by stage. Equation by equation.

The previous section was the cartoon. This is the engineering. We walk through all four stages, with their math, their motivation, and the design choices that make each one defensible.

The starting point: standard attention

Before Lighthouse changes anything, we need to remember what we're replacing. Standard scaled dot-product attention takes input $X \in \mathbb{R}^{N \times d_{\text{model}}}$, projects it three ways, then computes:

Equation 1 · Standard SDPA
$$Q = X W_Q,\quad K = X W_K,\quad V = X W_V$$ $$\text{Attn}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + M\right) V$$

The problem is the $QK^\top$ matrix: it's $N \times N$. At $N = 100{,}000$ that's 10 billion entries per layer per head. FlashAttention makes this fit in memory by tiling, but it doesn't change the fundamental $\Theta(N^2 d)$ compute cost.

Lighthouse replaces the entire equation with a four-stage pipeline. Let's walk through each.

1

Pyramid Pool — building multi-resolution views

Take the sequence of length $N$. Average-pool it down by a factor of $p$ to get a sequence of length $N/p$. Pool again to get $N/p^2$. Repeat $L$ times. You now have an $L$-level pyramid where level 0 is the original sequence and level $L{-}1$ is a heavily compressed summary.

Crucially: do this to all three projections — Q, K, and V — in lockstep. This is the part that makes Lighthouse different from NSA, HISA, and InfLLM-v2, which all leave queries alone and only pool keys/values.

ℓ=0 16 tokens ℓ=1 pool by p 8 tokens ℓ=2 pool by p² 4 tokens Each pyramid entry is a coherent (Q^(ℓ), K^(ℓ), V^(ℓ)) triple summarizing pℓ tokens.
Pyramid Pool · A 3-level pyramid with pooling factor $p=2$. Level 0 is the full sequence. Level 1 averages every 2 tokens. Level 2 averages every 4 tokens. The same pyramid is built independently for Q, K, and V.

The math of the pyramid

The $i$-th window at level $\ell$ covers a span of $p^\ell$ consecutive tokens:

Equation 2 · Window definition
$$W^{(\ell)}_i = \big[\, i\,p^\ell,\; (i+1)\,p^\ell - 1\,\big], \quad i = 0, \ldots, \tfrac{N}{p^\ell} - 1$$

Then each pyramid entry at level $\ell$, position $i$ is the mean of all the base-level tokens in that window:

Equation 3 · Symmetric mean-pooling
$$Q^{(\ell)}_i = \text{Pool}_\mu\!\big\{ Q_j \,\big|\, j \in W^{(\ell)}_i \big\}$$ $$K^{(\ell)}_i = \text{Pool}_\mu\!\big\{ K_j \,\big|\, j \in W^{(\ell)}_i \big\}$$ $$V^{(\ell)}_i = \text{Pool}_\mu\!\big\{ V_j \,\big|\, j \in W^{(\ell)}_i \big\}$$

Why pool Q symmetrically?

Two big payoffs. First, pooled queries and pooled keys now live in the same representation space at every level: a pooled query at level 2 can route to a pooled key at level 2 and the dot product is meaningful. With asymmetric pooling (only K and V), the coarse keys exist but there's nothing coarse to query them.

Second, the dense attention call cost drops from $\mathcal{O}(NSd)$ down to $\mathcal{O}(S^2 d)$, because both sides of the attention are now of length $S$. That's a huge win when $N$ is in the millions.

Total pyramid cost: $\sum_{\ell=0}^{L-1} N/p^\ell \le N \cdot p/(p-1) = \Theta(N)$. Pooling is essentially free.

2

Score and Top-K — picking what matters

We now have a pyramid of candidates. There are about $2N$ of them total (across all levels). We need to pick the $K$ most important to form our sub-sequence. The score is the simplest thing imaginable: the $\ell_2$ norm of the projection.

At the base level, the scores are:

Equation 4 · Base-level scores
$$s^{QK}_{0,i} = \lVert Q_i \rVert_2, \qquad s^{KQ}_{0,i} = \lVert K_i \rVert_2$$

At coarser levels, we don't recompute on the pooled projections — instead we max-pool the base-level scores. A coarse window "inherits" the importance of its single most-relevant token:

Equation 5 · Hierarchical max-pooled scores
$$s^{QK}_{\ell,i} = \max_{0 \le j < p^\ell} s^{QK}_{0,\, i p^\ell + j}$$ $$s^{KQ}_{\ell,i} = \max_{0 \le j < p^\ell} s^{KQ}_{0,\, i p^\ell + j}$$

Then we run a single global top-K over everything: every level, every position, both QK and KQ scores. The output is a set of indices $\mathcal{I}$:

Equation 6 · The top-K selection
$$\mathcal{I} = \text{TopK}\!\Big(\big\{ s^{QK}_{\ell,i},\, s^{KQ}_{\ell,i} : (\ell, i) \in \mathcal{P} \big\},\; k\Big)$$

where $\mathcal{P}$ is the set of all pyramid coordinates. There's one extra rule: the coarsest level is always kept in full. This is cheap (it's small) and guarantees that every base position has at least one summary that touches it. The remaining budget is spent on finer levels where it matters more.

SCORES ‖Q_i‖₂ ‖K_i‖₂ top-K cutoff SELECTED (top-K) — gathered into contiguous sub-sequence The selected entries form a dense, causally-ordered sequence ready for stock FlashAttention.
Scoring & Top-K · Every pyramid entry gets a parameter-free $\ell_2$ score. The top-K with the highest scores (★) are gathered into a dense sub-sequence. The chunked-bitonic kernel does this in $\Theta(N \log K)$, with each chunk dispatched as an independent CTA — no thread block ever holds more than $m$ scores.

Why no learned scorer?

You might wonder: why not learn what's important? Methods like DSA and NSA do exactly that, training a small scoring head end-to-end. Lighthouse takes the harder, more honest path:

3

Gather + Dense Attention — stock FlashAttention on a tiny sub-sequence

Now the magic happens. We use the indices $\mathcal{I}$ to gather the selected triples into a contiguous sequence:

Equation 7 · The gathered sub-sequence
$$\tilde{Q}_m = Q^{(\ell_m)}_{i_m},\quad \tilde{K}_m = K^{(\ell_m)}_{i_m},\quad \tilde{V}_m = V^{(\ell_m)}_{i_m}, \quad m = 1, \ldots, S$$

The total length of this sub-sequence is:

Equation 8 · The crucial size formula
$$S = \frac{N}{p^{L-1}} + (L-1)\,p\,k$$

The first term is the coarsest level (always kept in full); the second is the budget spent on the finer levels. The numbers are wild. The paper gives this example:

At $N = 10^6,\ L = 4,\ p = 4,\ k = 4096$:
$S \approx 65{,}000 \ll 1{,}000{,}000$

Attention runs on a sequence 15× shorter than the original. Because the gather is topologically sorted, the causal mask is just a standard $S \times S$ causal mask — no sparse indexing inside the kernel.

Then attention is just standard attention on this sub-sequence:

Equation 9 · Stock FlashAttention
$$\tilde{O} = \text{Attn}(\tilde{Q}, \tilde{K}, \tilde{V}; \tilde{M})$$

Why this is the load-bearing design choice

Every previous selection-based method jams its index list inside the attention kernel and writes a custom sparse matmul. That means:

Lighthouse sidesteps all of this. The gather produces a dense, contiguous tensor. Stock FlashAttention runs on it, unchanged. Forward and backward are bit-for-bit identical to a standard Transformer's. Switching to FlashAttention-3, FlashAttention-4, or whatever comes next is free.

4

Scatter-Back — distributing outputs to base positions

We computed attention on $S$ tokens. But the layer's output is supposed to have $N$ tokens. So each output entry from the sub-sequence needs to be scattered back to all the base positions it represents.

An entry at pyramid level $\ell$, position $i$ summarized a window $W^{(\ell)}_i = [ip^\ell,\; (i+1)p^\ell - 1]$. But its output isn't written to that range directly — it's written to a shifted range that respects causality:

Equation 10 · The shifted output range
$$R(\ell, i) = \big[\, i p^\ell + p^\ell - 1,\; i p^\ell + 2p^\ell - 2\,\big]$$

The shift of $p^\ell - 1$ is the causality trick: it ensures that a base position $j$ never receives a summary that contains its own future. This is non-obvious and matters! Without the shift, position 0 might read a summary that already "saw" position 5, which would leak future information.

Final output is just the sum of all contributions arriving at each position:

Equation 11 · Per-position output
$$O_j = \sum_{m\,:\, j \in R(\ell_m, i_m)} \tilde{O}_m$$

The per-position fan-in is bounded by $L$ (the number of pyramid levels) regardless of how big $k$ gets. Because there are no "holes" in the gathered sequence and no missing base positions, the output is fully dense — it's a compressive approximation of full attention, but with no gaps to cause gradient instabilities.

SUB-SEQUENCE OUTPUTS Õ Õ₁ (ℓ=2) Õ₂ (ℓ=1) Õ₃ Õ₄ (ℓ=1) Õ₅ (ℓ=2) FULL DENSE OUTPUT O Every base position receives at least one summary. No holes. No gradient instabilities. Color = level of pyramid that contributed. The shift of pℓ−1 enforces causality.
Scatter-Back · Each $\tilde{O}_m$ is broadcast to its $p^\ell$ base positions, shifted right by $p^\ell - 1$ to preserve causality. Contributions across levels are summed atomically. This guarantees the output is fully dense.

The top-K is not differentiable. And that's fine.

This is the most surprising part of the design. There's no straight-through estimator, no Gumbel-softmax, no auxiliary scoring loss. Just a hard top-K with zero gradient. The whole thing still works.

Here's how the gradient flows during the backward pass:

∇L (loss)
scatter
FlashAttention
gather
pyramid pool
W_Q, W_K, W_V

Gradients flow back through the entire data path — from the loss, through the scatter, through FlashAttention's existing backward, through the gather, and into the projection matrices. The scorer and the top-K never receive a gradient; they're bypassed entirely.

What does the model actually learn?

This is the elegant part. Because $W_Q, W_K, W_V$ get gradient signal only through entries that were selected, the projections learn to produce values that are useful when they're chosen — not to game a scoring function. There's no scorer to game. The scorer is just an $\ell_2$ norm; it doesn't have parameters.

This sidesteps every known optimization pathology of learned selectors: scorer collapse (where the scorer fixates on one set of tokens), scorer-attention misalignment (where the scorer picks tokens the attention doesn't actually use), auxiliary-loss tuning hell.

How fast is fast? Let's do the asymptotics.

The whole pipeline is a sequence of stages. Each one has a cost. The only super-linear term is the dense FlashAttention call on the gathered sub-sequence — and even that becomes polylogarithmic with the right choice of $L$.

Per-stage cost breakdown

Stage Primitive Cost
Projections Q, K, VGEMM$\Theta(N\, d_{\text{model}}\, d)$
Pyramid poolview + mean$\Theta(N\, d)$
Scoring (norms, max-pool)norm + max$\Theta(N\, d)$
Top-K selectionchunked bitonic$\Theta(N \log k)$
Gather to sub-sequencetorch.gather$\Theta(S\, d)$
Dense sub-sequence attentionFlashAttention$\Theta(S^2\, d)$
Scatter-backcustom atomic$\Theta(N\, d)$

The asymptotic trick

Recall $S = N/p^{L-1} + (L-1)\, p\, k$. If we set $L = \log_p(N/k)$, then $p^L = N/k$, and:

The balanced choice
$$S = pk + (L-1)pk = pk \cdot \log_p(N/k) = \Theta(k \log N)$$

So the dense attention term becomes:

Final attention cost
$$S^2 \cdot d = \Theta\!\big(k^2 \log^2 N \cdot d\big) \quad \text{← polylogarithmic in } N$$

For bounded $k$, the total per-layer compute reduces to $\Theta(N \cdot d)$. That's the same asymptotic class as linear attention and SSMs, while still using softmax attention.

How does that compare?

Method Per-layer compute Notes
Dense softmax attention$\Theta(N^2 \cdot d)$What we're trying to escape
Log-Linear Attention$\Theta(N \log N \cdot d)$Better, still super-linear
Lighthouse (bounded k)$\Theta(N \cdot d)$Linear, with softmax
Linear attention / SSMs$\Theta(N \cdot d)$Same class, no softmax

Train hierarchical, end with dense. The recoverability test.

This is the part that makes Lighthouse usable. If you trained the entire run with the hierarchical approximation, you'd be left with a model that only knows how to do hierarchical attention — not a normal Transformer. So the recipe has two phases.

STAGE 1 · LIGHTHOUSE — fast hierarchical training STAGE 2 · DENSE SDPA RECOVERY SWITCH brief training-loss spike (1.12–1.57) recovers within ~1–1.5k steps step 0 step ~10k step 16k 2× per-step throughput same kernel as baseline
The two-stage recipe · Stage 1 trains with Lighthouse at ~2× the throughput. Stage 2 resumes the same checkpoint under dense SDPA (the same optimizer state, the same data loader). The training loss spikes briefly when attention switches mode, then recovers within ~1–1.5k steps and crosses below the dense-from-scratch baseline.

Why does this even work?

The crucial insight: Lighthouse's inner attention call is just FlashAttention. When you flip the switch in stage 2 — disable selection, pass full Q, K, V into the same kernel — the model is immediately running real attention. It's not switching architectures. The weights $W_Q, W_K, W_V$ are the same weights; they've been training all along to produce sensible Q, K, V projections. The only thing that changes is which tokens those projections get applied to.

The "recovery" is the model adjusting from only seeing top-K subsets to seeing everything. Empirically this is fast. The paper tests three resume points (10k+6k, 11k+5k, 12k+4k) and all three end at lower final loss than dense-from-scratch at the same token budget.

Faster training, same or better final model.

The empirical story is clean. Compared to dense-SDPA from scratch on the exact same architecture, data, and token budget, every Lighthouse configuration was faster and at least as good on final loss.

21×
Faster attention
at 512K context (fwd)
17.3×
Faster fwd+bwd
at 512K context
1.7×
End-to-end
training speedup
126K
tok/s/GPU
vs 46K for dense SDPA

Headline ablation table

Numbers from Table 1 in the paper. 530M Llama-3-style model, 16k optimizer steps, ~50.3B tokens, 98K context. "Final Loss" is training loss at step 16,000.

Config LH Steps B200-Hours ↓ Tok/s/GPU ↑ Final Loss ↓
SDPA Baseline (dense from scratch)303.245.6K0.7237
LH→SDPA · 12k+4k · k=614412k214.774.7K0.7102
LH→SDPA · 11k+5k · k=614411k219.675.4K0.7001
LH→SDPA · 10k+6k · k=614410k228.075.0K0.6980
L=3, p=2, k=1536 (best loss)10k203.993.9K0.6825
L=3, p=4, k=1536, Norm scorer (best speed)10k179.6126.0K0.6946
1M-token training (CP=8) · k=409610k1300.348.9K0.6721

Read the table this way: every Lighthouse row has a lower final loss than dense-from-scratch (0.7237) while training in less wall-clock time. The best speed-vs-quality balance — L=3, p=4, k=1536, norm scorer — finishes in 60% of the dense baseline's time with a loss of 0.6946 (a meaningful improvement). The best raw loss configuration — L=3, p=2, k=1536 — saves about a third of the compute and gets 0.6825.

Surprising finding: smaller k is better

Counter-intuitively, decreasing the selection budget $k$ tends to improve final loss within the tested range: $0.6825 \to 0.6880 \to 0.6890 \to 0.6951$ as $k$ goes $1536 \to 2048 \to 3072 \to 4096$. The authors speculate that sparser configurations regularize better against their limited training budget — picking fewer, more confident tokens helps generalization. Whether this reverses at much larger budgets remains open.

Long-context retrieval (needle in a haystack)

The most stringent test: hide a single digit in random alphanumeric filler at various depths up to 96K context, and ask the model to retrieve it. The Lighthouse-trained models with dilated scoring beat the dense baseline (mean retrieval 0.76 at k=2048 dilated vs 0.72 baseline). The norm scorer hurts retrieval a bit more than loss, so the right default depends on your downstream task.

Why this is interesting beyond the speedup.

A few things in this paper are worth flagging because they push against the field's current direction:

① The "training-time correctness" criterion is hard, and rarely tested.

Many sparse attention papers are evaluated only at inference, against a dense backbone they didn't train. That's a much easier test — the dense model already works, you're just trying not to lose too much. Lighthouse explicitly takes on the harder question: does the model you trained still work as a full Transformer? The answer they get (yes, with a brief resume) is non-obvious and a contribution in its own right.

② Decoupling selection from attention is a small idea with big consequences.

Once selection lives outside the kernel, the kernel can be anything. FlashAttention-2, -3, -4 — Lighthouse picks up all upstream improvements for free. Custom sparse kernels can't say that. Ring attention for context parallelism? Works unchanged. The implementation surface shrinks dramatically.

③ Parameter-free wins, again.

A persistent theme in deep learning: the simpler thing often works as well as the learned thing. Lighthouse adds zero learnable parameters or auxiliary losses; the $\ell_2$ norm of the projection is enough. The paper even argues this is the cheaper option, so any benefit they show is a lower bound on what a smarter scorer could achieve.

④ The non-differentiable top-K is a confidence move.

Most of the field treats discrete operations as a problem to soften (Gumbel-softmax, straight-through estimators). Lighthouse just keeps them hard and routes gradients around them. The projections learn to be useful when selected — a cleaner optimization story than "learn to score well so you get selected."

What this doesn't do.

The paper is upfront about its limits, and they matter:

Lighthouse Attention is, in the end, a recipe for training without paying the full $N^2$ cost — while ending up with a model that doesn't know you cheated.

Code: github.com/ighoshsubho/lighthouse-attention