当前位置: 首页 > news >正文

LLaMA: Open and Efficient Foundation Language Models 论文阅读

目录

摘要

Introduction

Approach

预训练数据

结构

RMSNorm

SwiGLU

RoPE

Optimizer

高效训练的方法

主要结果

指令微调

模型风险评估

Conclusion

LLaMA 实现全部代码


Introduction

作为背景,讨论了是否更大的模型训练更大的参数就会有更好的训练效果。

这里引入了缩放定律scaling laws,作者认为缩放定律忽略了推理的成本,相比于训练更快的模型,作者认为应该选择推理更快的模型,因此提出小的 LLM 配大数据训练更好,因为小 LLM 推理更友好。

Approach

预训练数据

LLaMa 预训练数据大约包含 1.4T tokens,对于绝大部分的训练数据,训练期间只使用一次。

下图展示了 LLaMa 预训练数据的占比:

结构

LLaMA 为 decoder-only 结构,和之前其他模型相比最大的3个改进:

  • 对每个Transformer子层的输入使用 RMSNorm 归一化函数进行归一化,而不是对输出进行归一化。
  • 用 SwiGLU 激活函数替换 ReLU 非线性,以提高性能。
  • 删除了绝对位置嵌入,使用 RoPE 旋转位置嵌入。

RMSNorm

可以增强训练稳定性。

代码实现:

class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight

SwiGLU

可以提高模型性能。

LLaMa 使用 SwiGLU 激活函数替换 ReLU 非线性以提高性能,SwiGLU 激活函数结合了 Swish 激活函数和 GLU(Gated Linear Unit)的机制。

RoPE

可以更好地建模长序列数据。

来源于苏剑林大神,不直接将位置信息作为向量与词向量相加,而是在注意力机制的 Query(查询)和 Key(键)计算时,将输入向量视为复数,在复数平面上进行旋转,通过旋转操作将位置信息融入输入嵌入中。

优点:
旋转操作可以通过复数操作简化,计算复杂度低
可以捕捉相对位置关系
对长序列友好

LLaMA中 RoPE 的实现代码:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  t = torch.arange(end, device=freqs.device)  # type: ignore  freqs = torch.outer(t, freqs).float()  # type: ignore  freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  return freqs_cis  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  ndim = x.ndim  assert 0 <= 1 < ndim  assert freqs_cis.shape == (x.shape[1], x.shape[-1])  shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  return freqs_cis.view(*shape)  def apply_rotary_emb(  xq: torch.Tensor,  xk: torch.Tensor,  freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  return xq_out.type_as(xq), xk_out.type_as(xk)

Optimizer

使用 AdamW 优化器,超参数如下:β1 = 0.9,β2 = 0.95,最终学习率等于最大学习率的10 %。

如图为7B、13B、33B 和 65B tokens 的模型训练损失:

高效训练的方法

  1. 内存与速度优化:用 xformers 库实现高效因果多头注意力,手动实现 Transformer 层反向传播,选择性保存高计算成本激活值以减少重计算。
  2. 分布式策略:模型并行 + 序列并行减少内存占用;重叠激活计算与 GPU 间通信。
  3. 训练效率:65B 模型用 2048 张 A100(80GB)GPU,1.4T tokens 训练耗时约 21 天。

主要结果

该模块给出 LLaMA 在常识推理、闭卷问答、阅读理解6项具体任务中,其表现优于或匹敌主流模型。

指令微调

该部分指出小样本微调可以进一步提高了模型对指令的follow能力。下图比较了大小适中的模型在MMLU上有指令微调和无指令微调的情况:

模型风险评估

该部分分为Bias, Toxicity and Misinformation,即毒性、偏见、真实性三方面展开,核心结论如下:

  1. 毒性:用 RealToxicityPrompts 评估,模型规模越大毒性越强,如 65B 模型 “基础提示”“尊重提示” 毒性得分(0.128、0.141)高于 7B 模型(0.106、0.081),与 OPT 等开源模型趋势一致;
  2. 偏见:通过 CrowS-Pairs 和 WinoGender 评估,65B 模型平均偏见得分(66.6%)略低于 GPT-3、OPT-175B,但宗教领域偏见突出(79.0%,超 OPT 10%);且对 “their/them” 代词共指分辨率高于 “her/she”“his/he”,“gotcha” 案例错误率高,存在性别偏见;
  3. 真实性:用 TruthfulQA 评估,65B 模型 “真实回答”“真实且有用” 占比(57%、53%)高于 GPT-3(28%、25%),但仍有幻觉风险。

Conclusion

强调仅用公开可用数据训练大模型就能达到最先进性能,无需依赖专有数据集,这也是本文的标题所在:开源高效的大语言模型。

LLaMA 实现全部代码

# Copyright (c) Meta Platforms, Inc. and affiliates.  
# This software may be used and distributed according to the terms of the GNU General Public License version 3.  from typing import Optional, Tuple  
from dataclasses import dataclass  
import math  import torch  
from torch import nn  
import torch.nn.functional as F  import fairscale.nn.model_parallel.initialize as fs_init  
from fairscale.nn.model_parallel.layers import (  ParallelEmbedding,  RowParallelLinear,  ColumnParallelLinear,  
)  @dataclass  
class ModelArgs:  dim: int = 512  n_layers: int = 8  n_heads: int = 8  vocab_size: int = -1  # defined later by tokenizer  multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2  norm_eps: float = 1e-5  max_batch_size: int = 32  max_seq_len: int = 2048  class RMSNorm(torch.nn.Module):  def __init__(self, dim: int, eps: float = 1e-6):  super().__init__()  self.eps = eps  self.weight = nn.Parameter(torch.ones(dim))  def _norm(self, x):  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  def forward(self, x):  output = self._norm(x.float()).type_as(x)  return output * self.weight  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  t = torch.arange(end, device=freqs.device)  # type: ignore  freqs = torch.outer(t, freqs).float()  # type: ignore  freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  return freqs_cis  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  ndim = x.ndim  assert 0 <= 1 < ndim  assert freqs_cis.shape == (x.shape[1], x.shape[-1])  shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  return freqs_cis.view(*shape)  def apply_rotary_emb(  xq: torch.Tensor,  xk: torch.Tensor,  freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  return xq_out.type_as(xq), xk_out.type_as(xk)  class Attention(nn.Module):  def __init__(self, args: ModelArgs):  super().__init__()  self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  self.head_dim = args.dim // args.n_heads  self.wq = ColumnParallelLinear(  args.dim,  args.n_heads * self.head_dim,  bias=False,  gather_output=False,  init_method=lambda x: x,  )  self.wk = ColumnParallelLinear(  args.dim,  args.n_heads * self.head_dim,  bias=False,  gather_output=False,  init_method=lambda x: x,  )  self.wv = ColumnParallelLinear(  args.dim,  args.n_heads * self.head_dim,  bias=False,  gather_output=False,  init_method=lambda x: x,  )  self.wo = RowParallelLinear(  args.n_heads * self.head_dim,  args.dim,  bias=False,  input_is_parallel=True,  init_method=lambda x: x,  )  self.cache_k = torch.zeros(  (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  ).cuda()  self.cache_v = torch.zeros(  (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  ).cuda()  def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  bsz, seqlen, _ = x.shape  xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  self.cache_k = self.cache_k.to(xq)  self.cache_v = self.cache_v.to(xq)  self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  keys = self.cache_k[:bsz, : start_pos + seqlen]  values = self.cache_v[:bsz, : start_pos + seqlen]  xq = xq.transpose(1, 2)  keys = keys.transpose(1, 2)  values = values.transpose(1, 2)  scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  if mask is not None:  scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  scores = F.softmax(scores.float(), dim=-1).type_as(xq)  output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  output = output.transpose(  1, 2  ).contiguous().view(bsz, seqlen, -1)  return self.wo(output)  class FeedForward(nn.Module):  def __init__(  self,  dim: int,  hidden_dim: int,  multiple_of: int,  ):  super().__init__()  hidden_dim = int(2 * hidden_dim / 3)  hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  self.w1 = ColumnParallelLinear(  dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  )  self.w2 = RowParallelLinear(  hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x  )  self.w3 = ColumnParallelLinear(  dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  )  def forward(self, x):  return self.w2(F.silu(self.w1(x)) * self.w3(x))  class TransformerBlock(nn.Module):  def __init__(self, layer_id: int, args: ModelArgs):  super().__init__()  self.n_heads = args.n_heads  self.dim = args.dim  self.head_dim = args.dim // args.n_heads  self.attention = Attention(args)  self.feed_forward = FeedForward(  dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of  )  self.layer_id = layer_id  self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)  self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)  def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)  out = h + self.feed_forward.forward(self.ffn_norm(h))  return out  class Transformer(nn.Module):  def __init__(self, params: ModelArgs):  super().__init__()  self.params = params  self.vocab_size = params.vocab_size  self.n_layers = params.n_layers  self.tok_embeddings = ParallelEmbedding(  params.vocab_size, params.dim, init_method=lambda x: x  )  self.layers = torch.nn.ModuleList()  for layer_id in range(params.n_layers):  self.layers.append(TransformerBlock(layer_id, params))  self.norm = RMSNorm(params.dim, eps=params.norm_eps)  self.output = ColumnParallelLinear(  params.dim, params.vocab_size, bias=False, init_method=lambda x: x  )  self.freqs_cis = precompute_freqs_cis(  self.params.dim // self.params.n_heads, self.params.max_seq_len * 2  )  @torch.inference_mode()  def forward(self, tokens: torch.Tensor, start_pos: int):  _bsz, seqlen = tokens.shape  h = self.tok_embeddings(tokens)  self.freqs_cis = self.freqs_cis.to(h.device)  freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  mask = None  if seqlen > 1:  mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)  for layer in self.layers:  h = layer(h, start_pos, freqs_cis, mask)  h = self.norm(h)  output = self.output(h[:, -1, :])  # only compute last logits  return output.float()

http://www.dtcms.com/a/464768.html

相关文章:

  • LeetCode——Hot 100【全排列】
  • 云南大理拍婚纱照价格表建网站优化
  • 双目测距实战1-环境配置
  • 2025人工智能在无人机数据处理中的应用
  • 阿里开源Qwen3-Omni-30B-A3B三剑客——Instruct、Thinking 和 Captioner
  • 长春建站程序湖南网络科技有限公司
  • xtuoj Can you raed it croretcly?
  • 异构动作空间
  • 【Nginx开荒攻略】Nginx虚拟主机配置:从域名、端口到IP的完整指南
  • 小杰深度学习(nine)——CUDA与CuDNN安装
  • 鸿蒙NEXT USB Host模式开发完全指南
  • MinerU2.5 windows 本地部署
  • UIkit中使用新版UICollectionViewCompositionalLayout进行复杂布局(二)
  • 网站建设的技术问题苏州吴江建设局招标网站
  • 河南省村镇建设处网站网站配色与布局 教材
  • Prometheus运维之路(ES监控接入)
  • OpenAMP专题(一):一文了解OpenAMP全貌
  • C++ 中 rfind 方法详解
  • SpringBoot 教程(十四) SpringBoot之集成 Redis(优化版)
  • 【Linux】线程同步与互斥(上)
  • 图观 模型编辑器
  • Win11 输入延迟与鼠标卡顿:系统化排查与优化指南
  • 【开题答辩全过程】以 爱运动健身小程序的设计与实现为例,包含答辩的问题和答案
  • Linux 内核IIO sensor驱动
  • 《Linux系统编程之入门基础》【Linux的前世今生】
  • 活动汪活动策划网站龙岗建设网站
  • Apache IoTDB 架构特性与 Prometheus+Grafana 监控体系部署实践
  • LLM时代基于unstructured解析非结构化pdf
  • uniapp tab切换及tab锚点效果(wx小程序及H5端)
  • Hadoop面试题及详细答案 110题 (71-85)-- 集群部署与运维