当前位置: 首页 > news >正文

【GPU RAM】实时监控GPU内存分配(一)

利用torch.cuda.memory捕捉内存快照

用到的命令

torch.cuda.memory._record_memory_history(max_entries=100000) 	# 开始记录,最多记录100000条
torch.cuda.memory._dump_snapshot(file_name)	#保存快照
torch.cuda.memory._record_memory_history(enabled=None)	#停止记录

max_entries=100000:最多记录 100000 条 GPU 内存分配/释放事件(alloc/free events)。具体来说:某一块显存 被申请/分配了(比如创建了一个 tensor)或者 某一块显存 被释放了(比如 tensor 被删除或者生命周期结束)。

记录显存分配的函数

只在使用GPU的时候才记录。


def start_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Starting snapshot record_memory_history")torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT)def stop_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Stopping snapshot record_memory_history")torch.cuda.memory._record_memory_history(enabled=None)def export_memory_snapshot() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not exporting memory snapshot")return# Prefix for file names.host_name = socket.gethostname()timestamp = datetime.now().strftime(TIME_FORMAT_STR)file_prefix = f"{host_name}_{timestamp}"try:logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")except Exception as e:logger.error(f"Failed to capture memory snapshot {e}")return

Example

以Bert代码为例子(带#!!!!!!!!!!!!!!!!!!!就是记录快照的命令):

date_str = datetime.now().strftime("%Y-%m-%d")
logging.basicConfig(format="%(levelname)s:%(asctime)s %(message)s",level=logging.INFO,filename="log_mem_snap.txt",datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000def start_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Starting snapshot record_memory_history")torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT)def stop_record_memory_history() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not recording memory history")returnlogger.info("Stopping snapshot record_memory_history")torch.cuda.memory._record_memory_history(enabled=None)def export_memory_snapshot() -> None:if not torch.cuda.is_available():logger.info("CUDA unavailable. Not exporting memory snapshot")return# Prefix for file names.host_name = socket.gethostname()timestamp = datetime.now().strftime(TIME_FORMAT_STR)file_prefix = f"{host_name}_{timestamp}"try:logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")except Exception as e:logger.error(f"Failed to capture memory snapshot {e}")returndef train(config):model = BertForPretrainingModel(config,)t_min, t_max = -1, 0  # for the input layert_min, t_max = model.bert.bert_embeddings.set_params(t_min, t_max)for i, layer in enumerate(model.bert.bert_encoder.bert_layers):t_tuple = layer.set_params(t_min, t_max)t_min, t_max = t_tuple[-2:]# print(t_min, t_max, "----------")config.T_RES = t_maxlast_epoch = -1if os.path.exists(config.model_save_path):checkpoint = torch.load(config.model_save_path)last_epoch = checkpoint['last_epoch']loaded_paras = checkpoint['model_state_dict']model.load_state_dict(loaded_paras)logging.info("## Successfully loaded the existing model and continue training. ......")model = model.to(config.device)model.train()bert_tokenize = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenizedata_loader = LoadBertPretrainingDataset(vocab_path=config.vocab_path,tokenizer=bert_tokenize,batch_size=config.batch_size,max_sen_len=config.max_sen_len,max_position_embeddings=config.max_position_embeddings,pad_index=config.pad_index,is_sample_shuffle=config.is_sample_shuffle,random_state=config.random_state,data_name=config.data_name,masked_rate=config.masked_rate,masked_token_rate=config.masked_token_rate,masked_token_unchanged_rate=config.masked_token_unchanged_rate)train_iter, test_iter, val_iter = \data_loader.load_train_val_test_data(test_file_path=config.test_file_path,train_file_path=config.train_file_path,val_file_path=config.val_file_path)# Optimizer# Split weights in two groups, one with weight decay and the other not.no_decay = ["bias", "LayerNorm.weight"]optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],"weight_decay": config.weight_decay,"initial_lr": config.learning_rate},{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],"weight_decay": 0.0,"initial_lr": config.learning_rate},]optimizer = AdamW(optimizer_grouped_parameters)scheduler = get_polynomial_decay_schedule_with_warmup(optimizer,int(len(train_iter) * 0),int(config.epochs * len(train_iter)),last_epoch=last_epoch)max_acc = 0state_dict = Nonefor epoch in range(config.epochs):losses = 0start_time = time.time()# Start recording memory snapshot history	      #!!!!!!!!!!!!!!!!!!!start_record_memory_history()for idx, (b_token_ids, b_segs, b_mask, b_mlm_label, b_nsp_label) in enumerate(train_iter):b_token_ids = b_token_ids.to(config.device)  # [src_len, batch_size]b_segs = b_segs.to(config.device)b_mask = b_mask.to(config.device)b_mlm_label = b_mlm_label.to(config.device)b_nsp_label = b_nsp_label.to(config.device)with record_function("## forward ##"):loss, mlm_logits, nsp_logits = model(input_ids=b_token_ids,attention_mask=b_mask,token_type_ids=b_segs,masked_lm_labels=b_mlm_label,next_sentence_labels=b_nsp_label)optimizer.zero_grad()with record_function("## backward ##"):loss.backward()with record_function("## optimizer ##"):optimizer.step()scheduler.step()losses += loss.item()mlm_acc, _, _, nsp_acc, _, _ = accuracy(mlm_logits, nsp_logits, b_mlm_label,b_nsp_label, data_loader.PAD_IDX)if idx % 20 == 0:logging.info(f"Epoch: [{epoch + 1}/{config.epochs}], Batch[{idx}/{len(train_iter)}], "f"Train loss :{loss.item():.3f}, Train mlm acc: {mlm_acc:.3f},"f"nsp acc: {nsp_acc:.3f}")config.writer.add_scalar('Training/Loss', loss.item(), scheduler.last_epoch)config.writer.add_scalar('Training/Learning Rate', scheduler.get_last_lr()[0], scheduler.last_epoch)config.writer.add_scalars(main_tag='Training/Accuracy',tag_scalar_dict={'NSP': nsp_acc,'MLM': mlm_acc},global_step=scheduler.last_epoch)# Create the memory snapshot file	        #!!!!!!!!!!!!!!!!!!!export_memory_snapshot()# Stop recording memory snapshot history	#!!!!!!!!!!!!!!!!!!!stop_record_memory_history()end_time = time.time()train_loss = losses / len(train_iter)logging.info(f"Epoch: [{epoch + 1}/{config.epochs}], Train loss: "f"{train_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s")if (epoch + 1) % config.model_val_per_epoch == 0:mlm_acc, nsp_acc = evaluate(config, val_iter, model, data_loader.PAD_IDX)logging.info(f" ### MLM Accuracy on val: {round(mlm_acc, 4)}, "f"NSP Accuracy on val: {round(nsp_acc, 4)}")config.writer.add_scalars(main_tag='Testing/Accuracy',tag_scalar_dict={'NSP': nsp_acc,'MLM': mlm_acc},global_step=scheduler.last_epoch)

显示结果

将生成的.pickle文件拖拽到:https://docs.pytorch.org/memory_viz。随着训练,内存先增大,opt.step之后由于释放了不需要的梯度就减小了。
随着训练过程进行,内存分配的变化

reference

  • https://pytorch.org/blog/understanding-gpu-memory-1/

相关文章:

  • 做网站工资怎么样广告投放平台
  • 上海网站优化推广公司百度权重等级
  • 成都今日疫情增加seo网站推广批发
  • 网站建设哈尔滨网站建设1南京seo代理
  • 宁波高等级公路建设指挥部网站哪个平台视频资源多
  • 17网做网站seo推广软件
  • 八股文——JAVA基础:说一下C++与java的区别
  • 工业级3D设计理念:如何平衡功能性与美学的矛盾点?
  • el-upload的before-upload中请求写法
  • 【Docker基础】Docker容器管理:docker pause、stop、kill区别
  • PDF24 Creator绿色便携版v11.26.0
  • 系统思考:预防重于治疗
  • CVPR-2025 | 上交拥挤无序环境下的具身导航最新基准!RoboSense:以机器人为中心的具身感知与导航大规模数据集
  • 通过pyqt5学习MVC
  • nn.Embedding 和 word2vec 的区别
  • Boosting:从理论到实践——集成学习中的偏差征服者
  • 【番外篇】TLS指纹
  • 设计模式-桥接模式、组合模式
  • 龙虎榜——20250625
  • CSP-J 题单
  • 数据赋能(323)——安全与合规——诚信原则
  • Ruoyi-Vue 升级JDK21、Springboot3、Mybatis3
  • 【GStreamer】减小延时的参数设置、从RTP中获取时间戳
  • 鸿蒙ArkUI---基础组件Tabs(Tabbar)
  • 用Rust写平衡三进制乘法器
  • Linux size命令详解