A research-scientist's tour of LoLCATs: attention transfer as a feature-map distillation problem, LoRA as residual correction, and a block-wise schedule that makes 405B-parameter linearization tractable.
Causal softmax attention is $\mathcal{O}(L^2 d)$ in compute and $\mathcal{O}(L d)$ in KV-cache memory during decoding, where $L$ is sequence length and $d$ is head dimension. Linear attention reduces these to $\mathcal{O}(L d d')$ compute and $\mathcal{O}(d d')$ recurrent state — independent of $L$ during generation. The question is not whether subquadratic attention exists, but whether one can transplant it into a model that has already been pretrained for trillions of tokens.
Prior linearization methods (SUPRA, Mamba-in-Llama, Mohawk) fall short on three axes simultaneously:
| Failure mode | Symptom | Magnitude |
|---|---|---|
| Quality | 5-shot MMLU drop vs. base model | 23.4 – 28.2 points |
| Token cost | Linearizing-stage corpus size | 20 – 100B tokens |
| Scale ceiling | Largest model linearized | ≤ 8B parameters |
LoLCATs reframes the goal: rather than treat linearization as re-pretraining a different architecture, treat it as approximating a fixed nonlinear operator (softmax attention) with a learnable subquadratic one (parameterized linear attention), then absorb the residual approximation error via LoRA.
For a single head with queries $\bm{q}, \bm{k}, \bm{v} \in \mathbb{R}^{L \times d}$, causal softmax attention computes
The exponential $\exp(\bm{q}^\top \bm{k}/\sqrt{d})$ is a positive-definite kernel $\mathcal{K}(\bm{q}, \bm{k})$. By Mercer's theorem, any such kernel admits a feature-map factorization $\mathcal{K}(\bm{q},\bm{k}) = \phi(\bm{q})^\top \phi(\bm{k})$ for some $\phi : \mathbb{R}^d \to \mathbb{R}^{d'}$. For softmax, the exact $\phi$ has infinite dimension; the linear-attention trick is to approximate it with a tractable $\phi$. Substituting and rearranging:
The crucial observation is associativity: by computing the bracketed $\sum_i \phi(\bm{k}_i) \bm{v}_i^\top \in \mathbb{R}^{d' \times d}$ first, attention becomes a linear recurrence. Define the running state $\bm{s}_n$ and normalizer $\bm{z}_n$:
This collapses generation memory from $\mathcal{O}(Ld)$ (the KV cache) to $\mathcal{O}(dd')$ (a fixed-size recurrent state). The figure below makes the geometry of the cost reduction concrete.
Almost every linear-attention paper since Katharopoulos et al. (2020) has been a proposal for $\phi$: ELU+1, performer features, T2R (ReLU on a learned projection), Hedgehog (split softmax), etc. The defining choice of LoLCATs is to learn $\phi$ specifically to match the softmax kernel of a specific pretrained model, layer by layer, head by head.
LoLCATs decomposes linearization into two surgical objectives, training entirely different parameter sets in each.
For each layer $m$ and head $h$, parameterize the feature maps as a shallow learnable layer:
with $\tilde{\bm{W}} \in \mathbb{R}^{d \times d'}$, $\tilde{\bm{b}} \in \mathbb{R}^{d'}$, and $f$ a nonlinearity. The training objective is the per-head output MSE averaged over $H$ heads and $M$ layers:
Two implementation details matter:
Teacher-forcing across layers. Even though $\phi$ is being trained jointly across layers, the inputs to layer $m+1$ are the true softmax outputs $\bm{y}^{(m)}$, not the student outputs $\hat{\bm{y}}^{(m)}$. This decouples per-layer optimization and prevents the student error from compounding through the depth of the network during training. It also means $\bm{y}$ can be computed once per forward pass via FlashAttention and reused.
Output-MSE, not attention-weight MSE. Hedgehog (Zhang et al., 2024) supervises on the $L \times L$ attention matrix directly:
This is $\mathcal{O}(L^2)$ in memory and rules out FlashAttention. LoLCATs supervises on the contracted output $\bm{y}_n \in \mathbb{R}^d$, dropping to $\mathcal{O}(L)$ — which is what makes the method viable at 405B scale.
After Stage 1, $\hat{\bm{y}} \approx \bm{y}$ in expectation but not exactly. Substituting the linear approximation throughout the network shifts the residual stream; the cross-entropy loss creeps up. Instead of full finetuning, LoLCATs applies LoRA only to the four attention projections $\bm{W}_q, \bm{W}_k, \bm{W}_v, \bm{W}_o$:
with $r = 8$. This is sufficient to absorb the approximation residual without disturbing the rest of the model.
Stage 1: 32 layers × 32 heads × 2 feature maps × (128 × 64) ≈ 16.8M params — about 0.2% of the model. Trainable on a single 40 GB GPU.
Stage 2: LoRA at $r=8$ on four projections, 32 layers: ≈ 6.5M params — under 0.09% of the model.
Tokens: 40M total. Prior methods used 20–100B. That is a 500–2500× reduction.
The naive framework above works, but bare linear attention has a known weakness: it cannot represent the sharp, local, peaky distributions that softmax produces near the diagonal (recent tokens). The MSE residual is dominated by these failures. LoLCATs Part 2 introduces a hybrid that respects this structure.
Split each sequence position $n$ into a recent window of size $W$ and a long-range tail. Run exact softmax attention over the window, linear attention over everything before it:
where $Z_n^{\text{win}}$ and $Z_n^{\text{lin}}$ are the two normalizers summed for a shared denominator. The local window keeps compute at $\mathcal{O}(LW d)$ — still linear in $L$ for fixed $W$ — and crucially gives the model an exact handle on local context, while linear attention handles the diluted long-range signal where its approximation is much closer to softmax behavior anyway.
The paper compares three parameterizations of $\phi$, with concrete implications for representational capacity:
| Map | Form | Output dim | Property |
|---|---|---|---|
| T2R | $\text{ReLU}(\bm{q} \tilde{\bm{W}} + \tilde{\bm{b}})$ | $d' = d$ | Cheap; loses sign info |
| Hedgehog | $[\text{SM}_d(\bm{q}\tilde{\bm{W}}) \,\oplus\, \text{SM}_d(-\bm{q}\tilde{\bm{W}})]$ | $d' = 2d$ | Preserves both signs; softmax-like |
| LoLCATs default | Hedgehog inside window+linear hybrid | $d' = 2d$ | Lowest output MSE |
Empirically, the Hedgehog parameterization combined with the window+linear hybrid gives the lowest layer-wise MSE and the highest downstream MMLU. The intuition: $\phi$ needs to express both attractive and repulsive interactions in the feature space, and concatenating a $-\bm{q}\tilde{\bm{W}}$ branch gives it room to do so without explicitly negative features.
Jointly training all $M$ layers of $\phi$ from a single forward pass works for 7B–8B models. At 70B and 405B it breaks down. The paper reports that on Llama 3.1 405B (126 attention layers), late-layer MSE can be 200× larger than early-layer MSE under joint training — feature maps deep in the network simply do not converge to a good local solution when supervised under joint gradients.
The fix is a block-wise schedule. Partition the $M$ layers into blocks of $k$ consecutive layers. Within each block, optimize the MSE jointly; across blocks, treat each block's inputs as the cached softmax outputs of the previous block.
The cost model the authors use to pick $k$ balances two terms:
where $T_{\text{block}}(k)$ is the wall-clock cost of training one $k$-layer block and $S_{\text{cache}}$ is the memory cost of storing one block's input activations. For Llama 3.1 405B the chosen $k$ enabled attention transfer in 5 hours on 14 H100s — a regime that joint training simply cannot reach without absorbing prohibitive memory.
The headline measurements, in order of how surprising they are:
On a single 80 GB H100 generating 4096-token samples, linearized Llama 3 8B achieves 3× the throughput and supports 64× larger batch sizes than the FlashAttention-2 softmax baseline — exactly the asymptotic prediction from Figure 1 made concrete.
For context: prior work (Mamba-in-Llama) used 5 days on 8×A100s for an 8B model. LoLCATs handles a 50× larger model in half the GPU-hours.
LoLCATs is the cleanest current instantiation of a broader pattern in post-training architecture surgery: swap a layer family, then patch the manifold. Compare with:
LoLCATs is the first to demonstrate that the linearization gap can be closed with only low-rank updates after a parameter-frugal attention-matching phase, and the first to actually run the procedure at frontier model scale.
The paper does not claim parity with softmax on the hardest reasoning benchmarks — gaps remain on tasks heavy in long-range retrieval, where the constant-size KV state of pure linear attention is genuinely lossy. The sliding-window hybrid mitigates but does not eliminate this. The appendix discusses hybrid layer schedules (interleaving softmax and linear layers) as a route to further closing this gap, at proportional inference cost.