如何在24GB的GPU上运行DeepSeek-R1-Distill-Qwen-32B
如何在24GB的GPU上运行DeepSeek-R1-Distill-Qwen-32B
- 一、背景
- 二、解决方案
- 三、操作步骤
- 1.下载模型
- 2.安装依赖
- 3.量化
- 4.生成推理代码
- 5.运行
- A.缓存上限为128条
- B.不限制缓存上限
- C.输出内容
一、背景
随着深度学习的不断发展,大型语言模型(LLM,Large Language Model)在自然语言处理领域展现出了强大的能力。然而,伴随着模型参数规模的指数级增长,运行这些模型所需的计算资源也变得异常庞大,尤其是对显存(GPU内存)的需求。因此,如何在有限的GPU显存下有效地运行超大规模的LLM,成为了一个亟待解决的挑战。
本文验证在GPU显存受限的情况下,如何高效地运行超出GPU内存容量的LLM模型。通过对模型权重的量化和内存管理策略的优化,期望能够突破硬件瓶颈,为大型模型的部署和应用提供新的思路。
二、解决方案
下面的方案,主要包括权重量化、内存缓存机制以及自定义Linear的设计。具体方案如下:
-
权重的INT4块量化
- 量化策略:将模型的权重参数进行INT4(4位整数)块量化处理,量化的块大小设定为128。这种量化方式能够大幅度减少模型权重所占用的存储空间。
- 内存优势:经过INT4量化后的权重占用空间显著降低,使得所有权重可以加载到主机(HOST)内存中。这不仅缓解了GPU显存的压力,还为后续的高效读取奠定了基础。
-
减少磁盘I/O操作
- 全量加载:将所有量化后的INT4权重一次性加载到HOST内存中,避免了在模型运行过程中频繁进行磁盘读写操作。这种方式有效减少了磁盘I/O带来的时间开销和性能瓶颈。
-
设备内存缓存机制
- 缓存设计:在GPU设备内存中建立一个缓存机制,设定最大缓存条目数为N。N的取值与具体的GPU配置相关,目的是充分利用可用的设备内存,最大化其占用率,提升数据读取效率。
- 动态管理:缓存机制需要智能地管理内存的分配和释放,确保在不超过设备内存上限的情况下,高效地存取所需的数据。
-
权重预加载线程
- 职责分离:引入一个专门的权重预加载线程,负责将HOST内存中的INT4权重进行反量化处理(即将INT4还原为计算所需的格式),并将处理后的权重加载到GPU设备内存的缓存中。
- 效率优化:通过预加载线程的异步处理,提升了数据准备的效率,确保模型在需要数据时可以及时获取,最大程度减少等待时间。
-
自定义Linear模块
- 模块替换:将原有的
nn.Linear
层替换为自定义的Module。在模型构建和加载过程中,使用该自定义模块来承载线性计算任务。 - 运行机制:自定义的Module在前向传播(forward)过程中,从设备内存的缓存中获取所需的权重进行计算。计算完成后,立即释放权重占用的设备内存,以供后续的计算任务使用。
- 优势:这种动态加载和释放的机制,避免了在整个计算过程中权重长时间占用设备内存,极大地提高了内存的利用效率。
- 模块替换:将原有的
三、操作步骤
1.下载模型
# 模型介绍: https://www.modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
# 下载模型
apt install git-lfs -y
git clone https://www.modelscope.cn/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B.git
2.安装依赖
MAX_JOBS=4 pip install flash-attn==2.3.6
pip install torch-tb-profiler
3.量化
cat > extract_weights.py << EOF
import torch
import os
from tqdm import tqdm
from glob import glob
import torch
import sys
from safetensors.torch import safe_open, save_file
def quantize_tensor_int4(tensor):
"""
将bfloat16的Tensor按照块大小128进行量化为int4,并返回每个块的scale。
参数:
tensor (torch.Tensor): bfloat16类型的输入Tensor。
返回:
int4_tensor (torch.Tensor): 量化后的uint8类型的Tensor,存储int4值,每个元素包含两个int4值。
scales (torch.Tensor): 每个块对应的bfloat16类型的scale值。
"""
# 确保输入Tensor为bfloat16类型
tensor = tensor.to(torch.bfloat16)
# 将Tensor展平为一维
flat_tensor = tensor.flatten()
N = flat_tensor.numel()
block_size = 128
num_blocks = (N + block_size - 1) // block_size # 计算块的数量
# 计算每个元素的块索引
indices = torch.arange(N, device=flat_tensor.device)
block_indices = indices // block_size # shape: [N]
# 计算每个块的x_max
abs_tensor = flat_tensor.abs()
zeros_needed = num_blocks * block_size - N
# 对张量进行填充,使其长度为num_blocks * block_size
if zeros_needed > 0:
padded_abs_tensor = torch.cat([abs_tensor, torch.zeros(zeros_needed, device=abs_tensor.device, dtype=abs_tensor.dtype)])
else:
padded_abs_tensor = abs_tensor
reshaped_abs_tensor = padded_abs_tensor.view(num_blocks, block_size)
x_max = reshaped_abs_tensor.max(dim=1).values # shape: [num_blocks]
# 处理x_max为0的情况,避免除以0
x_max_nonzero = x_max.clone()
x_max_nonzero[x_max_nonzero == 0] = 1.0 # 防止除以0
# 计算scale
scales = x_max_nonzero / 7.0 # shape: [num_blocks]
scales = scales.to(torch.bfloat16)
# 量化
scales_expanded = scales[block_indices] # shape: [N]
q = torch.round(flat_tensor / scales_expanded).clamp(-8, 7).to(torch.int8)
# 将有符号int4转换为无符号表示
q_unsigned = q & 0x0F # 将范围[-8,7]映射到[0,15]
# 如果元素数量是奇数,补充一个零
if N % 2 != 0:
q_unsigned = torch.cat([q_unsigned, torch.zeros(1, dtype=torch.int8, device=q.device)])
# 打包两个int4到一个uint8
q_pairs = q_unsigned.view(-1, 2)
int4_tensor = (q_pairs[:, 0].to(torch.uint8) << 4) | q_pairs[:, 1].to(torch.uint8)
return int4_tensor, scales
torch.set_default_device("cuda")
if len(sys.argv)!=3:
print(f"{sys.argv[0]} input_model_dir output_dir")
else:
input_model_dir=sys.argv[1]
output_dir=sys.argv[2]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
state_dicts = {}
for file_path in tqdm(glob(os.path.join(input_model_dir, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cuda") as f:
for name in f.keys():
param: torch.Tensor = f.get_tensor(name)
#print(name,param.shape,param.dtype)
if "norm" in name or "embed" in name:
state_dicts[name] = param
else:
if "weight" in name:
int4_tensor, scales=quantize_tensor_int4(param)
state_dict={}
state_dict["w"]=int4_tensor.data
state_dict["scales"]=scales.data
state_dict["shape"]=param.shape
torch.save(state_dict, os.path.join(output_dir, f"{name}.pt"))
else:
torch.save(param.data, os.path.join(output_dir, f"{name}.pt"))
torch.save(state_dicts, os.path.join(output_dir, "others.pt"))
EOF
python extract_weights.py DeepSeek-R1-Distill-Qwen-32B ./data
4.生成推理代码
cat > infer.py << EOF
import sys
import os
from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig
import torch
import time
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import threading
from torch import Tensor
from tqdm import tqdm
import triton
import triton.language as tl
import time
import queue
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@triton.jit
def dequantize_kernel(
int4_ptr, # 量化后的 int4 张量指针
scales_ptr, # 每个块的 scale 值指针
output_ptr, # 输出张量指针
N, # 总元素数量
num_blocks, # 总块数
BLOCK_SIZE: tl.constexpr # 每个线程块处理的元素数量
):
# 计算全局元素索引
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
# 计算 int4 张量中的索引
int4_idxs = offs // 2 # 每个 uint8 包含两个 int4 值
int4_vals = tl.load(int4_ptr + int4_idxs, mask=int4_idxs < (N + 1) // 2)
# 提取高 4 位和低 4 位的 int4 值
shift = 4 * (1 - (offs % 2))
q = (int4_vals >> shift) & 0x0F
q = q.to(tl.int8)
# 将无符号 int4 转换为有符号表示
q = (q + 8) % 16 - 8 # 将范围 [0, 15] 映射回 [-8, 7]
# 计算每个元素所属的块索引
block_size = 128
block_idxs = offs // block_size
scales = tl.load(scales_ptr + block_idxs, mask=block_idxs < num_blocks)
# 反量化
dequantized = q.to(tl.float32) * scales
# 存储结果
tl.store(output_ptr + offs, dequantized, mask=mask)
def dequantize_tensor_int4_triton(int4_tensor, scales, original_shape):
N = original_shape.numel()
num_blocks = scales.numel()
output = torch.empty(N, dtype=torch.bfloat16, device=int4_tensor.device)
# 动态调整块大小(A100建议512-1024)
BLOCK_SIZE = min(1024, triton.next_power_of_2(N))
grid = (triton.cdiv(N, BLOCK_SIZE),)
dequantize_kernel[grid](int4_tensor, scales, output,N, scales.numel(), BLOCK_SIZE=BLOCK_SIZE)
output = output.view(original_shape)
return output
def load_pinned_tensor(path):
data = torch.load(path, map_location='cpu',weights_only=True) # 先加载到CPU
# 递归遍历所有对象,对Tensor设置pin_memory
def _pin(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.pin_memory()
elif isinstance(tensor, dict):
return {k: _pin(v) for k, v in tensor.items()}
elif isinstance(tensor, (list, tuple)):
return type(tensor)(_pin(x) for x in tensor)
else:
return tensor
return _pin(data)
class WeightCache:
def __init__(self, weight_names, weight_dir, max_cache_size):
self.weight_names = weight_names
self.weight_dir = weight_dir
if max_cache_size==-1:
self.max_cache_size = len(weight_names)
else:
self.max_cache_size = max_cache_size
self.cache = {}
self.cache_lock = threading.Lock()
self.condition = threading.Condition(self.cache_lock)
self.index = 0
self.weight_cpu = []
self.dequantized = {}
self.accessed_weights = set() # 用于记录被 get 过的权值
for name in tqdm(self.weight_names):
weight_path = os.path.join(self.weight_dir, name + ".pt")
self.weight_cpu.append(load_pinned_tensor(weight_path))
self.loader_thread = threading.Thread(target=self._loader)
self.loader_thread.daemon = True
self.loader_thread.start()
self.last_ts = time.time()
def _loader(self):
stream = torch.cuda.Stream()
while True:
with self.condition:
while len(self.cache) > self.max_cache_size:
# 尝试删除已被 get 过的权值
removed = False
for weight_name in list(self.cache.keys()):
if weight_name in self.accessed_weights:
del self.cache[weight_name]
self.accessed_weights.remove(weight_name)
removed = True
break # 每次删除一个
if not removed:
self.condition.wait()
# 加载新的权值到缓存
if self.index >= len(self.weight_names):
self.index = 0
weight_name = self.weight_names[self.index]
if weight_name in self.cache:
time.sleep(0.01)
continue
w = self.weight_cpu[self.index]
with torch.cuda.stream(stream):
if "weight" in weight_name:
new_weight = {"w": w['w'].to(device, non_blocking=False),
"scales": w['scales'].to(device, non_blocking=False),
"shape": w['shape']}
else:
new_weight = w.to(device, non_blocking=False)
with self.condition:
self.cache[weight_name] = new_weight
self.index += 1
self.condition.notify_all()
def wait_full(self):
with self.condition:
while len(self.cache) < self.max_cache_size:
self.condition.wait()
print(len(self.cache), self.max_cache_size)
def get(self, weight_name):
with self.condition:
while weight_name not in self.cache:
self.condition.wait()
weight = self.cache[weight_name] # 不再从缓存中删除
self.accessed_weights.add(weight_name) # 记录被 get 过的权值
self.condition.notify_all()
return weight
class TextGenerationDataset(Dataset):
def __init__(self, json_data):
self.data = json.loads(json_data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
input_text = item['input']
expected_output = item['expected_output']
return input_text, expected_output
class Streamer:
def __init__(self, tokenizer):
self.cache = []
self.tokenizer = tokenizer
self.start_time = None # 用于记录开始时间
self.token_count = 0 # 用于记录生成的令牌数量
def put(self, token):
if self.start_time is None:
self.start_time = time.time() # 初始化开始时间
decoded = self.tokenizer.decode(token[0], skip_special_tokens=True)
self.cache.append(decoded)
self.token_count += token.numel() # 增加令牌计数
elapsed_time = time.time() - self.start_time
tokens_per_sec = self.token_count / elapsed_time if elapsed_time > 0 else 0
print(f"{tokens_per_sec:.2f} tokens/sec| {''.join(self.cache)}", end="\r", flush=True)
def end(self):
total_time = time.time() - self.start_time if self.start_time else 0
print("\nGeneration complete.")
if total_time > 0:
avg_tokens_per_sec = self.token_count / total_time
print(f"总令牌数: {self.token_count}, 总耗时: {total_time:.2f}s, 平均速度: {avg_tokens_per_sec:.2f} tokens/sec.")
else:
print("总耗时过短,无法计算每秒生成的令牌数。")
class MyLinear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = True
if bias:
self.bias = True
else:
self.bias = False
def forward(self, x):
w = self.weight_cache.get(f"{self.w_name}.weight")
weight=dequantize_tensor_int4_triton(w['w'], w['scales'],w['shape'])
if self.bias:
bias = self.weight_cache.get(f"{self.w_name}.bias")
else:
bias = None
return torch.nn.functional.linear(x,weight,bias)
def set_linear_name(model,weight_cache):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.w_name=name
module.weight_cache=weight_cache
torch.nn.Linear=MyLinear
input_model_dir=sys.argv[1]
input_weights_dir=sys.argv[2]
cache_queue_size=int(sys.argv[3])
torch.set_default_device('cuda')
tokenizer = AutoTokenizer.from_pretrained(input_model_dir)
from transformers.models.qwen2 import Qwen2ForCausalLM,Qwen2Config
config=Qwen2Config.from_pretrained(f"{input_model_dir}/config.json")
config.use_cache=True
config.torch_dtype=torch.float16
config._attn_implementation="flash_attention_2"
model =Qwen2ForCausalLM(config).bfloat16().bfloat16().to(device)
checkpoint=torch.load(f"{input_weights_dir}/others.pt",weights_only=True)
model.load_state_dict(checkpoint)
weight_map=[]
with open(os.path.join(input_model_dir,"model.safetensors.index.json")) as f:
for name in json.load(f)["weight_map"].keys():
if "norm" in name or "embed" in name:
pass
else:
weight_map.append(name)
json_data =r'''
[
{"input": "1.1+2.3=?", "expected_output": "TODO"}
]
'''
weight_cache = WeightCache(weight_map,input_weights_dir,cache_queue_size)
print("wait done")
set_linear_name(model,weight_cache)
model.eval()
test_dataset = TextGenerationDataset(json_data)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
dataloader_iter = iter(test_dataloader)
input_text, expected_output=next(dataloader_iter)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
streamer = Streamer(tokenizer)
if True:
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
with torch.inference_mode():
#outputs = model.generate(**inputs, max_length=4096,streamer=streamer,do_sample=True,pad_token_id=tokenizer.eos_token_id,num_beams=1,repetition_penalty=1.1)
outputs = model.generate(**inputs, max_length=4096,streamer=streamer,use_cache=config.use_cache)
else:
def trace_handler(prof):
print(prof.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
prof.export_chrome_trace("output.json")
import torch.autograd.profiler as profiler
from torch.profiler import profile, record_function, ProfilerActivity
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],on_trace_ready=trace_handler) as p:
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
with torch.inference_mode():
outputs = model.generate(**inputs, max_length=4096,streamer=streamer,use_cache=True,do_sample=True,pad_token_id=tokenizer.eos_token_id,num_beams=1,repetition_penalty=1.1)
#outputs = model.generate(**inputs, max_length=4096,streamer=streamer)
#outputs = model.generate(**inputs, max_length=8)
p.step()
EOF
5.运行
A.缓存上限为128条
export TRITON_CACHE_DIR=$PWD/cache
python infer.py DeepSeek-R1-Distill-Qwen-32B data 128
性能
总令牌数: 403, 总耗时: 572.34s, 平均速度: 0.70 tokens/sec.
GPU利用率
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:03:00.0 Off | N/A |
| 71% 62C P0 186W / 350W | 10354MiB / 24576MiB | 95% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 0 N/A N/A 47129 C python 10258MiB |
+-----------------------------------------------------------------------------------------+
B.不限制缓存上限
export TRITON_CACHE_DIR=$PWD/cache
python infer.py DeepSeek-R1-Distill-Qwen-32B data -1
性能
总令牌数: 403, 总耗时: 72.84s, 平均速度: 5.53 tokens/sec.
GPU利用率
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:03:00.0 Off | N/A |
| 73% 65C P0 330W / 350W | 22678MiB / 24576MiB | 97% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 0 N/A N/A 47903 C python 22582MiB |
+-----------------------------------------------------------------------------------------+
C.输出内容
1.1+2.3=? Let me think.
Okay, so I need to add 1.1 and 2.3 together. Hmm, let me visualize this. I remember that when adding decimals, it’s important to line up the decimal points to make sure each place value is correctly added. So, I can write them one under the other like this:
1.1
+2.3
------
Starting from the rightmost digit, which is the tenths place. 1 (from 1.1) plus 3 (from 2.3) equals 4. So, I write down 4 in the tenths place.
Next, moving to the units place. 1 (from 1.1) plus 2 (from 2.3) equals 3. So, I write down 3 in the units place.
Putting it all together, the sum is 3.4. Let me double-check to make sure I didn’t make a mistake. 1.1 plus 2 is 3.1, and then adding the 0.3 more gives me 3.4. Yep, that seems right.
I think I got it! The answer should be 3.4.
To add 1.1 and 2.3, follow these steps:
- Align the decimal points:
1.1
+2.3
------
- Add the tenths place:
1 (from 1.1) + 3 (from 2.3) = 4
- Add the units place:
1 (from 1.1) + 2 (from 2.3) = 3
- Combine the results:
3.4
Final Answer: \boxed{3.4}
Generation complete.