Pretraining Recurrent Networks without Recurrence

基本信息


研究摘要 (Research Summary)

Recurrent neural networks (RNNs) have long occupied a peculiar position in the deep learning landscape. On one hand, their fixed-size memory state and constant-time sequential update offer an elegant computational model that mirrors biological cognition—our brains do not store every sensory impression we have ever experienced, but instead compress the past into a compact, evolving representation. On the other hand, training these networks has remained stubbornly difficult. The canonical algorithm, backpropagation through time (BPTT), requires unrolling the recurrent computation graph across the entire sequence, making the training process inherently sequential and exposing gradients to a long and treacherous path that stretches across up to O(T) timesteps. This path length is not merely an engineering inconvenience; it is a fundamental bottleneck that limits both parallelism and the network's capacity to learn long-range associations, as gradients may vanish or explode depending on the spectral properties of the recurrent transition Jacobian.

This paper by Akarsh Kumar and Phillip Isola confronts this problem with a reframing that is as conceptually bold as it is technically elegant. Rather than attempting to patch BPTT with architectural gating mechanisms or orthogonal weight parameterizations, they ask whether recurrent credit propagation is necessary at all. Their answer is no—or at least, it can be bypassed during the critical pretraining phase. The authors propose Supervised Memory Training (SMT), a method that sidesteps the entire machinery of BPTT by reducing RNN training to a supervised learning problem on one-step memory transitions. The core intuition is disarmingly simple: if we knew the optimal memory state mt at every timestep t, then training the RNN would reduce to learning the mapping (mt,xt+1)mt+1, a straightforward supervised regression task that requires no temporal backpropagation. The challenge, of course, is obtaining these optimal memory labels. Kumar and Isola's key insight is that the optimal memory can be characterized as a predictive state—a sufficient statistic of the past that retains only the information necessary to predict the future. Crucially, by augmenting each observation with its timestamp, the past can be losslessly represented as a set of timestamped events rather than as an ordered sequence, making the optimal memory a permutation-invariant function computable in parallel over time.

To operationalize this insight, SMT employs a Transformer-based encoder-decoder pair as a "teacher." The encoder compresses the past context into a compact memory representation, while the decoder uses this memory to predict future outputs. The encoder is trained on a predictive state objective: it must produce a memory that is sufficient for predicting the future output sequence, given the future input sequence. Once this teacher model has learned to construct such memories, the RNN is trained to predict the next memory state from the current one, using standard supervised objectives. The RNN is never unrolled during training; it learns only one-step transitions, yet it acquires the capacity to roll out autoregressively during inference. This decoupling of memory representation (what to remember) from memory dynamics (how to update memory) is the paper's central conceptual contribution. Because the teacher is time-parallel, SMT inherits the Transformer's O(1) credit path length between any two tokens, eliminating the gradient stability problems that plague BPTT. The longest gradient path in SMT is independent of sequence length, making long-range credit assignment qualitatively easier.

The paper's empirical findings are equally compelling. On synthetic tasks designed to isolate specific properties—gradient stability, memory capacity, state tracking, associative recall, and in-context learning—SMT consistently outperforms BPTT, often dramatically so as sequences grow longer. On naturalistic tasks, including character-level language modeling on TinyStories and pixel sequence modeling on MNIST and Sketchy, SMT-trained RNNs achieve better performance with less sequential computation. A secondary contribution, DAgger Memory Training (DMT), addresses the train-test mismatch that arises when the RNN uses its own predicted memories rather than the teacher's memories at inference time. DMT fine-tunes the RNN by unrolling it and training on its own induced states, providing a lightweight correction mechanism that mitigates the accumulation of prediction errors over time.

Beyond its immediate utility as a training algorithm, SMT offers a broader conceptual reorientation. It suggests that the difficulty of training RNNs is not intrinsic to the recurrent architecture itself, but rather an artifact of the training algorithm. By treating the past as a set and the future as a prediction task, SMT aligns recurrent modeling with the powerful parallelization capabilities of modern hardware. It also opens the door to scaling laws along a new axis: memory compression. The authors demonstrate that SMT can achieve better compression of the past into smaller memory states when given more compute, a property that speaks directly to the efficiency of intelligent systems. In the landscape of sequence modeling, where Transformers have dominated due to their parallelizability, SMT offers a credible pathway for nonlinear RNNs to reclaim their place as a viable, efficient, and scalable alternative.

理论框架 (Theoretical Framework)

Intellectual Lineage and Motivation

The theoretical foundations of this paper sit at the intersection of several long-standing threads in machine learning and dynamical systems theory. The problem of credit assignment in recurrent networks—how to attribute an error at the output to the parameters that influenced it many timesteps earlier—has been recognized since the earliest days of connectionism. Rumelhart, Hinton, and Williams (1986) introduced backpropagation through time as the canonical solution, and Paul Werbos (1990) elaborated its implementation. Yet the limitations of BPTT were apparent almost immediately. Bengio, Simard, and Frasconi (1994) famously proved that learning long-term dependencies with gradient descent is difficult because gradients tend to vanish or explode as they propagate through the recurrent nonlinear dynamics. This observation sparked decades of architectural innovations—LSTM gates, GRU gating, orthogonal and unitary weight parameterizations, residual connections—all designed to preserve gradient magnitudes across time. While these modifications helped, they did not eliminate the fundamental problem: the credit path length remains O(T), and the training process remains sequential.

Kumar and Isola's work draws a different line of intellectual descent, one that traces back to the theory of predictive state representations (PSRs) in reinforcement learning and control theory. Littman and Sutton (2001) introduced PSRs as a way to represent the state of a partially observable dynamical system entirely in terms of predictions about future observations. A PSR captures the idea that the state of a system is not its internal hidden configuration, but rather the set of probabilities it assigns to future outcomes. This is a profound epistemological shift: the state is defined by what it enables you to predict, not by what it is. Belief states in partially observable Markov decision processes share this flavor, representing a sufficient statistic of the past for optimal decision-making (Kaelbling, Littman, and Cassandra, 1998). The authors explicitly connect their work to this tradition, stating that an effective memory is a sufficient statistic of the past for predicting the future—a predictive state.

Sequence-to-Set Reframing: The Key Theoretical Insight

The paper's most elegant theoretical move is the reparameterization of sequential computation as a set-based computation. In the standard formulation, an RNN processes an ordered sequence [x0,x1,...,xt] through a recurrent function Q that accumulates information one token at a time: mt=Q(xt)=f(...f(f(m,x0),x1),...,xt). This formulation is inherently sequential because the function f is applied iteratively, and the order of application matters. The authors prove, however, that any such recurrent function Q can be equivalently represented as a permutation-invariant function g over the set of timestamped tokens {(x0,0),(x1,1),...,(xt,t)}. The proof is constructive: given the set, sort the elements by their timestamps to recover the original sequence, then apply Q. Since the timestamps are unique, the sort order is deterministic, and the function g is permutation-invariant because any reordering of the set elements yields the same sorted sequence.

This result is deceptively simple but carries profound implications. It means that the computational problem of computing a memory from the past is not inherently sequential—it can be computed in parallel over all past tokens, provided the model has access to the timestamps. Transformers, which are permutation-invariant set models with positional encodings, are natural candidates for this computation. The authors acknowledge that this reframing may require model depth to scale with sequence length for full theoretical expressivity (consistent with the circuit complexity limitations of constant-depth parallel models, as analyzed by Merrill and Sabharwal, 2023), but their empirical results demonstrate that even relatively shallow Transformers learn highly effective memory representations.

Predictive State Formulation

The authors formalize the predictive state objective as follows. Let x=[x0,...,xT] and y=[y0,...,yT] denote input and output sequences. At each timestep t, the past is decomposed into context xtctx=[x0,...,xt] and future inputs xtfut=[xt+1,...,xT], with corresponding future outputs ytfut=[yt,...,yT]. The encoder Eϕ maps each context to a memory state mt=Eϕ(xtctx), and the decoder Dψ predicts the future output distribution using the memory and teacher-forced future inputs:

pϕ,ψ(ytfut|xtctx,xtfut)=τ=tTpψ(yτ|mt,xt+1:τ)=Dψ(mt,xtfut)

This is essentially a conditional sequence modeling problem where the conditioning variable is the compressed memory mt. The encoder-decoder pair is trained with a cross-entropy loss on the future outputs:

Ltdec=CE(ytfut,pϕ,ψ(ytfut|xtctx,xtfut))

The decoder ensures that mt contains sufficient information for predicting the future, while the encoder ensures that this information is compressed into a fixed-size state. The authors also prove, in Appendix F, that if mt is an optimal minimal sufficient statistic for predicting the future, then the memory sequence (mt) is Markovian—that is, mt+1 is conditionally independent of the past context xt given (mt,xt+1). This means that the optimal memory can be predicted from only the previous memory and the current input, which is exactly the RNN transition structure.

Dynamics and Uniformity Losses

The RNN fθ is trained to predict the next memory state from the current one: m^t+1=fθ(mt,xt+1). The supervision signal comes from the encoder's memory labels:

Ltdyn=MSE(m^t+1,mt+1)

This dynamics loss serves two purposes: it trains the RNN, and it explicitly shapes the encoder memory representations to be approximately Markovian. In practice, the authors find it beneficial to jointly train the encoder, decoder, and RNN with a combined objective that includes a third term, a uniformity loss borrowed from contrastive learning:

Ltunif=logEta,tb[0,...,T]exp(2mtamtb2)

The uniformity loss prevents the memory space from collapsing—without it, the encoder might learn to map all contexts to the same memory vector, which would be trivially Markovian but useless for prediction. The full objective is a weighted sum:

Lsmt=λdecEt[Ltdec]+λdynEt[Ltdyn]+λunifLunif

The authors note that theoretically, it should be sufficient to train the encoder-decoder with only Ldec and the RNN with only Ldyn, but joint training provides practical benefits by explicitly optimizing the Markovian property and enabling additional temporal credit propagation through the RNN's parameters during the encoder's training phase.

Theoretical Properties and Limitations

The theoretical framework of SMT inherits both the strengths and limitations of its teacher Transformer. Transformers are time-parallel and have O(1) credit path length, but they are also limited in expressive power compared to unbounded sequential computation. Merrill and Sabharwal (2023) showed that constant-depth Transformers are limited to problems in the complexity class TC0, which excludes problems requiring deep sequential computation such as tracking the full state of a chess game. This implies that SMT-trained RNNs may inherit the same expressive limitations from their teacher. However, the authors argue that SMT is intended as a pretraining algorithm; a lightweight post-training phase (such as BPTT finetuning) can push the RNN beyond the teacher's limitations. This is a crucial and honest acknowledgment of the method's boundaries: SMT provides a strong initialization and a stable learning signal, but it does not eliminate the need for all recurrent training.

技术架构 (Technical Architecture)

System Overview

The complete SMT system comprises three distinct neural networks that interact in a teacher-student relationship: a bidirectional Transformer encoder, a causal Transformer decoder, and the recurrent neural network (the student) that is the ultimate target of training. The encoder and decoder together form the "teacher" that generates the training labels for the RNN. The RNN itself consists of an updater network (the recurrent transition function) and a readout network (the output head). This triad of models is trained in a single joint stage, though conceptually the objectives can be decomposed into representation learning (encoder), future prediction (decoder), and dynamics imitation (RNN).

The encoder architecture is a bidirectional Transformer that processes the input context tokens along with a set of learned memory register tokens. The input tokens and register tokens are concatenated and fed through a stack of bidirectional Transformer blocks (with full attention masks). The register tokens at the output are interpreted as the memory state mt=[mt1,...,mtM]. The use of register tokens, rather than taking the output of a single token position, allows the memory to be structured as a set of vectors, analogous to the multiple slots in external memory architectures or the multiple heads in attention mechanisms. The encoder uses rotary position encodings (RoPE) and RMSNorm for stability.

The decoder is a causal Transformer that takes the memory tokens from the encoder and the embedded future input tokens, concatenates them, and processes them through a stack of causally masked Transformer blocks. The causal mask ensures that the decoder can only attend to previous tokens when making predictions, enforcing the autoregressive generative structure. The output predictions are read out at positions such that y^t+k depends only on the memory mt and the future inputs xt+1,...,xt+k. This architecture forces the encoder to compress all relevant information from the past into the memory, because the decoder has no other access to the context once the memory tokens are passed in.

The RNN architecture used in most experiments is a Transformer-based recurrent network. At each timestep, the current memory tokens mt and the new input token xt+1 are concatenated and processed through a stack of bidirectional Transformer blocks to produce the next memory tokens mt+1. This is a departure from traditional RNNs that use gated recurrent units or simple matrix multiplications; instead, the transition function is itself a Transformer, operating on the memory as a set of tokens. The authors also experiment with MLP-based and GRU-based RNNs, where the memory is flattened into a single vector before being processed. The readout function is a separate bidirectional Transformer that processes the memory tokens to produce the output distribution, regardless of the RNN transition architecture.

Data Flow and Training Dynamics

The data flow during training is architecturally distinct from BPTT. In BPTT, the input sequence is fed token by token into the RNN, which generates a memory trajectory m0,m1,...,mT sequentially. The output predictions are computed at each timestep, and gradients flow backward from the final loss through the entire unrolled graph. In SMT, the encoder processes the entire past context in parallel, producing a memory mt at a sampled timestep t. The decoder then predicts the future outputs using this memory and the future inputs. The RNN, meanwhile, predicts the next memory from the current one, supervised by the encoder's memory label. There is no unrolling; the RNN's prediction error is computed and backpropagated only through the one-step transition.

This architectural difference has profound implications for computational efficiency. In BPTT, the memory footprint scales as O(MT) where M is the memory size, because the entire unrolled trajectory must be stored for gradient computation. In SMT, the memory footprint is O(M+Tc) where Tc is the encoder context length, because only a single memory state and the encoder's context window need to be stored. The sequential computation (measured in sequential FLOPs, the amount of work that must be done serially even on an infinitely parallel machine) is O(T) per optimization step for BPTT, but only O(1) for SMT, because the encoder and decoder can operate in parallel over time. This makes SMT fundamentally better suited to modern parallel hardware, where the limiting factor is often the amount of sequential computation rather than total computation.

DAgger Memory Training: Correcting the Train-Test Mismatch

A critical engineering challenge arises from the train-test mismatch. During SMT, the RNN learns to predict mt+1 from (mt,xt+1) where mt comes from the encoder. At inference time, however, the RNN uses its own predicted memory m^t as input. Small errors in the one-step prediction accumulate over time, causing the RNN's memory trajectory to drift away from the encoder trajectory. This drift is quantified as δt=MSE(m^t,mt), and it grows as the sequence progresses.

The authors address this with DAgger Memory Training (DMT), inspired by the DAgger algorithm for imitation learning. DMT is a fine-tuning phase where the RNN is unrolled using its own predicted memories, and then trained to imitate the encoder's memory at each step. The training labels are now (m^t,xt+1)mt+1, rather than (mt,xt+1)mt+1. During DMT, the encoder and decoder are frozen, and only the RNN is trained, with a small learning rate. This on-policy imitation learning exposes the RNN to its own induced state distribution and teaches it to auto-correct its errors. Although DMT is not time-parallel (it requires unrolling the RNN), it is a lightweight post-processing phase that follows the heavy SMT pretraining. The authors also note that DMT could potentially be parallelized using techniques like DEER (Deep Equilibrium Estimation and Recovery), though this is not explored in the current work.

Key Architectural Innovations

The most important architectural innovation is the use of register tokens for memory representation. Rather than compressing the past into a single vector, the encoder produces a set of memory vectors, each potentially capturing different aspects of the context. This set-based memory structure is preserved through the RNN transition, with the Transformer-based updater attending across all memory tokens and the input token. This design allows the memory to have a more complex, structured geometry than a simple vector state, which is crucial for tasks requiring state tracking or associative recall. The authors visualize these memory spaces in 2D and 3D, showing that different tasks induce qualitatively different memory geometries: retrieval tasks collapse the sequence into a few discrete states (finite-state-machine-like behavior), while string copying tasks require a tree-like structure to preserve all possible sequences without aliasing.

实验评估 (Experimental Evaluation)

Experimental Strategy and Datasets

The experimental design is structured to systematically validate the theoretical claims of SMT. The authors first evaluate on five synthetic tasks designed to isolate specific properties of the training algorithm: gradient stability (Retrieval), memory capacity (String Copy), state tracking (Stack Operations), associative recall (Keys-Values), and in-context learning (Modular Arithmetic). Each task has a controllable difficulty parameter (sequence length, memory size, state complexity, number of associations, or number of in-context examples) that allows the authors to stress-test the algorithms along specific axes. The RNN architecture for these experiments uses the Transformer backbone, and both the context length Tc and future length Tf are set equal to the full sequence length T.

For naturalistic tasks, the authors evaluate on character-level language modeling using TinyStories, a curated dataset of short stories generated by GPT-4. This task requires long-range memory for narrative coherence, as character names and plot threads must be tracked across thousands of characters. More challengingly, they evaluate on pixel sequence modeling of sparse images from MNIST and Sketchy, following raster-scan order. The authors term this "Attneave's task," referencing Fred Attneave's classic work on visual perception. The task is deliberately difficult for RNNs: when processing a pixel sequence, the model must remember the locations and shapes of strokes seen hundreds of timesteps earlier, buried among black pixels, in order to predict the next pixel. RNNs cannot attend directly to earlier pixels and must instead compress all relevant spatial information into a fixed memory state.

Synthetic Task Results

The synthetic task results reveal a stark and consistent pattern: SMT DMT outperforms BPTT across all tasks and all settings. BPTT struggles as sequences grow longer, even on the simple retrieval task, which requires only remembering and reproducing a token seen earlier. This failure is attributed to BPTT's O(T) credit path length and the resulting gradient instability. In contrast, SMT appears largely agnostic to sequence length, solving all tasks except associative recall with ease. The failure on associative recall is attributed to the complexity of the task itself rather than the training method; even the SMT encoder (the Transformer teacher) shows difficulty with the most complex associative recall settings.

On the stack operations task, which tests state tracking, BPTT fails to track deeper stacks as the sequence length increases, while SMT successfully tracks the stack state. This is a significant result because state tracking is a known weakness of linear RNN models and constant-depth Transformers, which are limited by their circuit complexity. The nonlinear RNN trained with SMT, by contrast, inherits the expressivity of nonlinear dynamics while benefiting from stable training. On the modular arithmetic task, which tests in-context learning, SMT again outperforms BPTT, demonstrating that the method is capable of learning to infer latent rules from examples and apply them to novel inputs.

Natural Task Results and Efficiency Analysis

The natural task results are equally compelling. On TinyStories, the SMT-trained encoder and the SMT DMT RNN achieve comparable data efficiency to BPTT with Transformer and MLP backbones, but with significantly lower sequential computation. On MNIST, SMT and SMT DMT show substantially better data efficiency than BPTT. The authors attribute this difference to the short-range vs. long-range memory requirements of natural language versus pixel sequence modeling. Natural language has relatively local dependencies (words depend on nearby words), while pixel sequences require integrating information over hundreds of timesteps to recognize shapes and strokes. This is precisely the regime where SMT's O(1) credit path length provides the greatest advantage.

The pixel sequence modeling results are visually striking. Samples generated by BPTT-trained RNNs, even with GRU architectures, fail to capture the long-range structure of handwritten digits. The generated images show either large streaks of white or black pixels based only on local context, or blurry, incoherent structures. In contrast, SMT DMT-trained RNNs generate recognizable digits with correct stroke structure, demonstrating that they have successfully learned to maintain long-range memory of the digit shape across the entire 784-pixel sequence. On the more challenging Sketchy dataset (64x64 binarized sketches, yielding 1024-pixel sequences), SMT DMT captures the overall stroke structure of human-drawn sketches, though the generated images are not always fully interpretable.

Task Metric BPTT RNN SMT→DMT RNN SMT Encoder*
Retrieval (T=512, noise=0.3) Test Loss High Low Low
String Copy (T=4096, M=64) Test Loss Fails Succeeds Succeeds
Stack Ops (T=64, depth=5) Test Loss Fails Succeeds Succeeds
TinyStories (seq len 256) Test Loss Comparable Comparable Lower
MNIST (seq len 784) Test Loss High Low Low
Sketchy (seq len 1024) Test Loss Very High Moderate Moderate

*SMT Encoder is the Transformer teacher, not an RNN; shown as reference upper bound.

Scaling Behavior

The authors evaluate scaling along three axes: context length, memory state size, and model parameter count. On TinyStories, SMT DMT exhibits smooth, predictable performance improvements as both the context length Tc and the memory size increase. This confirms that the method effectively leverages longer contexts and larger memory capacities, which is not a given for RNN training methods. For model scaling, sweeping the width and depth of the RNN, encoder, and decoder shows that the SMT encoder follows a standard power-law-like scaling trend, while the SMT DMT RNN also improves smoothly with scale, though with a differently shaped curve. Interestingly, the RNN appears to more closely match the encoder's performance at larger scales, suggesting that the dynamics imitation task becomes easier when both the teacher and the student have more capacity.

A particularly novel scaling analysis is performed along the compression axis. The authors train SMT models across a sweep of memory state sizes and training compute budgets, plotting iso-loss contours. The results confirm that SMT can achieve more compression (better performance with smaller memory) when allocated more compute. This suggests a new scaling law for sequence models: memory state compression. The authors note that Transformers perform no compression of the past (their memory grows with sequence length), which may explain their training efficiency but also their inference costs. RNNs, by contrast, compress the past into a fixed-size state, and SMT provides a principled way to learn this compression.

案例研究 (Case Studies)

Memory Space Visualization: What the Encoder Learns

To understand what SMT is actually learning, the authors train smaller models with 2D memory states and visualize the memory space across three synthetic tasks. In the retrieval task with two possible needles, the encoder learns to collapse the infinite sequence of possible pasts into only three effective memory states: an initial state, a state indicating that the next token is the needle, and states corresponding to the needle values. The RNN then learns finite-state-machine behavior to transition between these states. This is a remarkable finding: the encoder has discovered a minimal state representation that captures the task structure, and the RNN has learned to implement the transitions of a finite state machine.

In contrast, the string copying task requires lossless sequence compression because every sequence must be reproduced exactly in reverse. The encoder cannot alias distinct memory states together; it must preserve enough information to uniquely identify the input sequence. The visualization reveals a tree-like memory geometry, where each branch corresponds to a different input token, matching the tree structure of all possible strings. This demonstrates that the encoder adapts its memory geometry to the requirements of the task: when the future requires only categorical information (which needle was seen), it collapses the state space; when the future requires lossless reconstruction (copy the string), it expands into a tree structure.

Gradient Properties: The Fundamental Difference

The authors analyze the gradient properties of BPTT and SMT on the needle retrieval task, where the loss is applied only at the last timestep. In BPTT, the gradient magnitude of mt with respect to the loss, L/mt, depends critically on the time position t and the weight initialization. Depending on the singular values of the Jacobian, gradients either vanish (becoming exponentially small for early timesteps) or explode (becoming exponentially large), leading to a recency bias or training instability. In SMT, the gradient magnitude is independent of t, because the credit path length between the needle token and the answer token is always O(1), routed through the encoder rather than through the recurrent chain. This is not an empirical trick or a regularization effect; it is a structural property of the algorithm that eliminates the gradient stability problem at its root.

DMT and Drift Dynamics

The authors analyze the drift phenomenon and DMT's correction in detail. Before DMT, the RNN's rollout error 1R2 between the predicted memory and the encoder memory grows steadily over time. After DMT, the drift is significantly reduced. Interestingly, DMT seems to discover RNNs that have a higher initial drift but plateau at a much lower equilibrium drift. The one-step drift (the error in a single step) only partially correlates with the rollout drift, suggesting that the dynamics of error accumulation are complex and not fully predictable from single-step metrics. This invites future investigation into predicting and mitigating rollout drift during the SMT phase itself, perhaps by training the RNN to be more stable or by designing better initialization strategies.

Sequence Length Generalization

The authors compare an SMT DMT RNN against its Transformer teacher on the synthetic stack state tracking task, evaluating on sequence lengths longer than those seen during training. The Transformer outperforms the RNN on training sequence lengths, but significantly underperforms on longer sequences. This reflects the distinct inductive biases of the architectures: Transformers behave like growing lookup tables, where the context window determines the maximum sequence length they can handle. RNNs, by contrast, update finite states and can generalize to arbitrary sequence lengths if the state transition dynamics are correct. This is a powerful argument for RNNs in applications requiring unbounded horizons, such as lifelong learning or continuous agent control, where the sequence length is not known in advance.

综合价值与局限 (Synthesis — Value and Limitations)

Theoretical Significance

SMT changes how we think about the relationship between sequence modeling and parallel computation. It demonstrates that the sequentiality of RNN training is not an intrinsic property of recurrent architectures, but rather a consequence of the training algorithm. By reframing the past as a set of timestamped events, the paper opens a theoretical pathway for time-parallel training of any model that computes a sufficient statistic of the past. The concept of predictive states—memories that are defined by what they predict rather than by how they are computed—provides a new conceptual tool for thinking about memory, compression, and representation learning in dynamical systems.

The paper also makes a valuable contribution to the understanding of gradient stability in deep learning. The O(1) credit path length of SMT is not achieved through architectural constraints (like unitary weights) or approximations (like truncated BPTT), but through a fundamental reorganization of the computation graph. This is a qualitatively different approach to the vanishing gradients problem, one that treats the symptom (unstable gradients) by removing the cause (long recurrent paths) rather than by modifying the patient's constitution (the architecture).

Practical Impact and Strengths

The practical strengths of SMT are clear. It enables time-parallel training of nonlinear RNNs, which has been a long-standing goal in the field. It provides stable gradients for learning long-range dependencies, outperforming BPTT on tasks that require credit assignment across hundreds or thousands of timesteps. It reduces the sequential computation required for training, which is the primary bottleneck on modern parallel hardware. It also provides a principled method for pretraining RNNs, producing strong initializations that can be fine-tuned with lightweight post-training.

The empirical results on pixel sequence modeling are particularly impressive. Generating coherent MNIST digits and Sketchy sketches from a raster-scan pixel sequence is a task that fundamentally requires long-range memory, as the model must integrate shape information across the entire image. That SMT DMT achieves this with a fixed-size memory state, while BPTT fails even with GRU architectures, is a strong validation of the method's practical utility.

Limitations and Honest Weaknesses

The authors are commendably transparent about the limitations of their method. First, the teacher Transformer is time-parallel and thus limited in expressive power compared to unbounded sequential computation. This means SMT-trained RNNs may inherit the same limitations and may require BPTT finetuning to achieve full expressivity beyond the teacher. This is not a fatal flaw—the authors position SMT as a pretraining method—but it means that SMT alone does not eliminate all need for BPTT.

Second, SMT is useful for learning how to encode sequences, but not necessarily for learning complex reasoning where intermediate steps are not supervised. The same limitation applies to Transformers, and post-training has proven effective for extending their capabilities; the authors suggest the same may be true for SMT-trained RNNs. Third, the current SMT variant computes and trains on only a single memory mt within a sequence per optimization step, rather than all memories. While the authors found no improvement from training on all memories in their settings, this may not hold at larger scales.

Fourth, DMT, while effective, is not time-parallel and requires unrolling the RNN. It is a necessary but imperfect solution to the drift problem. The authors note that DMT could potentially be parallelized using DEER or similar techniques, but this remains future work. The one-step drift only partially correlates with rollout drift, suggesting that predicting and mitigating drift during SMT remains an open challenge.

Finally, the SMT-trained RNN on Sketchy generates images that capture stroke structure but are not always fully interpretable. This indicates that while the method significantly improves RNN performance, it does not fully solve the hardest sequence modeling problems.

Broader Implications

SMT sits at a fascinating intersection of trends in the field. It connects the parallelization power of Transformers with the memory efficiency of RNNs, potentially offering a path to sequence models that scale efficiently along both compute and memory axes. It also raises intriguing questions about the nature of intelligence and compression: the paper's finding that SMT can learn better compression with more compute suggests that memory efficiency is a learnable and scalable property, not merely an architectural constraint. In the context of lifelong learning and continuous agent experience, where the sequence length is unbounded and memory must be constant, SMT provides a principled framework for training models that can build and maintain temporal abstractions over a lifetime of experience.

延伸阅读与思考 (Further Reading and Reflection)

Prior Work and Foundations

This paper builds upon several foundational lines of research. The predictive state representation framework, introduced by Littman and Sutton (2001) and developed by Singh, James, and Rudary (2012), provides the conceptual backbone: the idea that state can be defined purely in terms of predictions about the future. Earlier work by Downey et al. (2017) incorporated PSRs into RNNs, but still used BPTT and thus was not time-parallelizable. The cross-architecture distillation literature, including Kasai et al. (2021) on finetuning Transformers into RNNs and Chen et al. (2026) on hybrid linear attention, provides methodological context, though these works do not address the fundamental credit assignment problem of training nonlinear RNNs without BPTT.

The concurrent work by Teoh et al. (2025) on Next-Latent Prediction (NextLat) is particularly relevant. With a specific hyperparameter setting, SMT and NextLat are equivalent. However, Teoh et al. focus on using the RNN as a regularizer for the Transformer, and their experiments primarily use BPTT for the latent representations (optionally truncated to Tf=1), whereas Kumar and Isola focus on the Tf=1 case and on eliminating BPTT entirely for RNN training. This distinction in emphasis leads to complementary rather than competing contributions.

Alternative approaches to time-parallel RNN training have emerged recently. The parallelization of linear RNNs through associative scan (Martin and Cundy, 2018; Dao and Gu, 2024) has enabled impressive results with models like Mamba, but these are fundamentally limited by the linearity of their transition function, which constrains the class of functions they can represent (Merrill et al., 2026). Recent attempts to parallelize nonlinear RNNs, such as DeepPCR (Danieli et al., 2023) and ParaRNN (Danieli et al., 2025), formulate the forward pass as an iterative optimization problem solved with Newton's method. While appealing, these approaches approximate BPTT and thus inherit its O(T) credit path length and gradient instability, along with the convergence concerns of Newton's method. SMT, by contrast, avoids recurrent credit propagation entirely.

Linear attention models and state space models (Katharopoulos et al., 2020; Gu, Goel, and Ré, 2021) also offer time-parallel training and fixed-size memory, but their linearity fundamentally limits their expressive power. The Recurrent Transformer (Oncescu et al., 2026) achieves O(1) gradient paths by attending to all past hidden states, but this causes its memory to grow unboundedly during inference, making it more akin to a Transformer than a fixed-memory RNN. SMT supports arbitrary fixed-memory RNN architectures and enables time-parallel training without ever unrolling the RNN.

Future Directions and Open Problems

The paper opens several promising research directions. The most immediate is the development of fully time-parallel drift correction methods that can replace or augment DMT. If the RNN's rollout drift could be predicted and corrected without unrolling, the entire training pipeline would remain time-parallel, preserving SMT's computational advantages. Techniques from deep equilibrium models or parallel nonlinear solvers may be relevant here.

Another direction is the exploration of SMT for lifelong learning and continuous agent control. RNNs are natural candidates for agents that must learn and remember over unbounded time horizons, but training methods have been the primary obstacle. SMT's O(1) credit path length could enable RNNs to learn memories that are useful only after thousands or millions of timesteps, an ability crucial for true lifelong learning.

The scaling laws for compression also invite further investigation. The paper shows that SMT can achieve better compression with more compute, but the functional form of this relationship and its limits remain unknown. Understanding how memory compression scales with model size, data, and compute could provide a new axis for designing efficient sequence models.

Finally, the theoretical question of full expressivity remains open. The sequence-to-set reframing shows that Transformer encoders can approximate recurrent memory functions, but may require depth to scale with sequence length for full theoretical expressivity. Understanding when and how shallow encoders succeed, and when deeper or different architectures are needed, would strengthen the theoretical foundations of the method.

Personal Reflection

What strikes me most about this paper is its conceptual clarity and boldness. The authors do not merely propose another architectural tweak or training trick; they fundamentally reframe the problem of RNN training. The idea that the past can be treated as a set rather than a sequence, and that memory can be learned through a predictive state objective, is both elegant and powerful. It reminds me of the kind of paradigm-shifting insight that characterized the original Transformer paper: instead of fighting the sequentiality of RNNs, the Transformer eliminated it by using attention. Kumar and Isola do something similar for RNN training: instead of fighting the instability of BPTT, they eliminate the need for recurrent credit propagation altogether.

The most thought-provoking aspect is the decoupling of memory representation from memory dynamics. This separation feels deeply right: it mirrors how we think about human cognition, where the formation of memories (encoding) is distinct from the maintenance and updating of memories (consolidation). It also suggests a general principle for training dynamical systems: if you can learn what the state should be, learning how to transition between states becomes much simpler. I find myself wondering how this principle could be extended beyond sequence modeling to other domains involving temporal dynamics, such as physical simulation, control systems, or even the training of recurrent connections in biological neural networks. This paper does not just solve a technical problem; it opens a new way of thinking about time, memory, and learning.

Topics:

Powered by Forestry.md