LoLCATs · Technical Walkthrough arXiv:2410.10254 · ICLR 2025
Method paper · Subquadratic LLMs · Stanford / Together AI / Caltech

Low-Rank
Linearizing,
formally.

A research-scientist's tour of LoLCATs: attention transfer as a feature-map distillation problem, LoRA as residual correction, and a block-wise schedule that makes 405B-parameter linearization tractable.

The problem, precisely

Causal softmax attention is $\mathcal{O}(L^2 d)$ in compute and $\mathcal{O}(L d)$ in KV-cache memory during decoding, where $L$ is sequence length and $d$ is head dimension. Linear attention reduces these to $\mathcal{O}(L d d')$ compute and $\mathcal{O}(d d')$ recurrent state — independent of $L$ during generation. The question is not whether subquadratic attention exists, but whether one can transplant it into a model that has already been pretrained for trillions of tokens.

Prior linearization methods (SUPRA, Mamba-in-Llama, Mohawk) fall short on three axes simultaneously:

Failure modeSymptomMagnitude
Quality5-shot MMLU drop vs. base model23.4 – 28.2 points
Token costLinearizing-stage corpus size20 – 100B tokens
Scale ceilingLargest model linearized≤ 8B parameters

LoLCATs reframes the goal: rather than treat linearization as re-pretraining a different architecture, treat it as approximating a fixed nonlinear operator (softmax attention) with a learnable subquadratic one (parameterized linear attention), then absorb the residual approximation error via LoRA.

Preliminaries: softmax → linear attention

For a single head with queries $\bm{q}, \bm{k}, \bm{v} \in \mathbb{R}^{L \times d}$, causal softmax attention computes

Eq. 1 $$ \bm{y}_n = \sum_{i=1}^{n} \frac{\exp(\bm{q}_n^\top \bm{k}_i / \sqrt{d})}{\sum_{j=1}^{n} \exp(\bm{q}_n^\top \bm{k}_j / \sqrt{d})} \, \bm{v}_i $$

The exponential $\exp(\bm{q}^\top \bm{k}/\sqrt{d})$ is a positive-definite kernel $\mathcal{K}(\bm{q}, \bm{k})$. By Mercer's theorem, any such kernel admits a feature-map factorization $\mathcal{K}(\bm{q},\bm{k}) = \phi(\bm{q})^\top \phi(\bm{k})$ for some $\phi : \mathbb{R}^d \to \mathbb{R}^{d'}$. For softmax, the exact $\phi$ has infinite dimension; the linear-attention trick is to approximate it with a tractable $\phi$. Substituting and rearranging:

Eq. 2 $$ \hat{\bm{y}}_n = \frac{\phi(\bm{q}_n)^\top \left( \sum_{i=1}^{n} \phi(\bm{k}_i) \bm{v}_i^\top \right)}{\phi(\bm{q}_n)^\top \sum_{i=1}^{n} \phi(\bm{k}_i)} $$

The crucial observation is associativity: by computing the bracketed $\sum_i \phi(\bm{k}_i) \bm{v}_i^\top \in \mathbb{R}^{d' \times d}$ first, attention becomes a linear recurrence. Define the running state $\bm{s}_n$ and normalizer $\bm{z}_n$:

Eq. 3 $$ \bm{s}_n = \bm{s}_{n-1} + \phi(\bm{k}_n) \bm{v}_n^\top, \qquad \bm{z}_n = \bm{z}_{n-1} + \phi(\bm{k}_n), \qquad \hat{\bm{y}}_n = \frac{\phi(\bm{q}_n)^\top \bm{s}_n}{\phi(\bm{q}_n)^\top \bm{z}_n} $$

This collapses generation memory from $\mathcal{O}(Ld)$ (the KV cache) to $\mathcal{O}(dd')$ (a fixed-size recurrent state). The figure below makes the geometry of the cost reduction concrete.

Softmax: O(L² d) Linear: O(L d d') Sequence length L → Compute / Memory → 2K 8K 32K Cost scaling, softmax vs. linear attention
Figure 1 The motivating asymmetry. Softmax attention's $L^2$ term dominates at long context; linear attention removes it. The crossover is what makes long-context inference economically interesting.

Almost every linear-attention paper since Katharopoulos et al. (2020) has been a proposal for $\phi$: ELU+1, performer features, T2R (ReLU on a learned projection), Hedgehog (split softmax), etc. The defining choice of LoLCATs is to learn $\phi$ specifically to match the softmax kernel of a specific pretrained model, layer by layer, head by head.

The two-stage framework

LoLCATs decomposes linearization into two surgical objectives, training entirely different parameter sets in each.

STAGE 1 · ATTENTION TRANSFER STAGE 2 · LoRA RECOVERY input x x ∈ ℝ^(L×d) Softmax Attn ❄ FROZEN Linear Attn φ_q, φ_k ⚙ TRAIN y (teacher) ŷ (student) MSE(y, ŷ) backprop → φ only teacher-forcing: y flows to next layer, not ŷ input x x ∈ ℝ^(L×d) Linear Attn (trained φ) W_q' = W_q + B_q A_q ← LoRA (same for k, v, o) Next-token loss −Σ log P(u_(t+1) | u_(1:t)) end-to-end training, but only low-rank B, A matrices update rank r = 8 < 0.09% of params
Figure 2 The two-stage LoLCATs procedure. Stage 1 trains only the feature maps $\phi_q, \phi_k$ to minimize output MSE against the frozen softmax teacher — a per-layer distillation. Stage 2 swaps softmax for the trained linear attention and trains LoRA adapters on the attention projections end-to-end under next-token loss.

Stage 1 — Attention transfer

For each layer $m$ and head $h$, parameterize the feature maps as a shallow learnable layer:

$$ \phi_q(\bm{q}_n) := f(\bm{q}_n \tilde{\bm{W}}_{(q)} + \tilde{\bm{b}}_{(q)}), \qquad \phi_k(\bm{k}_i) := f(\bm{k}_i \tilde{\bm{W}}_{(k)} + \tilde{\bm{b}}_{(k)}) $$

with $\tilde{\bm{W}} \in \mathbb{R}^{d \times d'}$, $\tilde{\bm{b}} \in \mathbb{R}^{d'}$, and $f$ a nonlinearity. The training objective is the per-head output MSE averaged over $H$ heads and $M$ layers:

Eq. 5 $$ \mathcal{L}_{\text{MSE}} = \frac{1}{MH} \sum_{m=1}^{M} \sum_{h=1}^{H} \mathcal{L}_{\text{MSE}}^{h,m}, \qquad \mathcal{L}_{\text{MSE}}^{h,m} = \frac{1}{Ld} \sum_{n=1}^{L} \big\| \bm{y}_n - \hat{\bm{y}}_n \big\|^2 $$

Two implementation details matter:

Teacher-forcing across layers. Even though $\phi$ is being trained jointly across layers, the inputs to layer $m+1$ are the true softmax outputs $\bm{y}^{(m)}$, not the student outputs $\hat{\bm{y}}^{(m)}$. This decouples per-layer optimization and prevents the student error from compounding through the depth of the network during training. It also means $\bm{y}$ can be computed once per forward pass via FlashAttention and reused.

Output-MSE, not attention-weight MSE. Hedgehog (Zhang et al., 2024) supervises on the $L \times L$ attention matrix directly:

$$ \mathcal{L}_{\text{Hedgehog}} = \sum_{i,j} \text{KL}\big( \text{softmax}_j(\bm{q}_i^\top \bm{k}_j) \,\big\|\, \tfrac{\phi(\bm{q}_i)^\top \phi(\bm{k}_j)}{\sum_{j'} \phi(\bm{q}_i)^\top \phi(\bm{k}_{j'})} \big) $$

This is $\mathcal{O}(L^2)$ in memory and rules out FlashAttention. LoLCATs supervises on the contracted output $\bm{y}_n \in \mathbb{R}^d$, dropping to $\mathcal{O}(L)$ — which is what makes the method viable at 405B scale.

Stage 2 — LoRA recovery

After Stage 1, $\hat{\bm{y}} \approx \bm{y}$ in expectation but not exactly. Substituting the linear approximation throughout the network shifts the residual stream; the cross-entropy loss creeps up. Instead of full finetuning, LoLCATs applies LoRA only to the four attention projections $\bm{W}_q, \bm{W}_k, \bm{W}_v, \bm{W}_o$:

$$ \bm{W}' = \bm{W} + \Delta \bm{W}, \qquad \Delta \bm{W} = \bm{B} \bm{A}, \qquad \bm{B} \in \mathbb{R}^{d \times r}, \; \bm{A} \in \mathbb{R}^{r \times d} $$

with $r = 8$. This is sufficient to absorb the approximation residual without disturbing the rest of the model.

Parameter accounting · Llama 3 8B

Stage 1: 32 layers × 32 heads × 2 feature maps × (128 × 64) ≈ 16.8M params — about 0.2% of the model. Trainable on a single 40 GB GPU.

Stage 2: LoRA at $r=8$ on four projections, 32 layers: ≈ 6.5M params — under 0.09% of the model.

Tokens: 40M total. Prior methods used 20–100B. That is a 500–2500× reduction.

Architectural extension: sliding window + linear

The naive framework above works, but bare linear attention has a known weakness: it cannot represent the sharp, local, peaky distributions that softmax produces near the diagonal (recent tokens). The MSE residual is dominated by these failures. LoLCATs Part 2 introduces a hybrid that respects this structure.

Split each sequence position $n$ into a recent window of size $W$ and a long-range tail. Run exact softmax attention over the window, linear attention over everything before it:

Eq. 6 $$ \hat{\bm{y}}_n = \underbrace{\sum_{i = n-W}^{n} \frac{\exp(\bm{q}_n^\top \bm{k}_i / \sqrt{d})}{Z_n^{\text{win}} + Z_n^{\text{lin}}} \bm{v}_i}_{\text{sliding window, exact}} \; + \; \underbrace{\frac{\phi_q(\bm{q}_n)^\top \sum_{i=1}^{n-W-1} \phi_k(\bm{k}_i) \bm{v}_i^\top}{Z_n^{\text{win}} + Z_n^{\text{lin}}}}_{\text{linear long-range}} $$

where $Z_n^{\text{win}}$ and $Z_n^{\text{lin}}$ are the two normalizers summed for a shared denominator. The local window keeps compute at $\mathcal{O}(LW d)$ — still linear in $L$ for fixed $W$ — and crucially gives the model an exact handle on local context, while linear attention handles the diluted long-range signal where its approximation is much closer to softmax behavior anyway.

Hybrid attention mask: window + linear n=1 n=L i=1 i=L key index i → query n → Sliding window: exact softmax over last W tokens Linear attention: φ(q)·φ(k) over tail Both terms share a single denominator (Eq. 6), giving a proper probability distribution across the full context. Cost: O(L·W·d) + O(L·d·d') → still linear in L
Figure 3 The hybrid mask. Sharp local structure (red band) is preserved exactly via softmax over the last $W$ tokens; the diffuse long-range tail (teal hatching) is handled by linear attention. Sharing a denominator keeps the output a normalized convex combination of values.

Feature-map choices

The paper compares three parameterizations of $\phi$, with concrete implications for representational capacity:

MapFormOutput dimProperty
T2R$\text{ReLU}(\bm{q} \tilde{\bm{W}} + \tilde{\bm{b}})$$d' = d$Cheap; loses sign info
Hedgehog$[\text{SM}_d(\bm{q}\tilde{\bm{W}}) \,\oplus\, \text{SM}_d(-\bm{q}\tilde{\bm{W}})]$$d' = 2d$Preserves both signs; softmax-like
LoLCATs defaultHedgehog inside window+linear hybrid$d' = 2d$Lowest output MSE

Empirically, the Hedgehog parameterization combined with the window+linear hybrid gives the lowest layer-wise MSE and the highest downstream MMLU. The intuition: $\phi$ needs to express both attractive and repulsive interactions in the feature space, and concatenating a $-\bm{q}\tilde{\bm{W}}$ branch gives it room to do so without explicitly negative features.

The scaling trick: block-wise attention transfer

Jointly training all $M$ layers of $\phi$ from a single forward pass works for 7B–8B models. At 70B and 405B it breaks down. The paper reports that on Llama 3.1 405B (126 attention layers), late-layer MSE can be 200× larger than early-layer MSE under joint training — feature maps deep in the network simply do not converge to a good local solution when supervised under joint gradients.

The fix is a block-wise schedule. Partition the $M$ layers into blocks of $k$ consecutive layers. Within each block, optimize the MSE jointly; across blocks, treat each block's inputs as the cached softmax outputs of the previous block.

Block-wise attention transfer (k = 9 shown) BLOCK 1 (layers 1–k) layer 1 layer 2 layer k joint MSE within block backprop through k layers output: cached y^(k) BLOCK 2 (layers k+1–2k) inputs = cached y^(k) independent optimization trains in parallel ↔ … BLOCK M/k final block k controls a memory ↔ parallelism tradeoff: large k → fewer cached activations, less parallel; small k → more memory for inter-block caches, more parallelism
Figure 4 Block-wise training. Each block of $k$ contiguous layers is optimized independently against its cached softmax-teacher inputs. The choice of $k$ trades off the memory of storing inter-block activations against the wall-clock benefit of training blocks in parallel.

The cost model the authors use to pick $k$ balances two terms:

$$ \text{Cost}(k) \;\propto\; \underbrace{\frac{M}{k} \cdot T_{\text{block}}(k)}_{\text{sequential training}} \;+\; \underbrace{\frac{M}{k} \cdot S_{\text{cache}}}_{\text{cached activations}} $$

where $T_{\text{block}}(k)$ is the wall-clock cost of training one $k$-layer block and $S_{\text{cache}}$ is the memory cost of storing one block's input activations. For Llama 3.1 405B the chosen $k$ enabled attention transfer in 5 hours on 14 H100s — a regime that joint training simply cannot reach without absorbing prohibitive memory.

Empirical results

The headline measurements, in order of how surprising they are:

+17.2
5-shot MMLU gain over prior linearization on Llama 3 8B
73.1
Zero-shot LM-Eval avg (linearized Llama 3 8B) vs 74.2 for original
77.8%
MMLU gap closed for Llama 3.1 70B vs. its softmax teacher
78.1%
MMLU gap closed for Llama 3.1 405B

Throughput

On a single 80 GB H100 generating 4096-token samples, linearized Llama 3 8B achieves 3× the throughput and supports 64× larger batch sizes than the FlashAttention-2 softmax baseline — exactly the asymptotic prediction from Figure 1 made concrete.

Compute footprint

For context: prior work (Mamba-in-Llama) used 5 days on 8×A100s for an 8B model. LoLCATs handles a 50× larger model in half the GPU-hours.

The paper's deepest contribution is not the architecture but the framing: linearization as a local approximation problem, with global behavior recovered by a small residual correction.

Where it sits in the literature

LoLCATs is the cleanest current instantiation of a broader pattern in post-training architecture surgery: swap a layer family, then patch the manifold. Compare with:

LoLCATs is the first to demonstrate that the linearization gap can be closed with only low-rank updates after a parameter-frugal attention-matching phase, and the first to actually run the procedure at frontier model scale.

Limitations worth flagging

The paper does not claim parity with softmax on the hardest reasoning benchmarks — gaps remain on tasks heavy in long-range retrieval, where the constant-size KV state of pure linear attention is genuinely lossy. The sliding-window hybrid mitigates but does not eliminate this. The appendix discusses hybrid layer schedules (interleaving softmax and linear layers) as a route to further closing this gap, at proportional inference cost.