2026年7月9日 周四晚上19:30,报名腾讯会议了解“如何构建自进化的动态知识库(Brain)”(限30人)
免费POC, 零成本试错
FDE知识库

FDE知识库

学习大模型的前沿技术与行业落地应用


收藏

前女友面试官:大模型内存占用机制是怎样的?

发布日期:2024-08-30 06:23:25 浏览次数: 5420
作者:丁师兄大模型

微信搜一搜,关注“丁师兄大模型”






大模型时代能够充分利用 GPU 的显存是一项非常有必要的技能。本文将在仅考虑单卡的情况下为大家讲明白大模型的内存占用机制,相信对大家后续训练、使用大模型都非常有帮助。

下面将围绕以下三个问题展开:
  • 告诉你一个模型的参数量,你要怎么估算出训练和推理时的显存占用?
  • Lora相比于全参训练节省的显存是哪一部分?Qlora相比Lora呢?
  • 混合精度训练的具体流程是怎么样的?

这是我曾在面试中被问到的问题,为了巩固相关的知识,打算系统的写一篇文章,帮助自己复习备战秋招的同时,希望也能帮到各位小伙伴。

这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。

01
数据精度
想要计算显存,从“原子”层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少 bit。

我们都知道:

  • 1 byte = 8 bits
  • 1 KB = 1,024 bytes
  • 1 MB = 1,024 KB
  • 1 GB = 1,024 MB

由此可以明白,一个含有 1G 参数的模型,如果每一个参数都是 32bit(4byte),那么直接加载模型就会占用 4x1G 的显存。

(1)常见的几种精度类型

个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:

各种精度的数据结构

可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。

符号位都是 1 位(0 表示正,1 表示负),指数位影响浮点数范围,小数位影响精度。

其中 TF32 并不是有 32bit,只有 19bit 不要记错了。BF16 指的是 Brain Float 16,由 Google Brain 团队提出。

(2)具体计算例子

我说实话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据。

我以 BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:

题目:
随机生成的 BF16 精度数据

先给出具体计算公式:

然后 step by step 地分析(不是,怎么还对自己使用上 Cot 了)。

符号位 Sign = 1,代表是负数,指数位 Exponent = 17,中间一坨是:,小数位 Mantissa = 3,后面那一坨是:

最终结果:三个部分乘起来就是最终结果:-8.004646331359449e-34。

注意事项:中间唯一需要注意的地方就是指数位是的全 0 和全 1 状态是特殊情况,不能用公式。

02
全参训练和推理的显存分析
OK 了,我们知道了数据精度对应存储的方式和大小,相当于我们了解了工厂里不同规格的机器零件。
但我们还需要了解整个生产线的运作流程,我们才能准确估算出整个工厂(也就是我们的模型训练过程)在运行时所需的资源(显存)。

那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。

(1)混合精度训练

原理介绍

顾名思义,混合精度训练就是将多种不同的精度数据混合在一起训练,《 MIXED PRECISION TRAINING 》这篇论文里将 FP16 和 FP32 混合,优化器用的是 Adam,如下图所示:

MIXED PRECISION TRAINING 论文里的训练流程图

按照训练运行的逻辑来讲:

  • Step1:优化器会先备份一份 FP32 精度的模型权重,初始化好 FP32 精度的一阶和二阶动量(用于更新权重)。
  • Step2:开辟一块新的存储空间,将 FP32 精度的模型权重转换为 FP16 精度的模型权重。
  • Step3:运行 forward 和 backward,产生的梯度和激活值都用 FP16 精度存储。
  • Step4:优化器利用 FP16 的梯度和 FP32 精度的一阶和二阶动量去更新备份的 FP32 的模型权重。
  • Step5:重复 Step2 到 Step4 训练,直到模型收敛。

我们可以看到训练过程中显存主要被用在四个模块上:

  • 模型权重本身(FP32+FP16)
  • 梯度(FP16)
  • 优化器(FP32)
  • 激活值(FP16)

三个小问题

写到这里,我就有 3 个小问题,第一个问题,为什么不全都用 FP16,那不是计算更快、内存更少?

根据我们第一章的知识,我们可以知道 FP16 精度的范围比 FP32 窄了很多,这就会产生数据溢出和舍入误差两个问题,这会导致梯度消失无法训练,所以我们不能全都用 FP16,还需要 FP32 来进行精度保证。

看到这里你也许会想到可以用 BF16 代替,是的,这也是为什么如今很多训练都是 BF16 的原因,至少 BF16 不会产生数据溢出了,业界的实际使用也反馈出比起精度,大模型更在意范围。

第二个问题,为什么我们只对激活值和梯度进行了半精度优化,却新添加了一个 FP32 精度的模型副本,这样子显存不会更大吗?

答案是不会,激活值和 batch_size 以及 seq_length 相关,实际训练的时候激活值对显存的占用会很大,对于激活值的正向优化大于备份模型参数的负向优化,最终的显存是减少的。

第三个问题,我们知道显存和内存一样,有静态和动态之分别,那么上面提到的哪些是静态哪些是动态呢?

应该很多人都能猜到:

  • 静态:优化器状态、模型参数
  • 动态:激活值、梯度值

也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。

动态监控显存图

来个测试吧!

写到这里,我们应该对于分析大模型训练时候的显存问题应该不在话下了(除了动态部分),那么我们就来实测一下,正在阅读的小伙伴也可以先自己尝试计算一下,看看是不是真的懂了。

对于 llama3.1 8B 模型,FP32 和 BF16 混合精度训练,用的是 AdamW 优化器,请问模型训练时占用显存大概为多少?

解:

  • 模型参数:16(BF16) + 32(PF32)= 48G
  • 梯度参数:16(BF16)= 16G
  • 优化器参数:32(PF32) + 32(PF32)= 64G
  • 不考虑激活值的情况下,总显存大约占用 (48 + 16 + 64) = 128G

(2)推理与 KV Cache

原理理解

推理的时候,显存几乎只考虑模型参数本身,除此之外就是现在广泛使用的 KV cache 也会占用显存。

KV cache 与之前讲的如何减少显存不一样,KV cache 的目的是减少延迟,也就是为了推理的速度牺牲显存。

具体 KV cache 是什么我就不展开讲了,我贴一张动图就可以非常清晰地明白了。

记住一点,我们推理就是在不断重复地做”生成下一个 token“的任务,生成当前 token 仅仅与当前的 QKV 和之前所有 KV 有关,那么我们就可以去维护这个 KV 并不断更新。

KV Cache 动态实现

顺便回答一个很多小白经常会问的问题,为什么没有Q Cache呢?

因为生成当前的 token 只依赖当前的 Q,那为什么生成当前的 token 只依赖当前的 Q 呢,因为 Self-Attention 的公式决定的。

S 代表 Softmax 激活函数:

我们可以看到,在序列 t 的位置,也就是第 t 行,只跟 有关系,也就是说,Attention 的计算公式就决定了我们不需要保存每一步的 Q,再深入地说,矩阵乘法的数学特性决定了我们不需要保存每一步的 Q。

计算 KV Cache 显存

如何计算 KV Cache 的显存是我这篇文章想要关心的事情,先给出公式:

前面的四个参数相乘应该很好理解,就是 KV 对应在模型每一层的所有隐藏向量的总和,第一个 2 指的是 KV 两部分,第二个 2 指的是半精度对应的字节数。

举个栗子,对于 llama7B,hiddensize = 4096,seqlength = 2048 , batchsize = 64,layers = 32 计算得到:

可以看到,KV Cache 在大批量长句子的情况下,显存占用率也是很大的。

68G 看着是相对模型本身很大,但这是在 batch 很大的情况下,在单 batch 下,KV Cache 就仅占有 1G 左右的显存了,就仅仅占用模型参数一半的显存。

MQA 和 GQA

什么,你觉得 KV Cache 用的显存还是太多了,不错,对于推理落地侧,再怎么严苛要求也是合理的,MQA 和 GQA 就是被用来进一步减少显存的方法,现在的大模型也几乎都用到了这个方法,我们就来讲一讲。

三种 KV 处理方式

其实方法不难理解,看这张图一目了然,关键词就是“共享多头 KV”,很朴素的删除模型冗余结构的思路。

最左侧就是最基础的 MHA 多头自注意力,中间的 GQA 就是保留几组 KV 头,右侧 MQA 就是只保留 1 组 KV 头,目前用的比较多的是 GQA,降低显存提速的同时也不会太过于影响性能。

上一小节我们知道 MHA 的 KV Cache 占用显存的计算公式是:

有一个小细节,可以重头开始训练 MQA 和 GQA 的模型,也可以像 GQA 论文里面一样基于开源模型,修改模型结构后继续预训练。目前基本上都是从头开始训练的,因为要保持训练和推理的模型结构一致。

03
Lora和Qlora显存分析
上面两章详细对全参微调训练和推理进行了显存分析,聪明的小伙伴就发现了一个问题,现在都用 PEFT(高效参数微调)了,谁有那么多资源全参训练啊推理阶段也是要量化的,这样又该怎么进行显存分析呢。
那么我们这一章就来解决这个问题,我相信完全理解前两章的小伙伴理解起来会非常轻松,所谓的显存分析,只要知道了具体的流程和数据精度,那么分析的方法都是类似的。
OK,我们将会在这一章里详细分析目前前业界最火的 Lora 和 Qlora 方法的显存占用情况,中间也会涉及到相关的原理知识,冲!

(1)Lora

能看到这里的人,我想对于 Lora 的原理应该都很了解了,就浅浅提一下,如下图所示,就是在原来的权重矩阵的旁路新建一对低秩的可训练权重,训练的时候只训练旁路,大大降低了训练的权重数量,参数量 d*d 降为 2*d*r。

Lora 原理图

有了前面的全参情况下训练的显存分析,现在分析起来就比较通顺了,我们一步一步来,还是以 BF16 半精度模型 Adamw 优化器训练为例子,lora 部分的参数精度也是 BF16,并且设 1 字节模型参数对应的显存大小 φ。

首先是模型权重本身的权重,这个肯定是要加载原始模型和 lora 旁路模型的,因为 lora 部分占比小于 2 个数量级,所以显存分析的时候忽略不计,显存占用 2φ。

然后就是优化器部分,优化器也不需要对原模型进行备份了,因为优化器是针对于需要更新参数的模型权重部分进行处理,也就是说优化器只包含 Lora 模型权重相关的内容,考虑到数量级太小,也忽略不计,故优化器部分占用显存 0φ。

其实容易搞错混淆的部分就是梯度的显存了,我看了不少的博客文章,有说原始模型也要参与反向传播,所以是要占用一份梯度显存的,也有的说原始模型都不更新梯度,肯定只需要 Lora 部分的梯度显存,搞得我头很大。

那么究竟正确答案是哪一种呢,这里直接给出答案,不需要计算原始模型部分的梯度,也基本不占用显存。也就是说梯度部分占用显存也可以近似为 0φ。

总的来说,不考虑激活值的情况下,Lora 微调训练的显存占用只有 2φ,一个 7B 的模型 Lora 训练只需要占用显存大约 14G 左右。

验证一下,我们来看 Llama Factory 里给出训练任务的显存预估表格:

Llama Factory 的表格

可以看到 7B 模型的 Lora 训练的显存消耗与我们估计得也差不多,同时也还可以复习一下全参训练、混合精度训练的显存分析,也是基本符合我们之前的分析的。

(2)QLora

上面 Llama Factory 的那张表也是稍微剧透了一下我们接下来要讲的内容,也就是 QLora,继 Lora 之后也是在业界落地非常广泛通用的一种大模型 PEFT 方法。

QLora,也叫做量化 Lora,顾名思义,也就是进一步压缩模型的精度,然后用 Lora 训练,他的核心思路很好理解,但实际上涉及的知识点细节却并不少。

我同样也不会太过深入地去介绍这个中细节,我主要是想按照显存占用的思路去分析 Qlora,理解思路永远比死的知识点更加重要。

Qlora 的整体思路

Qlora 来自于《 QLORA: Efficient Finetuning of Quantized LLMs 》这篇论文,实际上这篇论文的核心在于提出了一种新的量化方法,重点在于量化而不是 Lora。

很多不了解的人看到量化 lora 这个名字就以为是对 Lora 部分的参数进行量化,因为他们认为毕竟只有 Lora 部分的参数参与了训练。

但理解了上面一节的小伙伴就明白实际并不是这样,原始模型的本身参数虽然不更新参数,但是仍然需要前向和反向传播,QLora 优化的正是 Lora 里显存占大头的模型参数本身。

那么 Qlora 就是把原始模型参数从 16bit 压缩到 4bit,然后更新这个 4bit 参数吗?

非也非也,这里需要区分两个概念,一个是计算参数,一个是存储参数,计算参数就是在前向、反向传播参与实际计算的参数,存储参数就是不参与计算一开始加载的原始参数。

QLora 的方法就是,加载并且量化 16bit 的模型原始参数为 4bit 作为存储参数,但是在具体需要计算的时候,将该部分的 4bit 参数反量化为 16bit 作为计算参数。

也就是说,QLora 实际上我们训练计算里用到的所有数据的精度都是和 Lora 一样的,只是加载的模型是 4bit,会进行一个反量化到 16bit 的方法,用完即释放。

前面说到的都是模型原始参数本身,不包括 lora 部分的参数,Lora 部分的参数不需要量化,一直都是 16bit。

看到这里机智的你应该也想到了,这比 Lora 多了一个量化反量化的操作,那训练时间是不是会更长,没错一般来讲 Qlora 训练会比 Lora 多用 30% 左右的时间。

Qlora 的技术细节

基本的思路讲完了,那么其中包含了哪些具体的实现细节呢?

Qlora 主要包括三个创新点,这里我只简单提及,应付面试足够的程度:

  • NF4 量化:常见的量化分布都是基于参数是均匀分布的假设,而这个方法基于参数是正态分布的假设,这样使得量化精度大大提升。
  • 双重量化:对于第一次量化后得到的用于计算反量化时的锚点参数,我们对这个锚点参数进行量化,可以进一步降低显存。
  • 优化器分页:为了防止 OOM,可以在 GPU 显存紧张的时候利用 CPU 内存进行加载参数。

显存分析

想必已经理解 QLora 运行思路的小伙伴,应该可以很轻松的分析出 Qlora 占用显存的部分了吧,这就是理清楚思路的好处。

没错,Qlora 占用的显存主要就是 4Bit 量化后的模型本身也就是 0.5φ,这里没有考虑少量的 Lora 部分的参数和量化计算中可能产生的显存。可以回过头去看看刚才的表格,也是基本符合预期的。

最后我们用一个表格来总结所有之前我们提到的显存分析:

来源:https://zhuanlan.zhihu.com/p/713256008



END


加入学习




 我是丁师兄,专注于智能驾驶大模型,持续分享LLM面试干货。


 大模型1v1辅导,已帮助多名同学成功上岸

53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

添加专属顾问

回到顶部

加载中...

扫码咨询

扫码登录
登录即表示您同意《53AI网站服务协议》
服务协议

欢迎您使用【53AI 官方网站】(以下简称“本网站”或“我们”)。本《会员服务协议》(以下简称“本协议”)是您(以下简称“会员”或“用户”)与【深圳市博思协创网络科技有限公司】之间关于注册、登录及使用本网站会员服务所订立的法律协议。

在您注册或登录前,请务必审慎阅读、充分理解各条款内容,特别是免除或限制责任的条款、知识产权条款、争议解决条款等。此类条款将以加粗形式提示您注意。 当您通过微信公众号授权、手机验证码验证或其他方式成功登录本网站时,即视为您已完全理解并同意接受本协议的全部内容。

一、 定义

本网站:指由【深圳市博思协创网络科技有限公司】运营的,域名为【53ai.com】的网站及相关移动端页面。

会员服务:指本网站向注册会员提供的知识库文章查阅、内容检索及其他相关增值服务。

知识库内容:指本网站发布的包括但不限于文字、图表、数据、研究报告、行业分析等数字化内容资源。

二、 账号注册与登录

登录方式:本网站支持以下登录方式,您可根据实际情况选择:

微信公众号授权登录:您同意将您的微信OpenID信息授权给本网站,用于创建或关联会员账号。

手机验证码登录:您需提供真实有效的手机号码,并通过短信验证码完成身份验证与登录/注册。

账号安全:您的账号仅限您本人使用,禁止赠与、借用、租用、转让或售卖。因您保管不善导致的账号被盗、密码泄露等损失,由您自行承担。

实名认证:根据相关法律法规要求,我们可能要求您在特定功能下完成实名认证。如您拒绝提供,可能无法使用部分或全部服务。

未成年人保护:若您未满18周岁,请在法定监护人的陪同下阅读本协议,并在征得监护人同意后使用本服务。

三、 服务内容与规范

知识库查阅权限:会员登录后,有权按照其会员等级对应的权限范围,在线浏览、检索本网站知识库中的相关文章及内容。

服务变更:我们有权根据业务发展需要,调整、变更或终止部分服务内容,并将以网站公告、公众号消息等方式提前通知。

禁止行为:您在使用服务时不得实施以下行为:

利用技术手段批量爬取、下载、转存知识库内容;

将知识库内容用于商业目的或未经授权地向第三方传播;

干扰本网站正常运行或侵犯其他用户合法权益;

发布违法违规信息或从事违反公序良俗的活动。

四、 知识产权声明

权利归属:本网站知识库中的排版设计、软件代码等内容的知识产权均归【公司全称】或原权利人所有,受《中华人民共和国著作权法》等法律保护。

有限许可:本网站授予会员一项非独占、不可转让、不可转授权的普通许可,仅限于个人学习、研究之目的在线查阅知识库内容。

侵权追责:未经书面许可,任何单位或个人不得以任何形式复制、转载、摘编、镜像、汇编或以其他方式使用上述内容。一经发现,我们保留追究其法律责任的权利。

五、 个人信息保护

我们重视对您个人信息的保护。关于我们如何收集、使用、存储和保护您的个人信息,请单独阅读 《隐私政策》。

您通过微信公众号授权或手机号验证所提供的信息,我们将严格按照《个人信息保护法》的规定处理,仅用于身份识别、服务提供及安全验证等必要用途。

您可以随时通过网站设置或联系客服行使查阅、更正、删除个人信息及撤回授权同意的权利。

六、 免责声明

内容准确性:知识库内容仅供参考,不构成专业建议。我们不对其完整性、准确性、时效性作任何明示或暗示的保证,您应自行判断并承担使用风险。

不可抗力:因自然灾害、政策法规变化、网络故障、第三方平台接口异常(如微信接口维护、运营商短信通道故障)等不可抗力导致的服务中断或延迟,我们不承担违约责任。

第三方链接:本网站可能包含指向第三方网站的链接,该等网站的内容和服务不受我们控制,请您自行甄别风险。

七、 违约责任

如您违反本协议约定,我们有权视情节采取警告、限制功能、暂停服务、注销账号等措施,并保留要求赔偿损失的权利。

如因您的违约行为导致我们遭受行政处罚、第三方索赔或商誉损失,您应承担全部赔偿责任(包括但不限于罚款、赔偿金、律师费、公证费等)。

八、 法律适用与争议解决

本协议的订立、执行和解释均适用中华人民共和国大陆地区法律。

因本协议产生的或与本协议有关的任何争议,双方应友好协商解决;协商不成的,任何一方均可向【公司所在地】有管辖权的人民法院提起诉讼。

九、 其他

本协议构成双方就本服务达成的完整协议,取代此前任何口头或书面约定。

本协议任一条款被认定为无效或不可执行的,不影响其他条款的效力。

我们对本协议享有最终解释权,并在法律允许的范围内保留随时修改的权利。修改后的协议一经公布即生效,继续使用服务即视为同意修订内容。


已查阅