当前位置: 首页 > news >正文

李沐-第六章-LeNet训练中的pycharm jupyter-notebook Animator类的显示问题

问题

系统: win11
显卡:5060
CUDA:12.8
pytorch:2.7.1+cu128
pycharm:2025.2

使用d2l中的动图显示Animator类和train_ch6训练函数,在浏览器的jupyter notebook中可以正常显示动图,但是在Pycharm的集成jupyter中,开始训练时动图显示一闪而过,只有在训练结束时才显示。
原Animator类:

class Animator:"""For plotting data in animation."""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):"""Defined in :numref:`sec_utils`"""# Incrementally plot multiple linesif legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# Use a lambda function to capture argumentsself.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# Add multiple data points into the figureif not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True) # 需要改的就是这里!

原train_ch6函数:

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""Train a model with a GPU (defined in Chapter 6).Defined in :numref:`sec_utils`"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Sum of training loss, sum of training accuracy, no. of examplesmetric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

解决方法

在Animator类的add函数中,将最后一行的:

display.clear_output(wait=True)

改为:

display.clear_output(wait=False)

这是为了不用等图像内容刷新就显示。解决显示图像一闪而过的问题。

在train_ch6函数里,在最后两行print函数前,加入:

display.clear_output(wait=False)

这是为了解决最终输出会有两张图的问题。

效果

动态图正常显示。
这个方法只是解决了pycharm的问题,如果在浏览器中使用jupyter则会出现图像闪的问题。所以,为了避免和其他部分冲突,可以自己定义新函数,不修改d2l库代码。

pycharm jupyter中图像正常显示不会一闪而过

http://www.dtcms.com/a/330971.html

相关文章:

  • 轻松同步 Outlook 联系人到 Android
  • 深入解析SAE自动驾驶分级标准(0-5级)及典型落地实例
  • Ubuntu 软件源版本不匹配导致的依赖冲突问题及解决方法
  • C++ 23种设计模式的分类总结
  • C++23输出革命:std::print的崛起与工业界标准滞后的现实困境
  • 18.12 BERT问答系统核心难题:3步攻克Tokenizer答案定位与动态填充实战
  • c/c++ UNIX 域Socket和共享内存实现本机通信
  • 2021睿抗决赛 猛犸不上 Ban
  • diffusers库学习--pipeline,模型,调度器的基础使用
  • 深入解析Prompt缓存机制:原理、优化与实践经验
  • Centos9傻瓜式linux部署CRMEB 开源商城系统(PHP)
  • 流式数据服务端怎么传给前端,前端怎么接收?
  • Keil 微库(MicroLib)深度解析
  • USB 3.0 协议层 包定义
  • 微软对传统网页设计工具在2010年停止开发
  • Sql server 命令行和控制台使用二三事
  • web网站开发,在线%射击比赛成绩管理%系统开发demo,基于html,css,jquery,python,django,model,orm,mysql数据库
  • 一文讲透Go语言并发模型
  • Pytest本地插件定制及发布指南
  • Redis7学习--十大数据类型 bitmap、Hyperloglog、GEO、Stream、bitfield
  • PAT乙级_1073 多选题常见计分法_Python_AC解法_含疑难点
  • mysql查询中的filesort是指什么
  • Linux软件编程:进程和线程
  • 火山引擎数智平台发布 Data Agent“一客一策“与 AI 数据湖“算子广场“
  • 【Python】新手入门:什么是python字符编码?python标识符?什么是pyhon保留字?
  • 【数据集介绍】多种飞机检测的YOLO数据集介绍
  • 服务器数据恢复—误删服务器卷数据的数据恢复案例
  • 配置docker pull走http代理
  • 集成电路学习:什么是Video Processing视频处理
  • 网络原理-HTTP