复现nanoGPT——train.py(详细版拆解)
原版的train前面有特别多的参数定义,看的人头晕,所以我就把它们系统化的整理出来,分成几个模块,和各自使用部分放在一起,但是事实证明,放在最前面是最好的,因为放在中间可能会涉及到参数调用顺序和覆盖问题。把它们分开只是为了方便理解。
最前面应该还有两个参数out和device,因为它们几乎所有模块都有,就放在最前面了
out_dir = 'out'
device = 'cuda'
1.生成训练\测试数据
#--------------生成数据---------------------------------
dataset = 'hongloumeng_char'
data_dir = os.path.join('data', dataset)
device_type = 'cuda' if 'cuda' in device else 'cpu'
batch_size = 64
def get_batch(split):if split == 'train':data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')else:data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')ix = torch.randint(len(data) - block_size, (batch_size,))x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])y = torch.stack([torch.from_numpy((data[i+1:i+block_size+1]).astype(np.int64))for i in ix])if device_type == 'cuda':x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)else:x, y = x.to(device), y.to(device)return x, y
2.处理学习率
#---------------处理学习率-----------------------------------
warm_up_iter = 100
learning_rate = 1e-3
lr_decay_iter = 5000
min_lr =1e-4
decay_lr = True
def get_lr(it):if it < warm_up_iter:return learning_rate * (it + 1) / (warm_up_iter + 1)elif it > lr_decay_iter:return min_lrelse:lr_ratio = (it - warm_up_iter) / (lr_decay_iter - warm_up_iter)coeff = 0.5 * (math.cos(lr_ratio * math.pi) + 1)return min_lr + coeff * (learning_rate - min_lr)
3.模型初始化
#--------------初始化模型--------------------------------
"""
class Config:block_size: int = 1024vocab_size: int = 50304n_embd: int = 768bias: bool = Truen_layer: int = 12n_head: int = 12dropout: float = 0.0
#模型定义时的config
"""block_size = 256
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2
bias = False
init_form = "scratch"
meta_vocab_size = Nonemodel_args = dict(block_size=block_size, vocab_size=None, n_embd=n_embd,bias=bias, n_layer=n_layer, n_head=n_head, dropout=dropout)iter_num = 0
best_val_loss = 1e9meta_path = os.path.join(data_dir, 'meta.pkl')
if os.path.exists(meta_path):with open(meta_path, 'rb') as f:meta = pickle.load(f)meta_vocab_size = meta['vocab_size']print(f"found vacab_size={meta_vocab_size}(insida{meta_path})")if init_form == "scratch":print("Initializing a new model from scratch")if meta_vocab_size is None:print("defaulting to vacab_size of GPT-2 to 50304")model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304gptconf = GPTConfig(**model_args)model = GPT(gptconf)
elif init_form == "resume":out_dir = 'out-hongloumeng-char'print(f"Resuming training from {out_dir}")ckpt_path = os.path.join(out_dir, "ckpt.pt")checkpoint = torch.load(ckpt_path, map_location=device)checkpoint_conf_arg = checkpoint['model_args']for k in ['block_size', 'vocab_size', 'n_embd', 'bias', 'n_layer','n_head']:model_args[k] = checkpoint_conf_arg[k]gptconf = GPTConfig(**model_args)model = GPT(gptconf)#一个不知道为什么的错误state_dict = checkpoint['model']unwanted_prefix = '_orig_mod.'for k, v in list(state_dict.items()):if k.startswith(unwanted_prefix):state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)model.load_state_dict(state_dict)iter_num = checkpoint['iter_num']best_val_loss = checkpoint['best_val_loss']elif init_form.startswith('gpt2'):print(f"Initializing from OpenAI GPT-2 weights:{init_form}")override_args = dict(dropout=dropout)model = GPT.from_pretrained(init_form, override_args)for k in ['block_size', 'vocab_size', 'n_embd', 'bias', 'n_layer','n_head']:model_args[k] = getattr(model.config, k)
if block_size <= model.config.block_size:model.crop_block_size(block_size)model_args['block_size'] = block_size
model.to(device)
4.编译
#--------------编译-------------------------------------
compile = True
if compile:print("compiling the model ...(take a ~minute)")unoptimized_model = modelmodel = torch.compile(model)
但是我发现貌似compile只能在UNIX系统上使用,因为涉及到一个triton库好像只有UNIX版本的,如果遇到相关报错,可以在开头加:
#import torch._dynamo #torch._dynamo.config.suppress_errors = True
5.optimizer设置
#---------------optimizer----------------------------------
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
#beta2 = 0.99
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_form == 'resume':optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None
6.DDP
#---------------DDP(和master_process相关)-------------------
gradient_accumulation_steps = 1
backend = 'nccl'
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:init_process_group(backend=backend)ddp_rank = int(os.environ['RANK'])ddp_local_rank = int(os.environ['LOCAL_RANK'])ddp_world_size = int(os.environ['WORLD_SIZE'])device = f'cuda:{ddp_local_rank}'torch.cuda.set_device(device)master_process = ddp_rank == 0seed_offset = ddp_rankassert gradient_accumulation_steps % ddp_world_size == 0gradient_accumulation_steps //= ddp_world_size
else:master_process = Trueseed_offset = 0ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be :{tokens_per_iter}")if master_process: #避免多个进程创建os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = Trueif ddp:model = DDP(model, device_ids=[ddp_local_rank])
7.计算损失
#--------------计算损失---------------------------------
from contextlib import nullcontext
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32':torch.float32, 'bfloat16':torch.bfloat16, 'float16':torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
eval_iters = 200
@torch.no_grad
def estimate_loss():out = {}model.eval()for split in ['train', 'val']:losses = torch.zeros(eval_iters)for k in range(eval_iters):X, Y = get_batch(split)with ctx:logits, loss = model(X,Y)losses[k] = loss.item()out[split] = losses.mean().item()model.train()return out
8.主干训练部分
#---------------主干部分-------------------------------------
eval_interval = 250
#参数配置
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}
#日志记录
wandb_log = False
wandb_project = 'hongloumeng-char'
wandb_run_name = 'mini-gpt' # 'run' + str(time.time())
always_keep_checkpoint = False
running_mfu = -1.0
raw_model = model.module if ddp else model
if wandb_log:import wandbwandb.init(project=wandb_project, name=wandb_run_name, config=config)
eval_only = FalseX, Y = get_batch('train')scaler = torch.amp.GradScaler('cuda', enabled=(dtype=='float16'))
grad_clip = 1.0t0 = time.time()log_interval = 10
local_iter_num = 0
max_iters = 5000
while(True):lr = get_lr(iter_num) if decay_lr else learning_ratefor param_group in optimizer.param_groups:param_group['lr'] = lr#保存日志if iter_num % eval_interval == 0 and master_process:loss = estimate_loss()print(f"step {iter_num} training loss:{loss['train']:.4f}, val loss:{loss['val']:.4f}")if wandb_log:wandb.log({'iter':iter_num,'loss/val':loss['val'],'loss/train':loss['train'],'lr':lr,'mfu':running_mfu * 100})if loss['val'] <= best_val_loss or always_keep_checkpoint:best_val_loss = loss['val']if iter_num > 0:checkpoint = {'model':raw_model.state_dict(),'optimizer':optimizer.state_dict(),'model_args':model_args,'iter_num':iter_num,'best_val_loss':best_val_loss,'config':config,}print(f"saving checkpoint to {out_dir}")torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))if iter_num == 0 and eval_only:break#梯度裁剪for micro_step in range(gradient_accumulation_steps):if ddp:model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)with ctx:logits, loss = model(X, Y)loss = loss / gradient_accumulation_stepsX, Y = get_batch('train')scaler.scale(loss).backward()if grad_clip != 0.0:scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)scaler.step(optimizer)scaler.update()optimizer.zero_grad(set_to_none=True)#评估性质mfut1 = time.time()dt = t1 - t0t0 = t1if iter_num % log_interval == 0 and master_process:lossf = loss.item() * gradient_accumulation_stepsif local_iter_num >=5:mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfuprint(f"iter{iter_num}:loss{lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")iter_num +=1local_iter_num += 1if iter_num > max_iters:breakif ddp:destroy_process_group()
实现的时候可以先写骨干,根据骨干所需要的功能再逐渐补充上面的