Pretraining Recurrent Networks without Recurrence
基本信息
- 标题: Pretraining Recurrent Networks without Recurrence
- 第一作者: Akarsh Kumar (MIT)
- 研究团队: mit
- 会议/期刊: arXiv Preprint 2026
- 代码: https://github.com/akarshkumar0101/smt
- PDF 文件: [Pretraining Recurrent Networks without Recurrence](file:///C:/Users/admin/.openclaw/workspace/attachment/papers/20260604_pretraining_recurrent_networks_without_recurrence.pdf)
研究摘要 (Research Summary)
Recurrent neural networks (RNNs) have long occupied a peculiar position in the deep learning landscape. On one hand, their fixed-size memory state and constant-time sequential update offer an elegant computational model that mirrors biological cognition—our brains do not store every sensory impression we have ever experienced, but instead compress the past into a compact, evolving representation. On the other hand, training these networks has remained stubbornly difficult. The canonical algorithm, backpropagation through time (BPTT), requires unrolling the recurrent computation graph across the entire sequence, making the training process inherently sequential and exposing gradients to a long and treacherous path that stretches across up to
This paper by Akarsh Kumar and Phillip Isola confronts this problem with a reframing that is as conceptually bold as it is technically elegant. Rather than attempting to patch BPTT with architectural gating mechanisms or orthogonal weight parameterizations, they ask whether recurrent credit propagation is necessary at all. Their answer is no—or at least, it can be bypassed during the critical pretraining phase. The authors propose Supervised Memory Training (SMT), a method that sidesteps the entire machinery of BPTT by reducing RNN training to a supervised learning problem on one-step memory transitions. The core intuition is disarmingly simple: if we knew the optimal memory state
To operationalize this insight, SMT employs a Transformer-based encoder-decoder pair as a "teacher." The encoder compresses the past context into a compact memory representation, while the decoder uses this memory to predict future outputs. The encoder is trained on a predictive state objective: it must produce a memory that is sufficient for predicting the future output sequence, given the future input sequence. Once this teacher model has learned to construct such memories, the RNN is trained to predict the next memory state from the current one, using standard supervised objectives. The RNN is never unrolled during training; it learns only one-step transitions, yet it acquires the capacity to roll out autoregressively during inference. This decoupling of memory representation (what to remember) from memory dynamics (how to update memory) is the paper's central conceptual contribution. Because the teacher is time-parallel, SMT inherits the Transformer's
The paper's empirical findings are equally compelling. On synthetic tasks designed to isolate specific properties—gradient stability, memory capacity, state tracking, associative recall, and in-context learning—SMT consistently outperforms BPTT, often dramatically so as sequences grow longer. On naturalistic tasks, including character-level language modeling on TinyStories and pixel sequence modeling on MNIST and Sketchy, SMT-trained RNNs achieve better performance with less sequential computation. A secondary contribution, DAgger Memory Training (DMT), addresses the train-test mismatch that arises when the RNN uses its own predicted memories rather than the teacher's memories at inference time. DMT fine-tunes the RNN by unrolling it and training on its own induced states, providing a lightweight correction mechanism that mitigates the accumulation of prediction errors over time.
Beyond its immediate utility as a training algorithm, SMT offers a broader conceptual reorientation. It suggests that the difficulty of training RNNs is not intrinsic to the recurrent architecture itself, but rather an artifact of the training algorithm. By treating the past as a set and the future as a prediction task, SMT aligns recurrent modeling with the powerful parallelization capabilities of modern hardware. It also opens the door to scaling laws along a new axis: memory compression. The authors demonstrate that SMT can achieve better compression of the past into smaller memory states when given more compute, a property that speaks directly to the efficiency of intelligent systems. In the landscape of sequence modeling, where Transformers have dominated due to their parallelizability, SMT offers a credible pathway for nonlinear RNNs to reclaim their place as a viable, efficient, and scalable alternative.
理论框架 (Theoretical Framework)
Intellectual Lineage and Motivation
The theoretical foundations of this paper sit at the intersection of several long-standing threads in machine learning and dynamical systems theory. The problem of credit assignment in recurrent networks—how to attribute an error at the output to the parameters that influenced it many timesteps earlier—has been recognized since the earliest days of connectionism. Rumelhart, Hinton, and Williams (1986) introduced backpropagation through time as the canonical solution, and Paul Werbos (1990) elaborated its implementation. Yet the limitations of BPTT were apparent almost immediately. Bengio, Simard, and Frasconi (1994) famously proved that learning long-term dependencies with gradient descent is difficult because gradients tend to vanish or explode as they propagate through the recurrent nonlinear dynamics. This observation sparked decades of architectural innovations—LSTM gates, GRU gating, orthogonal and unitary weight parameterizations, residual connections—all designed to preserve gradient magnitudes across time. While these modifications helped, they did not eliminate the fundamental problem: the credit path length remains
Kumar and Isola's work draws a different line of intellectual descent, one that traces back to the theory of predictive state representations (PSRs) in reinforcement learning and control theory. Littman and Sutton (2001) introduced PSRs as a way to represent the state of a partially observable dynamical system entirely in terms of predictions about future observations. A PSR captures the idea that the state of a system is not its internal hidden configuration, but rather the set of probabilities it assigns to future outcomes. This is a profound epistemological shift: the state is defined by what it enables you to predict, not by what it is. Belief states in partially observable Markov decision processes share this flavor, representing a sufficient statistic of the past for optimal decision-making (Kaelbling, Littman, and Cassandra, 1998). The authors explicitly connect their work to this tradition, stating that an effective memory is a sufficient statistic of the past for predicting the future—a predictive state.
Sequence-to-Set Reframing: The Key Theoretical Insight
The paper's most elegant theoretical move is the reparameterization of sequential computation as a set-based computation. In the standard formulation, an RNN processes an ordered sequence
This result is deceptively simple but carries profound implications. It means that the computational problem of computing a memory from the past is not inherently sequential—it can be computed in parallel over all past tokens, provided the model has access to the timestamps. Transformers, which are permutation-invariant set models with positional encodings, are natural candidates for this computation. The authors acknowledge that this reframing may require model depth to scale with sequence length for full theoretical expressivity (consistent with the circuit complexity limitations of constant-depth parallel models, as analyzed by Merrill and Sabharwal, 2023), but their empirical results demonstrate that even relatively shallow Transformers learn highly effective memory representations.
Predictive State Formulation
The authors formalize the predictive state objective as follows. Let
This is essentially a conditional sequence modeling problem where the conditioning variable is the compressed memory
The decoder ensures that
Dynamics and Uniformity Losses
The RNN
This dynamics loss serves two purposes: it trains the RNN, and it explicitly shapes the encoder memory representations to be approximately Markovian. In practice, the authors find it beneficial to jointly train the encoder, decoder, and RNN with a combined objective that includes a third term, a uniformity loss borrowed from contrastive learning:
The uniformity loss prevents the memory space from collapsing—without it, the encoder might learn to map all contexts to the same memory vector, which would be trivially Markovian but useless for prediction. The full objective is a weighted sum:
The authors note that theoretically, it should be sufficient to train the encoder-decoder with only
Theoretical Properties and Limitations
The theoretical framework of SMT inherits both the strengths and limitations of its teacher Transformer. Transformers are time-parallel and have
技术架构 (Technical Architecture)
System Overview
The complete SMT system comprises three distinct neural networks that interact in a teacher-student relationship: a bidirectional Transformer encoder, a causal Transformer decoder, and the recurrent neural network (the student) that is the ultimate target of training. The encoder and decoder together form the "teacher" that generates the training labels for the RNN. The RNN itself consists of an updater network (the recurrent transition function) and a readout network (the output head). This triad of models is trained in a single joint stage, though conceptually the objectives can be decomposed into representation learning (encoder), future prediction (decoder), and dynamics imitation (RNN).
The encoder architecture is a bidirectional Transformer that processes the input context tokens along with a set of learned memory register tokens. The input tokens and register tokens are concatenated and fed through a stack of bidirectional Transformer blocks (with full attention masks). The register tokens at the output are interpreted as the memory state
The decoder is a causal Transformer that takes the memory tokens from the encoder and the embedded future input tokens, concatenates them, and processes them through a stack of causally masked Transformer blocks. The causal mask ensures that the decoder can only attend to previous tokens when making predictions, enforcing the autoregressive generative structure. The output predictions are read out at positions such that
The RNN architecture used in most experiments is a Transformer-based recurrent network. At each timestep, the current memory tokens
Data Flow and Training Dynamics
The data flow during training is architecturally distinct from BPTT. In BPTT, the input sequence is fed token by token into the RNN, which generates a memory trajectory
This architectural difference has profound implications for computational efficiency. In BPTT, the memory footprint scales as
DAgger Memory Training: Correcting the Train-Test Mismatch
A critical engineering challenge arises from the train-test mismatch. During SMT, the RNN learns to predict
The authors address this with DAgger Memory Training (DMT), inspired by the DAgger algorithm for imitation learning. DMT is a fine-tuning phase where the RNN is unrolled using its own predicted memories, and then trained to imitate the encoder's memory at each step. The training labels are now
Key Architectural Innovations
The most important architectural innovation is the use of register tokens for memory representation. Rather than compressing the past into a single vector, the encoder produces a set of memory vectors, each potentially capturing different aspects of the context. This set-based memory structure is preserved through the RNN transition, with the Transformer-based updater attending across all memory tokens and the input token. This design allows the memory to have a more complex, structured geometry than a simple vector state, which is crucial for tasks requiring state tracking or associative recall. The authors visualize these memory spaces in 2D and 3D, showing that different tasks induce qualitatively different memory geometries: retrieval tasks collapse the sequence into a few discrete states (finite-state-machine-like behavior), while string copying tasks require a tree-like structure to preserve all possible sequences without aliasing.
实验评估 (Experimental Evaluation)
Experimental Strategy and Datasets
The experimental design is structured to systematically validate the theoretical claims of SMT. The authors first evaluate on five synthetic tasks designed to isolate specific properties of the training algorithm: gradient stability (Retrieval), memory capacity (String Copy), state tracking (Stack Operations), associative recall (Keys-Values), and in-context learning (Modular Arithmetic). Each task has a controllable difficulty parameter (sequence length, memory size, state complexity, number of associations, or number of in-context examples) that allows the authors to stress-test the algorithms along specific axes. The RNN architecture for these experiments uses the Transformer backbone, and both the context length
For naturalistic tasks, the authors evaluate on character-level language modeling using TinyStories, a curated dataset of short stories generated by GPT-4. This task requires long-range memory for narrative coherence, as character names and plot threads must be tracked across thousands of characters. More challengingly, they evaluate on pixel sequence modeling of sparse images from MNIST and Sketchy, following raster-scan order. The authors term this "Attneave's task," referencing Fred Attneave's classic work on visual perception. The task is deliberately difficult for RNNs: when processing a pixel sequence, the model must remember the locations and shapes of strokes seen hundreds of timesteps earlier, buried among black pixels, in order to predict the next pixel. RNNs cannot attend directly to earlier pixels and must instead compress all relevant spatial information into a fixed memory state.
Synthetic Task Results
The synthetic task results reveal a stark and consistent pattern: SMT
On the stack operations task, which tests state tracking, BPTT fails to track deeper stacks as the sequence length increases, while SMT successfully tracks the stack state. This is a significant result because state tracking is a known weakness of linear RNN models and constant-depth Transformers, which are limited by their circuit complexity. The nonlinear RNN trained with SMT, by contrast, inherits the expressivity of nonlinear dynamics while benefiting from stable training. On the modular arithmetic task, which tests in-context learning, SMT again outperforms BPTT, demonstrating that the method is capable of learning to infer latent rules from examples and apply them to novel inputs.
Natural Task Results and Efficiency Analysis
The natural task results are equally compelling. On TinyStories, the SMT-trained encoder and the SMT
The pixel sequence modeling results are visually striking. Samples generated by BPTT-trained RNNs, even with GRU architectures, fail to capture the long-range structure of handwritten digits. The generated images show either large streaks of white or black pixels based only on local context, or blurry, incoherent structures. In contrast, SMT
| Task | Metric | BPTT RNN | SMT→DMT RNN | SMT Encoder* |
|---|---|---|---|---|
| Retrieval (T=512, noise=0.3) | Test Loss | High | Low | Low |
| String Copy (T=4096, M=64) | Test Loss | Fails | Succeeds | Succeeds |
| Stack Ops (T=64, depth=5) | Test Loss | Fails | Succeeds | Succeeds |
| TinyStories (seq len 256) | Test Loss | Comparable | Comparable | Lower |
| MNIST (seq len 784) | Test Loss | High | Low | Low |
| Sketchy (seq len 1024) | Test Loss | Very High | Moderate | Moderate |
*SMT Encoder is the Transformer teacher, not an RNN; shown as reference upper bound.
Scaling Behavior
The authors evaluate scaling along three axes: context length, memory state size, and model parameter count. On TinyStories, SMT
A particularly novel scaling analysis is performed along the compression axis. The authors train SMT models across a sweep of memory state sizes and training compute budgets, plotting iso-loss contours. The results confirm that SMT can achieve more compression (better performance with smaller memory) when allocated more compute. This suggests a new scaling law for sequence models: memory state compression. The authors note that Transformers perform no compression of the past (their memory grows with sequence length), which may explain their training efficiency but also their inference costs. RNNs, by contrast, compress the past into a fixed-size state, and SMT provides a principled way to learn this compression.
案例研究 (Case Studies)
Memory Space Visualization: What the Encoder Learns
To understand what SMT is actually learning, the authors train smaller models with 2D memory states and visualize the memory space across three synthetic tasks. In the retrieval task with two possible needles, the encoder learns to collapse the infinite sequence of possible pasts into only three effective memory states: an initial state, a state indicating that the next token is the needle, and states corresponding to the needle values. The RNN then learns finite-state-machine behavior to transition between these states. This is a remarkable finding: the encoder has discovered a minimal state representation that captures the task structure, and the RNN has learned to implement the transitions of a finite state machine.
In contrast, the string copying task requires lossless sequence compression because every sequence must be reproduced exactly in reverse. The encoder cannot alias distinct memory states together; it must preserve enough information to uniquely identify the input sequence. The visualization reveals a tree-like memory geometry, where each branch corresponds to a different input token, matching the tree structure of all possible strings. This demonstrates that the encoder adapts its memory geometry to the requirements of the task: when the future requires only categorical information (which needle was seen), it collapses the state space; when the future requires lossless reconstruction (copy the string), it expands into a tree structure.
Gradient Properties: The Fundamental Difference
The authors analyze the gradient properties of BPTT and SMT on the needle retrieval task, where the loss is applied only at the last timestep. In BPTT, the gradient magnitude of
DMT and Drift Dynamics
The authors analyze the drift phenomenon and DMT's correction in detail. Before DMT, the RNN's rollout error
Sequence Length Generalization
The authors compare an SMT
综合价值与局限 (Synthesis — Value and Limitations)
Theoretical Significance
SMT changes how we think about the relationship between sequence modeling and parallel computation. It demonstrates that the sequentiality of RNN training is not an intrinsic property of recurrent architectures, but rather a consequence of the training algorithm. By reframing the past as a set of timestamped events, the paper opens a theoretical pathway for time-parallel training of any model that computes a sufficient statistic of the past. The concept of predictive states—memories that are defined by what they predict rather than by how they are computed—provides a new conceptual tool for thinking about memory, compression, and representation learning in dynamical systems.
The paper also makes a valuable contribution to the understanding of gradient stability in deep learning. The
Practical Impact and Strengths
The practical strengths of SMT are clear. It enables time-parallel training of nonlinear RNNs, which has been a long-standing goal in the field. It provides stable gradients for learning long-range dependencies, outperforming BPTT on tasks that require credit assignment across hundreds or thousands of timesteps. It reduces the sequential computation required for training, which is the primary bottleneck on modern parallel hardware. It also provides a principled method for pretraining RNNs, producing strong initializations that can be fine-tuned with lightweight post-training.
The empirical results on pixel sequence modeling are particularly impressive. Generating coherent MNIST digits and Sketchy sketches from a raster-scan pixel sequence is a task that fundamentally requires long-range memory, as the model must integrate shape information across the entire image. That SMT
Limitations and Honest Weaknesses
The authors are commendably transparent about the limitations of their method. First, the teacher Transformer is time-parallel and thus limited in expressive power compared to unbounded sequential computation. This means SMT-trained RNNs may inherit the same limitations and may require BPTT finetuning to achieve full expressivity beyond the teacher. This is not a fatal flaw—the authors position SMT as a pretraining method—but it means that SMT alone does not eliminate all need for BPTT.
Second, SMT is useful for learning how to encode sequences, but not necessarily for learning complex reasoning where intermediate steps are not supervised. The same limitation applies to Transformers, and post-training has proven effective for extending their capabilities; the authors suggest the same may be true for SMT-trained RNNs. Third, the current SMT variant computes and trains on only a single memory
Fourth, DMT, while effective, is not time-parallel and requires unrolling the RNN. It is a necessary but imperfect solution to the drift problem. The authors note that DMT could potentially be parallelized using DEER or similar techniques, but this remains future work. The one-step drift only partially correlates with rollout drift, suggesting that predicting and mitigating drift during SMT remains an open challenge.
Finally, the SMT-trained RNN on Sketchy generates images that capture stroke structure but are not always fully interpretable. This indicates that while the method significantly improves RNN performance, it does not fully solve the hardest sequence modeling problems.
Broader Implications
SMT sits at a fascinating intersection of trends in the field. It connects the parallelization power of Transformers with the memory efficiency of RNNs, potentially offering a path to sequence models that scale efficiently along both compute and memory axes. It also raises intriguing questions about the nature of intelligence and compression: the paper's finding that SMT can learn better compression with more compute suggests that memory efficiency is a learnable and scalable property, not merely an architectural constraint. In the context of lifelong learning and continuous agent experience, where the sequence length is unbounded and memory must be constant, SMT provides a principled framework for training models that can build and maintain temporal abstractions over a lifetime of experience.
延伸阅读与思考 (Further Reading and Reflection)
Prior Work and Foundations
This paper builds upon several foundational lines of research. The predictive state representation framework, introduced by Littman and Sutton (2001) and developed by Singh, James, and Rudary (2012), provides the conceptual backbone: the idea that state can be defined purely in terms of predictions about the future. Earlier work by Downey et al. (2017) incorporated PSRs into RNNs, but still used BPTT and thus was not time-parallelizable. The cross-architecture distillation literature, including Kasai et al. (2021) on finetuning Transformers into RNNs and Chen et al. (2026) on hybrid linear attention, provides methodological context, though these works do not address the fundamental credit assignment problem of training nonlinear RNNs without BPTT.
The concurrent work by Teoh et al. (2025) on Next-Latent Prediction (NextLat) is particularly relevant. With a specific hyperparameter setting, SMT and NextLat are equivalent. However, Teoh et al. focus on using the RNN as a regularizer for the Transformer, and their experiments primarily use BPTT for the latent representations (optionally truncated to
Related Approaches and Comparisons
Alternative approaches to time-parallel RNN training have emerged recently. The parallelization of linear RNNs through associative scan (Martin and Cundy, 2018; Dao and Gu, 2024) has enabled impressive results with models like Mamba, but these are fundamentally limited by the linearity of their transition function, which constrains the class of functions they can represent (Merrill et al., 2026). Recent attempts to parallelize nonlinear RNNs, such as DeepPCR (Danieli et al., 2023) and ParaRNN (Danieli et al., 2025), formulate the forward pass as an iterative optimization problem solved with Newton's method. While appealing, these approaches approximate BPTT and thus inherit its
Linear attention models and state space models (Katharopoulos et al., 2020; Gu, Goel, and Ré, 2021) also offer time-parallel training and fixed-size memory, but their linearity fundamentally limits their expressive power. The Recurrent Transformer (Oncescu et al., 2026) achieves
Future Directions and Open Problems
The paper opens several promising research directions. The most immediate is the development of fully time-parallel drift correction methods that can replace or augment DMT. If the RNN's rollout drift could be predicted and corrected without unrolling, the entire training pipeline would remain time-parallel, preserving SMT's computational advantages. Techniques from deep equilibrium models or parallel nonlinear solvers may be relevant here.
Another direction is the exploration of SMT for lifelong learning and continuous agent control. RNNs are natural candidates for agents that must learn and remember over unbounded time horizons, but training methods have been the primary obstacle. SMT's
The scaling laws for compression also invite further investigation. The paper shows that SMT can achieve better compression with more compute, but the functional form of this relationship and its limits remain unknown. Understanding how memory compression scales with model size, data, and compute could provide a new axis for designing efficient sequence models.
Finally, the theoretical question of full expressivity remains open. The sequence-to-set reframing shows that Transformer encoders can approximate recurrent memory functions, but may require depth to scale with sequence length for full theoretical expressivity. Understanding when and how shallow encoders succeed, and when deeper or different architectures are needed, would strengthen the theoretical foundations of the method.
Personal Reflection
What strikes me most about this paper is its conceptual clarity and boldness. The authors do not merely propose another architectural tweak or training trick; they fundamentally reframe the problem of RNN training. The idea that the past can be treated as a set rather than a sequence, and that memory can be learned through a predictive state objective, is both elegant and powerful. It reminds me of the kind of paradigm-shifting insight that characterized the original Transformer paper: instead of fighting the sequentiality of RNNs, the Transformer eliminated it by using attention. Kumar and Isola do something similar for RNN training: instead of fighting the instability of BPTT, they eliminate the need for recurrent credit propagation altogether.
The most thought-provoking aspect is the decoupling of memory representation from memory dynamics. This separation feels deeply right: it mirrors how we think about human cognition, where the formation of memories (encoding) is distinct from the maintenance and updating of memories (consolidation). It also suggests a general principle for training dynamical systems: if you can learn what the state should be, learning how to transition between states becomes much simpler. I find myself wondering how this principle could be extended beyond sequence modeling to other domains involving temporal dynamics, such as physical simulation, control systems, or even the training of recurrent connections in biological neural networks. This paper does not just solve a technical problem; it opens a new way of thinking about time, memory, and learning.
Topics:
- "memory_mechanism"
- "recurrent_neural_networks"
- "llm"
- "reinforce_learning"
- "neuro_science"
References: - "mit"
- "supervised_memory_training"
- "backpropagation_through_time"