Skip to main content

experimental-transformer-architectures

Active

17 experimental language-model architectures where hidden dim = vocab size — no embedding, no output projection; every intermediate state is a distribution over vocabulary.

PythonUpdated Apr 22, 2026
ablation-studyinterpretable-mllanguage-modellinear-attentionpytorchregister-machineresearchtransformer

Experimental Transformer Architectures

17 experimental language model architectures where hidden dimension equals vocabulary size. The register state IS the prediction at every step.

Quick Start · Report Bug · Model Versions

License

Exploring whether vocabulary-space computation, where every hidden state is a readable distribution over words, can match opaque-embedding architectures at language modeling.

What This Is

A collection of 17 experimental language model architectures that share one constraint: hidden dimension = vocabulary size. There is no learned embedding and no output projection. The register state IS the prediction. Every intermediate state is interpretable as “which words are active and how strongly.”

This constraint is genuinely novel — no published architecture we’re aware of operates this way. Whether it’s a good idea is an open question we’re trying to answer empirically.

One variant (v13_with_embedding) deliberately breaks the constraint as a labeled control.

What We’ve Found So Far

Benchmark results (10 min, 3x A40, batch=491,520 tokens)

MODEL_VERSIONArchitectureParamsStepsval_lossval_bpbtok/sStatus
v8_lowrank_vv (rank 8)Recurrent rank-r V x V linear layer164K1005.243.10270KStill descending
v2_convDepthwise causal conv + Fourier channel mix353K4645.393.19383KStill descending
v6_banded_fourierBand-partitioned Fourier with gated coupling824K1665.663.35136KStill descending
v1_shared_attnShared GQA attention + Fourier channel mix3.4M2396.063.59196KPlateaued
v7_soft_opsSoft op-bank + soft register addressing329K3486.263.71287KUnstable (loss spikes)
v3_fourier_linattnLinear attn with causal decay (Fourier proj)329K3976.814.03326KStuck
v8_lowrank_vv (rank 64)Recurrent rank-r V x V linear layer1.1M188270KMemorized (train 0.04, overfitting)

What these results mean

The low-rank V x V linear layer (v8) is the best architecture so far. At rank 8 with 164K params, it reaches val_loss 5.24 in 100 steps — better than v2_conv (353K params, 464 steps) with half the parameters in one-fifth the steps. The train/val gap is essentially zero, confirming it’s learning, not memorizing.

But at rank 64, the same architecture memorizes. The 1.1M-param version drove train_loss to 0.04 while val_loss stayed high. The rank-64 U @ V^T matrix has enough capacity to store a bigram lookup table. Rank 8 can’t, so it’s forced to learn a compressed, generalizable mapping instead.

This is still far from useful. val_loss 5.24 (3.10 bpb) is well above the ~1.7 loss needed for 1 bpb. GPT-2 at 124M params achieves ~0.93 bpb. We’re at 164K params, so the comparison isn’t fair, but the gap is large.

What’s actually unique here

  1. hidden_dim = vocab_size with no embedding or output projection. No published architecture does this. The state IS the prediction at every step.

  2. Interpretability by construction. You can read intermediate states as distributions over vocab dimensions. This is not a post-hoc technique.

  3. The specific combination of vocabulary-space state + various cross-position mechanisms (conv, decay memory, low-rank linear) + recurrent depth has not been explored before.

What’s NOT unique

  • Weight sharing across depth: Universal Transformer (2019), ALBERT (2020), DEQ (2019) all do this.
  • Fourier parameterization: FNet (2022), butterfly matrices, Fourier Neural Operators.
  • Causal decay memory / linear attention: RWKV, Mamba, S4 all use equivalent mechanisms.
  • Low-rank dimension-to-dimension interaction: mathematically, x @ U @ V^T is just a rank-r linear layer.
  • Recurrent register machines: Neural Turing Machine (2014), Neural GPU (2016).

Honest assessment of the v8_lowrank_vv results

The rank-8 variant works well because direct bilinear dimension-to-dimension interaction is a good inductive bias when the dimensions are vocab entries. Language is fundamentally about which words predict which other words. A model that directly parameterizes W[i, j] = "dim i predicts dim j" captures this structure more efficiently than architectures that must discover it through generic operations (convolutions, MLPs, Fourier transforms).

But this is a well-known insight. Bigram and n-gram models encode the same structure. The open question is whether multi-hop propagation (8 hops through the low-rank interaction matrix) captures longer-range dependencies that simple n-grams cannot. The current results don’t answer this — we’d need to test on tasks requiring longer-range reasoning.

Architecture

All variants share the same skeleton:

Input:  one-hot("cat") -> R["cat"] = 1.0, everything else 0.0
Repeat N times:
  1. Cross-position mixing  (how do words at different positions interact?)
  2. Within-position transform  (how do vocab-dim activations combine?)
Output: register state -> softcap -> cross-entropy loss

No embedding. No output projection. (Except v13_with_embedding, the labeled control.)

Model Versions

Names describe mechanism, not metaphor.

Core variants

MODEL_VERSIONCross-positionWithin-positionNotes
v1_shared_attnGQA + RoPE (weights shared across depth)Fourier-parameterized channel mix3.4M params, plateaus early
v2_convDepthwise causal 1D convFourier-parameterized channel mix353K params, strong baseline
v3_fourier_linattnLinear attn with causal decay; Q/K/V/O via Fourier basisFourier-parameterized channel mixStuck — Fourier parameterization bottleneck
v4_weight_sharedShared Q/K + per-head decay (v3 body)Factored (diag + low-rank) channel mixSize-reduction ablation of v3
v5_fft_linattnLinear attn with causal decay; Q/K/V/O via rFFTFFT-based channel mixFourier-over-vocab same caveat
v6_banded_fourierBand-partitioned Fourier linattn, gated couplingThree parallel band projections, gated824K, still descending
v7_soft_opsLinear attn with causal decayGumbel-soft op-bank + soft register addressingUnstable, loss spikes
v8_lowrank_vvDiagonal Q/K linear attn, activation similarityLow-rank V x V (U @ V^T + diag)Best so far at rank 8
v9_linattnLinear attn with causal decay (dense projections)MLP bottleneck4.2M params, best non-attention variant
v10_state_cond_opLinear attn in compressed state spaceState-conditioned soft read/op/write dispatchUntested
v11a_mixed_opsHigh-decay EMA + linear-attnSigmoid gate, dense layer, low-decay EMAUntested
v11b_hard_routingMulti-timescale linear attnGumbel-hard op routing + PonderNet haltingUntested
v12_vocab_sliceCausal decay in fixed k-dim sliceMLP in k-dim sliceUntested; slice indices are deterministic vocab-id windows
v14_data_dependentInput-modulated conv (Hyena)Data-dependent decay (Mamba), DCT mixMamba / RWKV / Hyena bundle
v15_aux_lossv12 body + per-step CE + top-k sparsityEntropy-adaptive write scalingTraining-side additions on v12
v16_multi_branchPer-column decay memoryBranched gated MLP + cross-column inhibitionEnsemble + gated branches

Control variant

MODEL_VERSIONPurpose
v13_with_embeddingThesis-breaking control. Adds Embedding(V, d) -> Linear(d, V) before the register state (same body as v12_vocab_slice). Exists to measure what the no-embedding constraint costs; do not reuse as a template.

Quick Start

# Setup on RunPod
curl -sSL https://raw.githubusercontent.com/urmzd/experimental-transformer-architectures/main/setup.sh | bash

# Or manually
uv pip install --system -r pyproject.toml
python data/download_data.py --variant sp1024

# Train the best model (low-rank V x V, rank 8)
INTERACTION_RANK=8 MODEL_VERSION=v8_lowrank_vv \
  torchrun --standalone --nproc_per_node=$(nvidia-smi -L | wc -l) train.py

# Benchmark all models
benchmark

# Benchmark specific models
benchmark --versions v8_lowrank_vv,v2_conv,v14_data_dependent --minutes 10

All hyperparameters configurable via environment variables. See core/config.py.

What We’ve Learned

Inductive bias matters more than parameter count. v8_lowrank_vv (164K params, rank 8) beats v1_shared_attn (3.4M params, 20x more) because direct dimension-to-dimension interaction is a better prior for language than generic attention in vocab space.

Too much capacity in the right place enables memorization. v8_lowrank_vv at rank 64 memorizes the training batch (train loss 0.04). At rank 8 it generalizes (train ≈ val). The constraint forces learning.

Fourier-over-vocab parameterization is a structural bottleneck. v3_fourier_linattn and v5_fft_linattn both constrain their linear-attention projections to linear combinations of sin/cos over vocab indices. Both got stuck. Vocab ids from BPE have no meaningful ordering, so “smooth over vocab ids” throws away useful capacity. The linear-attention core itself works fine — see v9_linattn, which uses dense projections on the same core.

Attention in vocab space is expensive and unhelpful at this scale. v1_shared_attn spends most of its 3.4M params on Q/K/V/O projections over V=1024 vectors and still plateaus at val_loss ~6.06. The overhead isn’t justified.

Training instability is a real problem. v7_soft_ops had two catastrophic loss spikes (9.35 at step 161, 8.28 at step 181) before recovering. The soft op-selection path is fragile.

Inspirations

  • Linear Genetic Programming — register machines, sequential cheap operations
  • Tangled Program Graphs — hard bidding, multi-timescale memory
  • Neural GPU (Kaiser 2016) — repeated convolution learns algorithms
  • Deep Equilibrium Models (Bai 2019) — weight-shared iteration to convergence
  • Mamba (Gu & Dao 2023) — data-dependent state transitions
  • RWKV — linear attention with causal decay
  • Hyena — input-dependent long convolutions

Agent Skill

This repo’s conventions are available as portable agent skills in skills/.

License

Apache-2.0