Unlocking the Working Memory of Large Language Models for Latent Reasoning

Authors: Lukas Aichberger, Sepp Hochreiter
Institutions: ELLIS Unit Linz and LIT AI Lab, Institute for Machine Learning, Johannes Kepler University Linz; NXAI GmbH, Linz, Austria
Venue: arXiv:2605.30343v1 [cs.CL]
Year: 2026
Pages: 21 (main paper + appendix)
arXiv ID: 2605.30343v1
Paper URL: https://arxiv.org/abs/2605.30343


研究摘要 (Research Summary)

当我们要求一个大型语言模型(Large Language Model, LLM)解决一道复杂的数学题时,它通常会怎么做?在 chain-of-thought (CoT, 思维链) 推理范式下,模型会逐字逐句地写出中间推理步骤——"首先计算总价,然后减去折扣,最后得到答案"。这种"大声思考"(thinking out loud)的方式虽然有效,却将推理过程与语言生成过程深度耦合在一起:模型不仅要进行数学运算,还要耗费宝贵的计算预算来生成语法正确、流畅自然的中间文本。语言本质上是用于沟通的符号系统,而非优化的计算媒介,因此相当一部分计算资源被浪费在了确保文本连贯性上,而非纯粹的内部推理。

这篇论文的核心关切正是如何打破这种耦合。Aichberger 和 Hochreiter 提出了一个根本性的问题:人类在解决复杂问题时,真的会把每一个中间步骤都用语言表达出来吗?认知心理学的答案是明确的否定。人类拥有工作记忆(working memory),这是一个内部的认知工作空间,可以在不将信息外化为语言的情况下暂时存储和操纵任务相关的表征。发展心理学进一步观察到,儿童在学习解决问题时确实会依赖"自言自语"(private speech)作为语言支架,但随着认知能力的成熟,这种外部化的支架会逐渐被内化的思维所取代。这一认知科学洞见为论文的理论动机提供了坚实的基础:如果我们能够为大语言模型也构建一个类似的工作记忆系统,使其能够在不生成任何中间文本标记(token)的情况下进行内部推理,那么我们就可以在保留推理能力的同时,彻底摆脱语言生成带来的计算开销和语法约束。

基于这一动机,作者提出了 Reasoning in Memory(RiM,内存推理)方法。其核心创新在于引入了一组固定排列的特殊标记序列,称为记忆块(memory blocks)。这些记忆块由边界标记 <b></b> 以及内部的 <m> 标记组成,它们的位置和身份是预先固定的,因此可以在单次前向传播(single forward pass)中被模型同时处理。这与现有的隐式推理方法形成了鲜明对比——无论是 CoT 的显式文本生成,还是 Coconut 等隐式推理方法中连续表征(continuous representations)的自回归生成,都要求中间计算结果必须先被"外部化",后续计算才能以其为条件。RiM 的记忆块则提供了一个固定的潜在工作空间,模型可以在其中编码中间推理信息,而无需逐一生成这些表征。

为了让模型学会利用这些原本没有预定义计算角色的记忆块,作者设计了一个精心构建的两阶段课程(two-stage curriculum)。第一阶段(Stage 1)中,每个记忆块之后都设置一个读出分支(readout),要求该分支从已见过的记忆块中恢复出对应的下一个显式推理步骤。这相当于用一种预测性表征学习(predictive representation learning)的方式,强制模型将任务相关的中间信息编码到记忆块的上下文表征中。第二阶段(Stage 2)则移除了中间推理步骤的监督信号,转而要求每个记忆块后的读出直接预测最终答案,鼓励模型将已习得的工作记忆能力用于逐步精炼答案。这种从"教会记忆块做什么"到"让记忆块自己做"的渐进式训练策略,使得模型最终能够在没有任何显式推理痕迹的情况下,仅通过固定记忆块的潜在计算就给出正确答案。

实验结果令人信服。在 GSM8K(领域内测试集)和 GSM-Hard(领域外测试集)上的评估表明,RiM 在 GPT-2 和 Llama-3.2 两个模型家族的多个规模上,均达到或超过了现有的显式 CoT 和隐式推理基线方法。更重要的是,RiM 的推理延迟与直接回答模型(SFT w/o CoT)几乎完全相同——因为记忆块是固定输入,只需单次前向传播即可处理。相比之下,Coconut 的延迟约为 RiM 的 7 倍,而显式 CoT 的延迟更是达到了约 27 倍。这意味着 RiM 在实现了高质量的隐式推理的同时,几乎没有牺牲任何推理效率。

这项工作的深远影响在于,它从根本上挑战了"推理必须通过某种形式的外部化过程完成"这一隐含假设。它证明了大语言模型可以被训练成使用工作记忆作为有效的推理机制,为构建更快速、更节能、更不受语言语法约束的推理系统开辟了全新的道路。这一范式转变不仅有望降低部署推理系统的计算成本,也为理解模型内部的信息处理过程提供了新的研究视角——毕竟,当推理发生在固定的潜在空间中时,我们或许能够比追踪千变万化的文本输出更容易地分析和解释模型的思考过程。

理论框架 (Theoretical Framework)

要深入理解 RiM 的理论根基,我们需要追溯从显式推理到隐式推理的思想演变脉络,以及认知科学中工作记忆概念的启发意义。

现代测试时计算(test-time compute)范式的兴起源于 Nye 等人在 2021 年的开创性工作,他们提出在问题和答案之间插入一个"文本工作区"(scratchpad),让模型在生成最终答案前先写下中间计算步骤。这一思想在 Wei 等人 2022 年的 chain-of-thought 工作中得到了系统性的验证和扩展:通过 few-shot 提示词中包含逐步推理的示例,可以显著激发大语言模型的多步推理能力。然而,正如 Kojima 等人 2022 年以及 Wei 等人自己所指出的,这种推理方式存在根本性的张力——语言是为了沟通而进化的,而非为了计算。当我们强迫模型用自然语言表达每一个中间步骤时,相当一部分计算预算被分配给了生成语法正确、语义连贯的文本,而非纯粹的内部信息处理。

近年来,隐式推理(latent reasoning)领域的探索正是为了缓解这一张力。Hao 等人 2025 年提出的 Coconut 方法是这一方向的代表性工作:它将离散的推理标记替换为连续的表征(continuous thoughts, CTs),这些连续表征由模型的隐藏状态产生,并作为下一个解码步骤的输入嵌入被反馈回模型。Cheng 和 Durme 2024 年的压缩思维链、Shen 等人 2025 年的 CODI 自蒸馏方法、以及 Wang 等人 2025 年的 Synadapt 等工作,都在不同方面改进了连续表征的训练和使用方式。然而,这些方法共享一个核心特征:中间计算仍然以自回归(autoregressive)的方式被"外部化"——每个连续表征必须先被生成,下一个表征才能以它为条件。换言之,它们是在连续空间中"大声思考",但依然没有摆脱计算的串行瓶颈。

另一条相关的工作线探索了填充标记(filler tokens)的潜力。Lanham 等人 2023 年发现,简单地在输入中添加填充标记并不能提升准确率,甚至可能在长上下文中降低性能。Pfau 等人 2024 年展示了填充标记可以在合成算法任务上支持计算,但需要特定的密集监督才能收敛。Goyal 等人 2024 年将这一方法扩展到真实世界下游任务,但发现收益主要来自在预训练和微调两个阶段都使用填充标记。Deng 等人 2024 年的 DART 方法则表明,填充标记可以在没有特定预训练的情况下被训练,但微调需要一个双路径自蒸馏框架和多个辅助损失函数。这些结果共同暗示:让没有预定义语义内容的标记承担计算角色是可能的,但需要一个精心设计的训练信号。

正是在这个理论交汇点上,认知科学中的工作记忆概念提供了关键的启发。Baddeley 1992 年对工作记忆的经典定义将其描述为一个用于暂时保持和操作任务相关信息的内部认知工作空间。Vygotsky 1978 年在发展心理学中的观察更具启发性:儿童最初依赖外部语言作为认知支架,但随着能力的发展,这种外部化会逐渐内化。这暗示了一个重要的设计原则——从外部化的推理(显式 CoT)到内部化的推理(隐式工作记忆)不仅是可行的,而且是认知能力成熟的标志。

RiM 的理论框架正是在这些思想的基础上构建的。其核心概念是记忆块(memory block),形式化地定义为:

mk=[b,m,,m,/b]

其中每个记忆块 mk 包含 M<m> 标记,由边界标记 <b></b> 界定。在全部实验中,M 被固定为 2。这些记忆块的身份和位置是完全固定的——它们不是由模型生成的,而是作为输入的一部分预先插入到问题序列之后。因此,对于 K 个记忆块,模型处理的完整序列是 (x,m1:K),其中 x 是问题标记序列。由于这些标记位置固定,整个增强序列可以在单次前向传播中被处理:

hidden states=LLM(x,m1,m2,,mK)

这与 CoT 推理形成了鲜明对比。在标准 CoT 中,对于包含 C 个推理标记的推理痕迹 r=(r1,,rC),每个推理标记都是从条件分布 pw(ri|x,r<i) 中自回归解码的,需要 C 次顺序解码步骤。在 Coconut 等显式隐式推理中,对于 L 个连续表征 z1:L,每个表征 zRd 都依赖于先前生成的表征 z<,因此仍需要 L 次顺序解码步骤(尽管通常 L<C)。而 RiM 通过将计算移入固定的记忆块,将复杂度从 O(C)O(L) 的顺序步骤降低到 O(1) 的单次前向传播,这是其计算效率优势的数学根源。

关于训练的理论框架,RiM 采用了 JEPA(Joint-Embedding Predictive Architecture,联合嵌入预测架构)风格的预测性表征学习思想。LeCun 2022 年提出的 JEPA 框架强调,通过学习预测缺失的结构,固定状态可以转化为并行化的潜在状态。Stage 1 正是这一思想的直接应用:固定记忆块学习预测缺失的推理结构,从而转变为并行的工作记忆状态。这种预测性约束确保了记忆块的上下文表征不能是任意的——它们必须包含足够的信息来重构下一步的推理内容。

RiM 的理论假设也隐含了对其适用范围的一些限定。首先,它假设语言模型已经具备足够的先验知识,只需被引导如何将这些知识组织到固定的工作记忆中,而非从零学习推理本身。其次,记忆块的数量 K 需要与任务的推理深度相匹配——对于 GSM8K-Aug 中最多 13 步的推理问题,Stage 1 使用最多 13 个记忆块,而 Stage 2 则固定使用 8 个记忆块作为推理预算。这种预设的"记忆容量"限制了模型能够处理的问题复杂度,但也正是这种限制使得计算变得可控和高效。

技术架构 (Technical Architecture)

RiM 的技术实现可以视为一个围绕固定记忆块构建的精密训练系统,其核心在于如何通过精心设计的注意力掩码(attention mask)和两阶段课程目标,教会一个预训练的语言模型将原本"空洞"的特殊标记转化为承载任务相关信息的潜在工作空间。

从系统整体架构来看,RiM 可以分解为三个相互协作的组件:记忆块注入层、定制注意力掩码、以及两阶段训练目标。记忆块注入层负责在输入问题序列之后追加固定数量的特殊标记序列。每个记忆块由 <b> 起始标记、两个 <m> 内容标记(实验中 M=2)和 </b> 结束标记组成。为了避免干扰预训练词汇表的语义,所有现有词汇标记的嵌入被冻结,仅更新新引入的特殊标记的嵌入。这一设计选择具有深刻的考量:它确保模型的改进来源于学会使用记忆块作为工作空间,而非通过修改已有词汇的表征来间接适应任务。特殊标记在其他方面被当作标准输入标记处理,通过语言模型的常规前向传播机制进行计算。

定制注意力掩码是 RiM 架构中最精妙的设计之一,也是实现密集监督(dense supervision)的关键。标准的因果注意力(causal attention)要求序列中的每个位置只能关注它之前的位置。如果直接应用标准掩码,那么后面的记忆块读出分支将能够直接看到前面的显式推理步骤(在 Stage 1 中作为监督目标使用),从而绕过记忆块本身进行预测。RiM 的解决方案是将序列划分为两个流:记忆块流(memory block stream)和监督读出分支(supervised readout branches)。

具体来说,未来的记忆块可以关注问题以及之前的记忆块,但绝不可以关注任何监督读出分支。每个读出分支可以关注问题和截至其位置为止的所有可用记忆块,但不能关注其他读出分支。这种架构使得所有推理步骤(Stage 1)或答案目标(Stage 2)可以在单次前向传播中同时接收教师强制(teacher-forced)监督,同时防止后面的潜在状态"偷看"到前面步骤的真实答案。从信息论的角度来看,这确保了每个读出分支必须从记忆块中提取信息来完成预测任务,从而强制中间计算真正发生在记忆块内部。

论文还探讨了一种有趣的变体:块内双向注意力(bidirectional within-block attention)。在这种变体中,同一个记忆块内部的标记可以相互双向关注,而块与块之间的注意力仍然保持因果性。这增加了块内部的信息交换能力,但论文报告称实验结果参差不齐,没有呈现一致的趋势,因此将系统性的研究留给了未来的工作。

两阶段课程构成了 RiM 训练策略的核心。第一阶段的目标是让模型学会"记忆块是什么"——即为这些原本没有计算角色的标记赋予中间推理的语义。具体而言,对于每个训练样本 (x,r,y),其中 x 是问题标记序列,r 是显式推理痕迹,y 是最终答案标记序列,作者首先将推理痕迹分割为 T 个推理步骤 r1:T=(r1,,rT)。然后,为每个推理步骤分配一个记忆块,并训练第 t 个记忆块后的读出分支预测下一个推理步骤 rt+1(约定 rT+1=y)。

Stage 1 的目标函数是一个加权的推理步骤负对数似然:

LS1(w)=t=1Tλt(s)logpw(rt+1|x,mt)

这里 λt(s)[0,1] 控制第 t 个记忆块读出分支的监督强度,随训练步数 s 而变化。论文采用线性相对退火(linear relative annealing)策略:每个 λt(s) 从 1 线性递减到 0,退火顺序由 t 决定(即较早的推理步骤先失去监督)。这种相对加权——相对于每个样本自身的推理步骤数 T 而非固定的全局最大步数——确保了短样本的监督不会被过早移除。实验证明,这种策略优于绝对加权和无加权策略:绝对加权会过早地为短样本移除监督,而无加权则从不放松密集监督,导致模型无法自主组织推理流程。

当 Stage 1 完成后,记忆块已经获得了有意义的计算角色——它们的上下文表征编码了任务相关的中间信息。此时,第二阶段的目标转变为让模型学会"记忆块能做什么"——即利用已习得的工作记忆来直接精炼最终答案。Stage 2 丢弃显式推理痕迹 r,对所有样本使用固定数量的 K 个记忆块(论文中为 8 个)。每个记忆块后的读出分支被训练为预测最终答案 y

LS2(w)=k=1Kαklogpw(y|x,mk)

权重 αk[0,1] 控制第 k 个记忆块读出分支的监督强度。同样采用线性加权策略,但这一次权重随着 k 增加而增大,反映了一个直观的假设:后面的记忆块可以访问更多的潜在计算,因此应该产生更强的最终答案。Stage 2 的灵感来自迭代隐式推理模型如 HRM(Hierarchical Reasoning Model)和 TRM(Tiny Recursive Model),但 RiM 通过沿序列维度的水平扩展而非递归模块来实现精炼:额外的记忆块在答案读出之前提供了额外的潜在计算,既不需要递归精炼模块,也不需要自回归生成潜在状态。

在两个阶段之间,作者遵循 Deng 等人和 Hao 等人的做法,重置优化器状态和学习率调度器,并使用更低的学习率和更高的 dropout 来减少密集答案监督下的过拟合。这种"硬切换"很重要,因为两个阶段优化的是截然不同的目标:Stage 1 为中间计算奠定基础,而 Stage 2 则在这个基础上进行最终答案的精炼。实验表明,缺少 Stage 1 的直接 Stage 2 训练虽然能快速提升最终答案准确率,但会远低于先经过 Stage 1 grounding 的模型;而仅有 Stage 1 则无法产生可靠的最终答案读出,因为模型从未被训练过如何从记忆块中直接提取答案。

实验评估 (Experimental Evaluation)

RiM 的实验设计围绕三个核心研究问题展开:记忆块是否真的被用于中间计算?RiM 在性能和延迟方面与先前方法相比如何?推理时的记忆预算变化是否会影响性能稳定性?为了回答这些问题,作者在公认的小学数学推理基准上进行了系统性的评估,涵盖了模型家族、参数规模和分布内/分布外泛化等多个维度。

实验设置上,训练数据采用 GSM8K-Aug——一个包含 386K 道小学数学题的扩充数据集,其中每道题附带最多 13 个以数学表达式形式呈现的显式推理步骤。分布内(In-Distribution, ID)测试集使用标准的 GSM8K,包含约 1,319 道测试题;分布外(Out-of-Distribution, OOD)测试集使用 GSM-Hard,这是一个更具挑战性的数学问题集合。模型方面覆盖了 GPT-2(Radford 等,2018)和 Llama-3.2(Dubey 等,2024)两个主流家族,包括 GPT-2、Llama-3.2-1B 和 Llama-3.2-3B 三种规模。这一设置使得实验结果可以直接与 Hao 等、Jiang 等、Deng 等、Goyal 等、Shen 等先前在隐式推理领域的工作进行比较。

基线方法的选择体现了作者对公平比较的审慎考量。最直接相关的非隐式基线是 SFT w/o CoT(监督微调但不使用思维链),它直接在问题-答案对上训练,是与 RiM 最接近的非潜在基线,因为两者在测试时都不能写出显式推理痕迹。SFT w/ CoT 则作为显式推理的黄金标准,在训练和测试时都使用完整的自然语言推理痕迹,虽然其推理成本远高于 RiM。隐式推理方面,主要比较对象是 Coconut(Hao 等,2025)——当前最广泛使用的隐式推理基线。论文比较了两种 Coconut 变体:带有 Stage 0(显式推理热身阶段)的 Coconut w/ Stage 0,以及省略该阶段的 Coconut w/o Stage 0。此外,还与 DART(Jiang 等,2025)的官方报告数字进行了比较,尽管 DART 使用了更多的训练轮次和双路径训练目标,这种比较对 RiM 在训练成本方面是保守的。

评估协议上的一个关键细节体现了作者对实验严谨性的重视。先前工作通常报告在评估基准上表现最好的检查点(checkpoint),但这引入了选择偏误(selection overfitting),可能夸大真实性能。为了避免这一问题,作者采用了 k 折交叉验证协议:将 GSM8K 测试集分为 16 份,每份保留 264 个样本用于检查点选择,选择在该保留集上贪婪准确率最高的检查点。除非特别说明,所有后续分析都使用这些被选中的检查点。

关于记忆块是否承载有意义计算的验证,作者进行了细致的表征分析。训练 Llama-3.2-1B 6 个 epoch 的 Stage 1(约 18,000 步)和 2 个 epoch 的 Stage 2(约 6,000 步),每 1,000 步保存一次检查点。在每个检查点处,收集所有 GSM8K 测试题上每个记忆块在倒数第二层的表征,并将它们投影到一个共享的 PCA 基上(解释了 25% 的总方差)。观察到的现象极具说服力:在训练轨迹方面,不同记忆块的表征沿着平滑且块特定的轨迹演化,表明模型正在系统性地组织潜在工作空间,而非随机扰动表征。在基础模型与最终模型的对比中,未经训练的基线模型的记忆块表征基本是坍缩的,而训练后的最终模型的表征则形成了广泛的、与样本相关的簇,不同的问题诱导出截然不同的潜在工作空间方向。对于 PCA 投影中距离最远的两个样本,在原表征空间中计算余弦相似度低至接近零,证实了观察到的分离确实反映了原始表征空间中的真实差异。

主要性能结果总结在表 1 中。RiM 的"最终块读出"(Final block)——即使用最后一个记忆块的答案预测作为模型输出——在所有模型规模和两个测试集上都一致地超过了最强的 Coconut 变体。在 GSM8K 上,RiM 相比 Coconut w/ Stage 0 的提升幅度在 GPT-2 上为 2.5 个百分点,在 Llama-3.2-1B 上为 5.2 个百分点,在 Llama-3.2-3B 上为 7.5 个百分点。在更具挑战性的 GSM-Hard 上,提升幅度为 0.7 到 1.8 个百分点。与直接回答基线 SFT w/o CoT 相比,提升更为显著:在 GSM8K 上达到 12.6 到 18.2 个百分点,在 GSM-Hard 上为 3.5 到 5.2 个百分点。Pass@8 指标(使用温度 1 采样 8 个答案,若任一正确则计为正确)呈现出相同的模式,表明 RiM 不仅仅在贪婪解码下运气好,而是真正地分配了更多的概率质量给正确答案。

延迟方面的对比尤为引人注目。如表 1 和表 7 所示,RiM 的首标记时间(Time to First Token, TTFT)与 SFT w/o CoT 几乎完全相同——在 Llama-3.2-1B 上均为 16.1 毫秒。这是因为记忆块是固定输入标记,只需单次前向传播即可处理。相比之下,Coconut 的延迟约为 RiM 的 7 倍(108.3 毫秒),而显式 CoT 更是达到了约 27 倍(420.3 毫秒)。在完整的答案生成延迟上(表 7),RiM 与 SFT w/o CoT 的差异仅为 0.5 毫秒(126.0 毫秒 vs 126.5 毫秒),而 Coconut 慢了 178.7 毫秒,SFT w/ CoT 慢了 982.7 毫秒。这意味着 RiM 在实现了接近显式 CoT 甚至超越部分隐式基线的推理质量的同时,保留了直接回答模型的推理速度——这对于延迟敏感的实际应用来说是一个决定性的优势。

Model Method TTFT (ms) GSM8K Greedy (%) GSM8K Pass@8 (%) GSM-Hard Greedy (%) GSM-Hard Pass@8 (%)
GPT-2 SFT w/o CoT 7.6 15.4±0.2 33.3±0.3 3.5±0.1 7.6±0.1
Coconut w/ Stage 0 53.4 31.1±0.2 45.0±0.2 7.1±0.0 10.7±0.1
RiM (ours) 7.6 33.6±0.2 49.1±0.2 7.8±0.1 11.2±0.1
Llama-3.2-1B SFT w/o CoT 16.1 23.9±0.2 41.7±0.3 5.3±0.1 9.5±0.1
Coconut w/ Stage 0 108.3 36.9±0.2 51.1±0.2 8.5±0.0 12.2±0.0
RiM (ours) 16.1 42.1±0.2 56.1±0.3 10.5±0.0 13.8±0.0
Llama-3.2-3B SFT w/o CoT 27.9 36.2±0.2 45.9±0.2 8.5±0.1 10.8±0.1
Coconut w/ Stage 0 188.8 41.3±0.2 55.5±0.5 10.2±0.1 13.5±0.1
RiM (ours) 27.9 48.8±0.2 58.8±0.2 12.0±0.0 14.1±0.0

表 1:主要结果。使用交叉验证协议在 GSM8K 上选择检查点。数值为 16 次分割重复的均值 ± 标准误。

推理预算的稳健性分析回答了第三个核心问题。图 6a 展示了在 Stage 1 和 Stage 2 后,改变记忆块中 <m> 标记数量 M(垂直轴)和记忆块数量 K(水平轴)对贪婪准确率的影响。Stage 1 后的准确率在第一块读出时达到约 27%,但随着 K 增加而下降——这符合预期,因为 Stage 1 将记忆块与中间推理步骤绑定,而非最终答案。然而 Stage 2 后,这种依赖性基本消失:准确率提升至约 43%,并在广泛的记忆预算范围内保持稳定。这表明 Stage 2 成功地将 Stage 1 中习得的 grounded 记忆块转化为一个固定序列的潜在计算,可以在不同位置被可靠读出。图 6b 进一步追踪了跨记忆块的答案转换情况:在最终模型中,仍有相当一部分问题的答案会在不同记忆块之间发生变化,且累积净效应为正(更多问题从错误转为正确而非相反)。这说明潜在工作空间在 Stage 2 中并未坍缩,而是继续跨记忆块精炼预测。

关于阶段切换的消融实验(图 9)为训练策略提供了有力的支持证据。仅使用 Stage 2 虽然能快速提升最终答案准确率,但很快达到远低于先经过 Stage 1 grounding 的模型的平台期。反之,仅使用 Stage 1 虽然能产生很高的"任一块准确率"(any-block accuracy),但最终块准确率(final-block accuracy)始终很低——因为模型从未被训练过如何从记忆块中直接提取答案。过早的阶段切换也会削弱性能,表明记忆块必须先获得足够的计算角色,答案精炼才能有效。此外,将 RiM 的课程替换为 Coconut 风格的渐进课程(图 10)后,性能显著下降,证实记忆块确实需要密集监督信号来强制通过潜在工作空间进行中间计算。

案例研究 (Case Studies)

为了更直观地理解 RiM 如何在实践中运作,我们可以从论文提供的训练数据和测试样本中选取典型案例进行走读分析,同时结合表征层面的可视化证据来揭示模型内部的信息处理过程。

论文在图 7 展示了代表性的数据集样本。图 7a 是一个 GSM8K-Aug 训练样本:"John 上午 8:30 开始徒步,下午 6:30 结束。他中午休息 20 分钟,下午休息 15 分钟,结束前休息 30 分钟。他徒步了多少小时?"该样本附带了四个显式推理步骤:首先计算总时长 18.58.5=10,然后累加休息时间 20+15+30=65,接着将分钟转换为小时 65/60=1.08333,最后得到净徒步时间 101.08333=8.91667。在 Stage 1 中,RiM 为这四个推理步骤分配四个记忆块,每个记忆块后的读出分支被训练为预测下一个推理步骤。第一个记忆块必须编码"计算时间跨度"的信息,以便读出分支能预测出 18.58.5=10;第二个记忆块必须同时保留"总时长为 10 小时"和"总休息时间为 65 分钟"的信息,以支持下一步的分钟到小时转换。这种多步信息的累积和转换正是工作记忆的核心功能。

图 7b 展示了 GSM-Hard 的一个测试样本——这是一个更复杂的房屋购买问题:"Cruz 夫人的预算是 40 万美元。她看到一处售价 35 万美元的房子。此外买家需要支付售价 5% 的经纪费和 12% 的转让费。总价比预算多多少?"值得注意的是,GSM-Hard 测试样本只提供问题和最终答案(9500.0),没有中间推理步骤。这意味着在测试时,模型必须完全依赖在 GSM8K-Aug 上通过显式推理步骤习得的工作记忆能力,来内化处理这个分布外的复杂问题。RiM 的成功在于,它不需要在测试时生成任何中间文本,而是通过记忆块的潜在计算直接得出答案。

表征层面的案例研究更为深刻。图 4 和图 11 展示了记忆块表征在训练过程中的演化轨迹。以 Llama-3.2-1B 的第 12 层为例,在训练开始时(基础模型),所有记忆块的表征几乎坍缩到同一个点——这说明未经训练的模型只是将这些特殊标记当作无意义的占位符。经过 Stage 1 的训练后,表征开始分离:不同记忆块沿着不同的方向演化,且同一记忆块内的表征也因问题而异。到了 Stage 2,这种分离变得更加显著。论文特别标注了两个在 PCA 投影中距离最远的样本,它们在第 12 层的余弦相似度低至约 0.09,第 16 层更是降至 0.01。这意味着某些问题诱导出的记忆块表征几乎是正交的——模型确实在使用记忆块编码截然不同的计算路径。

另一个有趣的案例来自线性探针(linear probe)实验(表 5)。作者从 256 个 held-out GSM8K 样本上训练轻量级线性探针,基于记忆块表征预测该块后的读出是否正确。在所有记忆块中,探针的 AUROC 约为 84.8% 到 85.0%,AUPRC 约为 80.7% 到 82.3%。更惊人的是,在条件于"至少一个记忆块产生了正确答案"的子集上,基于探针概率选择答案的准确率达到 90.0%。这表明记忆块表征中蕴含了丰富的正确性信息——即使最终块的答案可能是错的,模型在潜在工作空间的某个中间状态中实际上已经"知道"了正确答案。这一发现暗示了 RiM 的一个潜在扩展方向:通过简单的线性探针从记忆块序列中选择最佳答案,而非仅依赖最终块。

关于 Coconut 课程的消融案例(图 10)则从另一个角度说明了训练策略的重要性。当保持 RiM 的固定记忆块但将课程替换为 Coconut 风格的渐进式替换时,性能在整个训练过程中都显著低于 RiM 的两阶段密集监督。这就像一个学生如果只是偶尔被要求自己思考,大部分时间仍依赖老师的逐步提示,那么他的内化推理能力发展就会慢得多。相比之下,RiM 的 Stage 1 立即要求所有记忆块同时承担预测下一步推理的责任,这种"全面投入"的密集监督迫使模型迅速建立起工作记忆的使用习惯。

综合价值与局限 (Synthesis — Value and Limitations)

RiM 的提出标志着隐式推理领域的一个重要范式转变,其价值不仅体现在实验数字上,更在于它从根本上重新思考了推理与生成之间的关系。

从理论层面来看,这项工作最深刻的贡献在于将认知科学中的工作记忆概念系统地引入了大语言模型的推理架构设计。先前的方法,无论是显式的 CoT 还是隐式的 Coconut,都在某种意义上将推理视为一种"独白"——模型必须以一种可观察的形式(文本或连续表征)依次展开其思考过程。RiM 则打破了这种独白模式,提出推理可以在一个内部的、不可直接观察的潜在空间中并行进行。这一视角转变具有重要的认识论意义:它暗示我们过去对模型推理能力的评估可能过度依赖了外部化的思维痕迹,而忽视了模型在纯粹内部计算中的潜力。此外,RiM 的两阶段课程与 JEPA 预测性表征学习框架的联系,也为理解记忆块如何获得语义提供了理论锚点——记忆块之所以变得有意义,是因为它们被训练去预测外部环境中的结构(推理步骤)。

从实用层面来看,RiM 的延迟优势是最直接的工程价值。在 Llama-3.2-1B 上,RiM 的首标记时间为 16.1 毫秒,与直接回答模型的 16.1 毫秒持平;而 Coconut 需要 108.3 毫秒,显式 CoT 需要 420.3 毫秒。在 GPT-2 上,RiM 与 SFT w/o CoT 均为 7.6 毫秒,Coconut 则需要 53.4 毫秒。对于需要低延迟响应的生产系统——例如实时对话助手、在线数学辅导或嵌入式推理应用——这种数量级的延迟差异可能是决定性的。更重要的是,RiM 保持了单次前向传播的简洁性,这意味着它在部署时不需要复杂的生成调度或多步解码逻辑,大大降低了工程实现的复杂度。

然而,论文也存在一些诚实的局限性值得讨论。首先是记忆容量的硬性约束。RiM 需要在推理时预先确定记忆块的数量 K(论文中 Stage 2 使用 8 个),这对于超出其处理能力的复杂问题可能会成为瓶颈。相比之下,显式 CoT 可以自适应地生成所需长度的推理痕迹,Coconut 虽然也是固定预算但在理论上可以通过调整连续表征数量来灵活扩展。RiM 的固定架构意味着它必须在效率和灵活性之间做出预设的权衡。

其次是可解释性的双刃剑效应。一方面,RiM 的潜在推理痕迹不再以自然语言形式呈现,这使得直接的人类可读性分析变得困难——我们无法像阅读 CoT 那样轻易地检查模型"在哪里犯了错"。另一方面,固定位置的记忆块可能为表征层面的分析提供了比千变万化的文本输出更稳定的研究对象。论文中的 PCA 可视化和线性探针实验已经展示了这种可能性,但如何将表征层面的发现转化为可操作的诊断工具仍是开放问题。

第三个局限是训练数据的依赖性。Stage 1 需要高质量的显式推理步骤作为监督信号,这意味着 RiM 的适用范围目前仍局限于那些可以获得逐步推理标注的领域(如数学问题)。对于缺乏此类标注的领域——例如常识推理、因果推断或开放式创作——如何设计 Stage 1 的 grounding 监督仍是一个未解决的挑战。此外,论文的实验集中在相对短链的数学推理(最多 13 步),更长或更分支化的推理结构(如程序综合、定理证明)是否能在固定记忆块中得到有效编码,还需要进一步的验证。

最后,论文的结论部分已经指出了两个非常有前景的未来方向。其一是探索 Stage 2 中基于强化学习(Reinforcement Learning, RL)的最终答案奖励机制——如果模型能够从答案级别的反馈而非逐词监督中学习优化其记忆块的使用方式,可能进一步释放工作记忆的潜力。其二是混合方法,将 RiM 与显式生成结合,用于处理更复杂的问题。这些方向如果得到充分探索,可能催生出既具有 RiM 的效率优势,又能处理更广泛问题类型的下一代推理架构。

延伸阅读与思考 (Further Reading and Reflection)

要全面理解 RiM 在更广阔学术图景中的位置,我们需要将其与几个关键的相关工作脉络联系起来,并思考这些脉络交汇处的开放问题。

在显式推理的谱系上,RiM 最直接的前身是 Wei 等人 2022 年的 chain-of-thought 工作和 Nye 等人 2021 年的 scratchpad 方法。这些工作确立了"在问题和答案之间插入中间计算可以显著提升推理能力"这一基本原则,但代价是将推理与语言生成耦合。沿着降低这种耦合的方向,后续工作从多个角度进行了探索:Zelikman 等人 2022 年的 STaR 和 Muennighoff 等人 2025 年的 s1 通过自举或监督微调扩展了显式推理的覆盖范围;DeepSeek-AI 2025 年的 DeepSeek-R1 展示了强化学习在激励推理能力方面的巨大潜力;而 Brown 等人 2024 年的"大语言猴子"则通过重复采样来放大推理的探索空间。RiM 与这些工作的关系并非替代,而是互补——它解决的问题是如何在保留推理能力的前提下消除显式生成的计算开销。

在隐式推理的领域,Coconut(Hao 等,2025)是 RiM 最直接的比较对象和技术对话者。Coconut 的核心洞见是推理不需要离散的语言标记,连续表征同样能够传递信息。RiM 继承并发展了这一洞见,进一步提出连连续表征的自回归生成都不是必需的。其他相关方法各有侧重:Cheng 和 Durme 2024 年的压缩思维链追求表征密度的最大化;Shen 等人 2025 年的 CODI 通过自蒸馏将显式 CoT 压缩为连续空间;Wang 等人 2025a 的 HRM 和 Jolicoeur-Martineau 2025 年的 TRM 通过递归模块实现迭代的潜在精炼;Liu 等人 2025 年的 MARCOS 将马尔可夫链结构引入连续思维。RiM 的独特之处在于它用水平序列中的固定位置替代了垂直/递归的迭代结构,用单次前向传播替代了多步生成循环。

填充标记(filler token)的研究线为 RiM 提供了重要的背景参考。Lanham 等人 2023 年的负面结果、Pfau 等人 2024 年在合成任务上的突破、Goyal 等人 2024 年在真实任务上的扩展,以及 Deng 等人 2024 年 DART 的双路径框架,共同勾勒出一个核心教训:让"空洞"的标记变得有用是可能的,但需要精心设计的训练信号。RiM 通过两阶段课程和定制注意力掩码,为这一教训提供了一个更简洁、更高效的实现方案。

最令人兴奋的未来方向可能是跨领域的扩展。目前 RiM 的验证局限于数学推理,但工作记忆的概念在认知科学中具有普遍性——人类在理解文本、规划行动、进行视觉空间推理时都依赖工作记忆。如果 RiM 的原理能够迁移到代码生成(程序中间状态的潜在编码)、多步决策(强化学习中的潜在策略评估)、甚至科学发现(假设空间的潜在探索),那么它的影响将远超当前的数学推理基准。另一个开放问题是记忆块表征的可解释性:线性探针已经展示了这些表征中蕴含丰富的正确性信息,但能否进一步解码出具体的"子计算"——例如某个记忆块编码了"累加操作",另一个编码了"数值比较"?如果答案是肯定的,RiM 可能成为连接神经网络黑箱和可解释人工智能(Explainable AI, XAI)研究的重要桥梁。

对我而言,这篇论文最发人深省的地方在于它揭示了一个被长期忽视的简单可能性:我们之所以一直让模型"大声思考",可能只是因为我们习惯了阅读它们的输出,而非因为思考本身需要语言。人类大脑的工作记忆系统并不依赖语言——失语症患者仍然能够进行逻辑推理,婴儿在习得语言之前已经表现出因果推理能力。RiM 的成功暗示,大语言模型或许也拥有类似的非语言推理潜能,只是我们之前的训练范式未能有效地解锁它。这让我想知道:如果进一步增大记忆块的数量或改变其内部结构(例如论文中提到的双向注意力变体),模型能否处理更复杂的、需要长期依赖和分支探索的推理任务?当工作记忆的容量扩大时,模型是否会自发地涌现出更复杂的内部组织模式,就像人类工作记忆中从简单的信息保持发展到复杂的操作和策略一样?这些问题没有现成答案,但 RiM 为探索它们提供了一个坚实的技术平台。

Topics:

Powered by Forestry.md