这项由高通AI研究院的Ramchalam Kinattinkara Ramakrishnan、Zhaocong Yuan等七位研究人员共同完成的研究,发表于2025年7月3日,论文编号为arXiv:2507.02659v1。感兴趣的读者可以通过arXiv平台获取完整论文内容。这项研究解决了一个在人工智能领域颇为头疼的问题:如何让一个小巧的AI模型为各种不同的大型AI模型提供高效的"草稿服务"。
当我们使用ChatGPT、Claude这样的大型语言模型时,它们需要逐字逐句地生成回答,就像一个作家在稿纸上一个字一个字地写作。这个过程很慢,特别是在手机或其他移动设备上使用时更是如此。为了解决这个问题,研究人员想出了一个巧妙的办法:让一个小而快的"草稿员"模型先快速写出初稿,然后让大模型来检查和修正这个初稿。这就像是让一个速记员先快速记录,然后让专业编辑来润色一样。
然而,现实中存在一个棘手的问题。不同的AI模型就像来自不同国家的人,它们使用着不同的"词汇表"。一个专门为Llama模型训练的草稿员,无法直接为Qwen或其他模型提供草稿服务,因为它们对同一个词汇的理解方式不同。这就好比一个习惯了美式英语的速记员,突然要为一个只懂英式英语的编辑工作,两者之间的词汇差异会造成很多误解。
高通AI研究院的团队提出了一个名为OmniDraft的解决方案,它的核心思想是创建一个"万能翻译官",让同一个小型草稿模型能够为任何大型目标模型提供服务。这个方案包含了三个巧妙的创新。
一、跨词汇表的智能翻译系统
研究团队首先解决的是不同模型之间的"语言障碍"问题。他们设计了一个叫做"n-gram缓存"的翻译系统,这个系统就像是一个智能词典,能够记住不同模型之间的词汇对应关系。
传统的做法是只处理两个模型词汇表中完全相同的词汇,这就像两个人只能用共同认识的词汇交流,大大限制了交流的丰富性。而OmniDraft的n-gram缓存更加聪明,它能够处理更复杂的对应关系。比如,草稿模型可能将"snowflake"(雪花)分解为"snow"、"f"、"la"、"ke"四个部分,而目标模型可能将其识别为"snow"和"flake"两个部分。n-gram缓存能够学会这种对应关系,将草稿模型的四个片段正确地组合成目标模型能理解的两个词汇。
这个过程就像一个经验丰富的翻译官,不仅能翻译单个词汇,还能理解不同语言中词汇组合的方式。当草稿模型提出一系列词汇片段时,翻译系统会查看缓存,看看这些片段是否能组合成目标模型更喜欢的形式。如果找到了匹配的组合,就会将多个小片段合并成一个完整的词汇,大大提高了被目标模型接受的可能性。
更重要的是,这个缓存系统是动态学习的。每当系统遇到新的词汇对应关系时,都会将其记录下来,供将来使用。这就像一个翻译官在工作中不断积累经验,遇到的对应关系越多,翻译能力就越强。
二、在线混合蒸馏训练
解决了翻译问题后,研究团队面临的第二个挑战是如何让草稿模型更好地理解目标模型的"思维方式"。他们开发了一种叫做"在线混合蒸馏"的训练方法。
这个过程可以比作师傅带徒弟的学习方式。草稿模型(徒弟)在实际工作中观察目标模型(师傅)的表现,然后调整自己的行为来更好地配合师傅。具体来说,当目标模型接受了草稿模型的建议时,草稿模型会记住这次成功的经验;当目标模型拒绝建议并给出修正时,草稿模型也会从这次"纠错"中学习。
传统的训练方法通常是离线进行的,就像学生在考试前突击复习一样。而OmniDraft采用的是在线学习方式,更像是边工作边学习的学徒制。这种方法的优势在于,草稿模型能够根据具体的使用场景和用户数据不断调整自己,而不是一成不变地使用固定的知识。
混合蒸馏的"混合"体现在训练方法的灵活性上。对于可以直接对应的词汇,系统使用一种叫做"逆向KL散度"的方法来对齐两个模型的概率分布,这就像让徒弟学习师傅对同一个问题的判断方式。对于需要通过n-gram缓存翻译的词汇,系统则使用"最大似然估计"的方法,重点提高这些词汇被正确预测的概率。
研究团队还引入了一个动态权重参数λ,用来平衡这两种训练方式的重要性。这个参数可以根据实际情况调整,比如当遇到的翻译词汇较多时,可以增加翻译相关训练的权重;当直接对应的词汇较多时,则增加概率对齐训练的权重。
三、自适应草稿长度调整
OmniDraft的第三个创新是智能的草稿长度调整机制。这个机制就像一个经验丰富的秘书,能够根据不同情况调整汇报的详细程度。
在实际应用中,草稿模型需要决定每次应该提供多少个词汇建议。提供太少的建议可能无法充分利用加速的潜力,而提供太多的建议则可能导致大部分被拒绝,反而浪费计算资源。传统的做法是使用固定的草稿长度,但这显然不够灵活。
OmniDraft引入了一个"接受率预测头",这个小型神经网络能够预测每个词汇建议被目标模型接受的可能性。基于这些预测,系统会动态计算继续提供更多建议的风险。如果预测显示后续建议被拒绝的概率很高,系统就会提前停止,避免浪费计算资源。
这个预测机制使用了一种叫做"sigmoid函数"的数学工具来估计接受概率,然后计算所有建议中至少有一个被拒绝的总体概率。当这个概率超过预设的阈值时,系统就会停止生成更多建议。这就像一个精明的销售员,能够判断客户的兴趣程度,在合适的时候结束推销。
在在线学习环境中,这个预测头面临着一个特殊的挑战:随着草稿模型不断改进,词汇被接受的概率也在变化,这意味着预测头需要同步调整。研究团队提出了两种解决方案。
第一种是"联合训练"方法,让草稿模型和预测头同时更新。这种方法简单直接,但可能因为两个组件的学习速度不同而产生不稳定性。第二种是"交替训练"方法,为预测头维护一个更大的数据缓冲区,包含历史数据,这样可以提供更稳定的训练环境。实验表明,交替训练方法通常能获得更好的性能。
四、实验验证与性能表现
为了验证OmniDraft的有效性,研究团队进行了大量的实验测试。他们选择了一个仅有68M参数的Llama小模型作为草稿员,并测试了它与多个不同大型模型的配合效果,包括Llama3-8B、Qwen2-7B和Vicuna-7B。
实验涵盖了四个不同的任务领域。在数学推理任务中,他们使用了GSM8K数据集,这个数据集包含了各种小学数学应用题。在编程任务中,他们结合了MBPP和HumanEval两个代码生成数据集。在文本生成方面,他们使用了Alpaca指令跟随数据集。在文本摘要任务中,他们采用了XSum新闻摘要数据集。
实验结果令人印象深刻。在跨词汇表的场景中,传统的直接映射方法(SpDDM)几乎无法工作,接受率通常在0.1左右,加速比甚至低于1,这意味着不仅没有加速反而变慢了。而使用OmniDraft的LDM(直接映射训练)方法能将接受率提升到0.2-0.4的范围,加速比达到1.2-1.6倍。
当加入n-gram损失项后,性能进一步提升。LDM + λLN-gram方法在所有任务上都表现出色,接受率通常能达到0.2-0.4,加速比在1.2-1.7倍之间。特别值得注意的是,在GSM8K数学推理任务上,无论是Llama3-8B还是Qwen2-7B作为目标模型,都能获得最大的加速效果,这可能是因为数学推理任务具有更强的结构性和可预测性。
研究团队还测试了使用LoRA(Low-Rank Adaptation)技术的效果。LoRA是一种参数高效的微调方法,只需要更新模型的一小部分参数。实验显示,即使使用LoRA这种"轻量级"的训练方式,OmniDraft仍然能够获得显著的性能提升,虽然效果略低于全参数微调,但对于资源受限的边缘设备来说,这种方案提供了很好的性能和效率平衡。
在自适应草稿长度调整的实验中,研究团队发现了一些有趣的现象。联合训练方法虽然能够获得更高的接受率,但在某些任务上的加速比反而不如交替训练方法。这表明高接受率不一定直接转化为更好的加速效果,可能是因为联合训练方法容易低估接受概率,导致过早停止生成建议。