2026年7月2日 周四晚上19:30,报名腾讯会议了解“如何构建自进化的动态知识库(Brain)”(限30人)
免费POC, 零成本试错
FDE知识库

FDE知识库

学习大模型的前沿技术与行业落地应用


收藏

大模型知识蒸馏的两种方式

发布日期:2024-09-07 07:23:15 浏览次数: 7565
作者:NetRookie

微信搜一搜,关注“NetRookie”

      上个月llama3.1的405B已经发布,除了感叹开源模型效果的厉害之外,另一个普遍的感受就是,跑不动,根本跑不动,没资源,就算能训练,也部署不起。所以很多人就自然而然关注到了知识蒸馏,通过将大模型能力迁移到小模型能力上。于是大概调研了下,本文主要是对清华的《MiniLLM:Knowledge Distillation of Large Language Models》和Meta的《Distilling System2 into System1》一些解读,刚好他们分别作为白盒蒸馏和黑盒蒸馏的一个典型代表。

       在知乎搜了一下minillm相关的文章,如《 吃果冻不吐果冻皮:大模型知识蒸馏概述 》总结性的介绍了下minillm的逆向kl散度的思路。即最小化前向 Kullback-Leibler 散度 (KLD) 的挑战为教师分布中不太可能的区域出现概率过高,从而在自由运行生成过程中导致不可能的样本 。为了解决这个问题,MINILLM 选择最小化逆向 KLD。这种方法可以防止学生高估教师分布中的低概率区域,从而提高生成样本的质量。但具体原因只在论文中才更清楚,于是部分细节整理如下,本文主要对FKL和RKL差异以及从强化学习视角看MiniLLM做一些介绍,其他论文细节没有涉及太多。

MiniLLM蒸馏

Motivation

前向KL散度倾向于学习mean-seeking,反向kl散度学习mode-seeking

首先需要明确KL散度的非对称性质,即前向和后向是当前仅当两个分布完全相等时才等价的,然后我们分开看一下两个kl散度的具体公式。了解前向KL散度和KL散度分别会导致mode seeking 和 mean seeking 产生的原因在于:

  • 当p(x)较大时,qθ(x) 也需要比较大且比p(x)相对更大,否则公式右边很大的情况下,FKL整体就无法达到最小;

  • 当p(x)较小时,p(x) 在 log 外趋于0占主导,FKL整体总是能比较小,跟qθ(x)关系较小。所以在优化的时候,qθ(x) 会覆盖p(x)的所有mode,即便此时有可能导致高估 p(x) 很小的部分,对应上述图中的橙色部分。

  • 当qθ(x) 较大时,为了在优化时候降低RKL,p(x) 必须较大,因此 p(x) 概率最大的 mode 也要对应 qθ(x) 概率最大的地方,p(x) 概率很小的地方必须对应 qθ(x) 概率为0的地方,也就是说 qθ(x) 拟合了 p(x) 概率最大的部分。对应前述图中的绿色部分。

  • 当 qθ(x) 等于0时,p(x) 取什么样的值都不影响优化。

由此可以看下MiniLLM具体的方法图:

RKL和Inverse RL的等价的数学推导

论文中的另一个视角个人觉得特别好,就是将RKL和逆强化学习进行对比,并给出了数学说明,可以看一下

公式说明:这里的公式序号均来自论文本身,目的是结合论文一起看可能更好,不破坏原有公式顺序。


既然可以这么类比,RKL约等于逆强化学习,FKL等价于模仿学习,而在实际应用和理论说明中,逆强化学习的效果都会比模仿学习更优,虽然更加难以训练,但其泛化性能,理论上限肯定会更高,所以结论是MiniLLM的RKL理论上是更优的。 模仿学习和逆强化学习这个说明可以查看:https://www.zhihu.com/question/470949607/answer/2450111740?utm_id=0

实际怎么训练

上面两个部分,其实都在说明,MiniLLM理论证明上是更优的蒸馏方法。所以我们可以去进行大胆尝试。实际的训练过程,类似于RLHF的训练方式,教师模型在训练中只推理,作为奖励信号去训练模型。作者也提供了类似ranking loss的更简单的平替方式去优化,对比传统Bert时代的蒸馏方法都会有提升。感谢作者!

# https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
# 这里是实际蒸馏loss的计算
def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits):
with torch.no_grad():
teacher_model.eval()
teacher_outputs = teacher_model(**model_batch, use_cache=False) # 教师模型推理
teacher_logits = teacher_outputs.logits # 获取教师分布logits
if args.model_parallel:
distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float())
distil_losses = distil_losses.view(-1)
loss_mask = no_model_batch["loss_mask"].view(-1)
distil_loss = (distil_losses * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) #
inf_mask = torch.isinf(logits)
#log_softmax实际上是在教师和学生的交叉熵;交叉熵损失在形式上等价于KL散度减去一个常数项(分布P 的熵)在最小化KL散度时可以忽略
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (no_model_batch["label"] != -100).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

return distil_loss

一般来说实际训练中还会加上sft数据的loss确保不跑偏,类似rlhf中的reference model的作用
output = self.model(inputs, attention_mask=attention_mask, return_output=True)
sft_loss = self.loss_fn(output.logits, labels)

需要注意教师模型和学生模型需要使用同源的模型。即相同的tokenizer,对于国产模型来说,qwen、deepseek、yi等都有相同tokenizer不同尺寸的模型可供选择。

System2到System1蒸馏

整体说明:

人类认知系统中的两种推理系统,系统1和系统2,系统1被认为是无意识的,能够快速识别和迅速判断,也叫做快思考,系统2被认为是处理复杂问题如数学和逻辑问题,需要深思熟虑,也叫做慢思考。

在大模型中,可以将中间的流程如多次调用大模型、中间的思考tokens类比为深思熟虑的过程,这些方法如cot、RaR等等带来更好的推理效果,但与此同时,耗时问题会导致这些方法很难用于生产落地。于是很多方法都在尝试将系统2的效果蒸馏到系统1当中(毕竟自2023年 gpt4出来后,应该有非常多的黑盒蒸馏gpt4数据训练到各家系统中;还有很久之前llama2的ghost attention:在每一句中都加入system prompt让 Llama 2 有效地遵循多轮指令,都是一些蒸馏的有效形式)。

这篇论文的主要与之前差异点在于,显式的提出System2的推理能力蒸馏到System1中,并做了很多实验进行验证。可以理解为论文提供了非常好的一种数据合成的范式,通过使用这些数据进行指令微调等方法,提升System1的推理能力。

以下几个公式是对System1和System2的形式化说明:

也就是说,通过上述公式3可以得到的大量训练数据,但是实际会存在质量问题。论文主要通过一致性标准进行过滤。

  • 输出一致性:输入不变,对输出进行N次采样,通过投票实现,少数服从多数

  • 输入扰动下的一致性:输出不变,对输入增加扰动,比如选择题改顺序但答案没变化,不一致则过滤

但猜测实际可能有更多更精细化的方式实现。

然后就是四种方式在不同数据集上的效果,我觉得给出Prompt可能是最好的方法体现形式

RePhrase And Respond Distillation

Prompt:

"{question}"\nRephrase and expand the question, and respond.

让模型先改写,改写可能提供更丰富的文本信息,然后再回答,能让大模型用自己的知识体现理解问题,回答问题。

System2 Attention Distillation

让大模型过滤无效信息,去除有偏信息和不相干上下文,然后再改写基础上进行回答

Branch-Solve-Merge Distillation

Chain Of Thought Distillation

论文通过这四种System2的方式,蒸馏到System1当中,做了很多实验,结果就不一一贴了,都是差的也不会发paper,总结下来,整体是有效的,如RaR蒸馏可用于澄清任务指令相关任务、S2A能有效提升有偏任务,Branch-Solve-Merge蒸馏能作为LLM-Judge评估任务,但是在复杂推理任务上的蒸馏,目前还做得不好。这可能也是一个共识,需要持续研究。

总结,不管是黑盒蒸馏,还是白盒蒸馏,都是现如今非常好的将更大模型的知识注入到较小模型中去的方式,不断提升小模型的知识密度,这样可以再更多的落地场景中应用。期待这个方向后续更多的工作。


53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

添加专属顾问

回到顶部

加载中...

扫码咨询

扫码登录
登录即表示您同意《53AI网站服务协议》
服务协议

欢迎您使用【53AI 官方网站】(以下简称“本网站”或“我们”)。本《会员服务协议》(以下简称“本协议”)是您(以下简称“会员”或“用户”)与【深圳市博思协创网络科技有限公司】之间关于注册、登录及使用本网站会员服务所订立的法律协议。

在您注册或登录前,请务必审慎阅读、充分理解各条款内容,特别是免除或限制责任的条款、知识产权条款、争议解决条款等。此类条款将以加粗形式提示您注意。 当您通过微信公众号授权、手机验证码验证或其他方式成功登录本网站时,即视为您已完全理解并同意接受本协议的全部内容。

一、 定义

本网站:指由【深圳市博思协创网络科技有限公司】运营的,域名为【53ai.com】的网站及相关移动端页面。

会员服务:指本网站向注册会员提供的知识库文章查阅、内容检索及其他相关增值服务。

知识库内容:指本网站发布的包括但不限于文字、图表、数据、研究报告、行业分析等数字化内容资源。

二、 账号注册与登录

登录方式:本网站支持以下登录方式,您可根据实际情况选择:

微信公众号授权登录:您同意将您的微信OpenID信息授权给本网站,用于创建或关联会员账号。

手机验证码登录:您需提供真实有效的手机号码,并通过短信验证码完成身份验证与登录/注册。

账号安全:您的账号仅限您本人使用,禁止赠与、借用、租用、转让或售卖。因您保管不善导致的账号被盗、密码泄露等损失,由您自行承担。

实名认证:根据相关法律法规要求,我们可能要求您在特定功能下完成实名认证。如您拒绝提供,可能无法使用部分或全部服务。

未成年人保护:若您未满18周岁,请在法定监护人的陪同下阅读本协议,并在征得监护人同意后使用本服务。

三、 服务内容与规范

知识库查阅权限:会员登录后,有权按照其会员等级对应的权限范围,在线浏览、检索本网站知识库中的相关文章及内容。

服务变更:我们有权根据业务发展需要,调整、变更或终止部分服务内容,并将以网站公告、公众号消息等方式提前通知。

禁止行为:您在使用服务时不得实施以下行为:

利用技术手段批量爬取、下载、转存知识库内容;

将知识库内容用于商业目的或未经授权地向第三方传播;

干扰本网站正常运行或侵犯其他用户合法权益;

发布违法违规信息或从事违反公序良俗的活动。

四、 知识产权声明

权利归属:本网站知识库中的排版设计、软件代码等内容的知识产权均归【公司全称】或原权利人所有,受《中华人民共和国著作权法》等法律保护。

有限许可:本网站授予会员一项非独占、不可转让、不可转授权的普通许可,仅限于个人学习、研究之目的在线查阅知识库内容。

侵权追责:未经书面许可,任何单位或个人不得以任何形式复制、转载、摘编、镜像、汇编或以其他方式使用上述内容。一经发现,我们保留追究其法律责任的权利。

五、 个人信息保护

我们重视对您个人信息的保护。关于我们如何收集、使用、存储和保护您的个人信息,请单独阅读 《隐私政策》。

您通过微信公众号授权或手机号验证所提供的信息,我们将严格按照《个人信息保护法》的规定处理,仅用于身份识别、服务提供及安全验证等必要用途。

您可以随时通过网站设置或联系客服行使查阅、更正、删除个人信息及撤回授权同意的权利。

六、 免责声明

内容准确性:知识库内容仅供参考,不构成专业建议。我们不对其完整性、准确性、时效性作任何明示或暗示的保证,您应自行判断并承担使用风险。

不可抗力:因自然灾害、政策法规变化、网络故障、第三方平台接口异常(如微信接口维护、运营商短信通道故障)等不可抗力导致的服务中断或延迟,我们不承担违约责任。

第三方链接:本网站可能包含指向第三方网站的链接,该等网站的内容和服务不受我们控制,请您自行甄别风险。

七、 违约责任

如您违反本协议约定,我们有权视情节采取警告、限制功能、暂停服务、注销账号等措施,并保留要求赔偿损失的权利。

如因您的违约行为导致我们遭受行政处罚、第三方索赔或商誉损失,您应承担全部赔偿责任(包括但不限于罚款、赔偿金、律师费、公证费等)。

八、 法律适用与争议解决

本协议的订立、执行和解释均适用中华人民共和国大陆地区法律。

因本协议产生的或与本协议有关的任何争议,双方应友好协商解决;协商不成的,任何一方均可向【公司所在地】有管辖权的人民法院提起诉讼。

九、 其他

本协议构成双方就本服务达成的完整协议,取代此前任何口头或书面约定。

本协议任一条款被认定为无效或不可执行的,不影响其他条款的效力。

我们对本协议享有最终解释权,并在法律允许的范围内保留随时修改的权利。修改后的协议一经公布即生效,继续使用服务即视为同意修订内容。


已查阅