当前位置: 首页 > news >正文

第三十八课:实战案例-飞鸟和飞机的识别

PyTorch飞鸟与飞机识别大作战 🐦✈️

大家好!这节课我们要用PyTorch打造一个"鸟类观察家兼航空管制员"AI系统,让它能区分天上飞的是鸟还是飞机。我们会用CIFAR-10数据集(已经包含这两类图片),全程笑料不断,保证学得开心!

1. 准备"望远镜"(环境设置)

首先安装必要的库(如果你还没安装的话):

pip install torch torchvision matplotlib

2. 获取"观察记录"(数据集)

CIFAR-10数据集就像一本包含10类物体的相册,其中正好有鸟(class 2)和飞机(class 0):

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 定义我们的"图像增强望远镜"(数据增强)
transform = transforms.Compose([transforms.ToTensor(),  # 把图片变成数字张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])# 下载"观察记录本"(数据集)
train_data = datasets.CIFAR10(root='data', train=True,download=True,  # 如果本地没有就下载transform=transform
)test_data = datasets.CIFAR10(root='data',train=False,download=True,transform=transform
)# 只看鸟和飞机(类别0和2)
bird_plane_train = [(img, label) for img, label in train_data if label in [0, 2]]
bird_plane_test = [(img, label) for img, label in test_data if label in [0, 2]]print(f"训练集大小: {len(bird_plane_train)} 张")
print(f"测试集大小: {len(bird_plane_test)} 张")

3. 看看我们的"观察对象"

让我们随机看看几张图片,猜猜是鸟还是飞机:

# 定义类别标签
classes = {0: '飞机', 2: '鸟'}# 显示一些样本
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
for i, ax in enumerate(axes):img, label = bird_plane_train[i]img = img / 2 + 0.5  # 反标准化ax.imshow(img.permute(1, 2, 0))  # 调整维度顺序ax.set_title(classes[label])ax.axis('off')
plt.suptitle("天上飞的到底是...?", fontsize=16)
plt.show()

4. 打造"智能望远镜"(模型)

我们要建一个CNN模型,就像给AI配了台高级望远镜:

import torch.nn as nn
import torch.nn.functional as Fclass BirdPlaneSpotter(nn.Module):def __init__(self):super().__init__()# 第一组"镜片"self.conv1 = nn.Conv2d(3, 16, 3, padding=1)  # 3通道输入,16个滤镜,3x3卷积核self.pool = nn.MaxPool2d(2, 2)  # 缩小图像尺寸# 第二组"镜片"self.conv2 = nn.Conv2d(16, 32, 3, padding=1)# 智能分析系统self.fc1 = nn.Linear(32 * 8 * 8, 256)  # 全连接层self.fc2 = nn.Linear(256, 2)  # 输出2类: 鸟或飞机self.dropout = nn.Dropout(0.2)  # 防止过度关注细节def forward(self, x):# 第一轮观察x = self.pool(F.relu(self.conv1(x)))# 第二轮更仔细的观察x = self.pool(F.relu(self.conv2(x)))# 展平图像特征x = x.view(-1, 32 * 8 * 8)# 分析决策x = self.dropout(F.relu(self.fc1(x)))x = self.fc2(x)return x# 创建我们的智能望远镜
model = BirdPlaneSpotter()
print("我们的智能望远镜结构:")
print(model)

5. 训练"观察员"(模型训练)

现在要教我们的AI怎么区分鸟和飞机了:

from torch.utils.data import DataLoader
import torch.optim as optim# 准备数据加载器
train_loader = DataLoader(bird_plane_train, batch_size=32, shuffle=True)
test_loader = DataLoader(bird_plane_test, batch_size=32)# 选择"学习教材"(损失函数和优化器)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 开始训练!
epochs = 10
train_losses, test_losses = [], []print("\n开始训练观察员...")
for epoch in range(epochs):# 训练模式model.train()running_loss = 0.0for images, labels in train_loader:# 把标签中的2(鸟)改成1,因为我们要二分类labels = torch.where(labels == 2, torch.tensor(1), torch.tensor(0))optimizer.zero_grad()  # 清空之前的观察笔记outputs = model(images)  # 进行观察loss = criterion(outputs, labels)  # 计算错误loss.backward()  # 反思哪里看错了optimizer.step()  # 调整观察方法running_loss += loss.item()# 测试模式model.eval()test_loss = 0.0accuracy = 0.0with torch.no_grad():  # 测试时不需要记笔记for images, labels in test_loader:labels = torch.where(labels == 2, torch.tensor(1), torch.tensor(0))outputs = model(images)test_loss += criterion(outputs, labels).item()# 计算准确率_, predicted = torch.max(outputs, 1)accuracy += (predicted == labels).sum().item()# 统计本轮表现train_loss = running_loss / len(train_loader)test_loss = test_loss / len(test_loader)accuracy = accuracy / len(bird_plane_test)train_losses.append(train_loss)test_losses.append(test_loss)print(f"第{epoch+1}期训练 | "f"训练损失: {train_loss:.3f} | "f"测试损失: {test_loss:.3f} | "f"准确率: {accuracy*100:.2f}%")

6. 查看"训练成果"

让我们看看AI学得怎么样:

# 绘制学习曲线
plt.plot(train_losses, label='训练损失')
plt.plot(test_losses, label='测试损失')
plt.legend()
plt.title('观察员的学习进步曲线')
plt.show()# 随机测试几个样本
model.eval()
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
for i, ax in enumerate(axes.flat):img, label = bird_plane_test[i]true_label = '鸟' if label == 2 else '飞机'with torch.no_grad():output = model(img.unsqueeze(0))_, predicted = torch.max(output, 1)predicted_label = '鸟' if predicted == 1 else '飞机'img = img / 2 + 0.5  # 反标准化ax.imshow(img.permute(1, 2, 0))ax.set_title(f"真: {true_label}\n预测: {predicted_label}")ax.axis('off')
plt.tight_layout()
plt.suptitle('AI观察员的测试表现', y=1.02, fontsize=16)
plt.show()

7. 保存"优秀观察员"

训练好的模型可以保存下来以后使用:

# 保存模型
torch.save(model.state_dict(), 'bird_plane_spottter.pth')# 以后加载模型
# model = BirdPlaneSpotter()
# model.load_state_dict(torch.load('bird_plane_spottter.pth'))

8.AI观察员成长记

  • 数据集:就像给AI的"鸟类和飞机图鉴"
  • CNN模型:是AI的"智能望远镜系统"
  • 训练过程:相当于老观察员带徒弟:
    • 徒弟:“那个是鸟吗?”
    • 师傅:“笨蛋!那是飞机!看翅膀形状!”
  • 测试阶段:就像毕业考试,看看能不能独立工作

记住这个口诀:

想要区分鸟和机,
PyTorch帮你造神器,
CNN是核心科技,
数据加载要仔细,
训练就像教徒弟,
测试别忘eval(),
保存模型记心里,
下次直接能装逼!

现在你的AI已经是个合格的"天空观察员"了!试着用你自己的图片测试它吧(可能需要调整大小到32x32)~ 🦅🛫

相关文章:

  • 《性能之巅》第三章 操作系统
  • AI时代,学习力进化指南:如何成为知识的主人?
  • Java(网络编程)
  • unittest 和 pytest 框架
  • 浅谈软件开发工作流
  • Vue3 Router 使用指南:从基础到高级用法
  • openEuler虚拟机中容器化部署
  • springboot+mybatis面试题
  • CQF预备知识:Python相关库 -- 插值过渡指南 scipy.interpolate
  • 接口测试常用工具及测试方法(基础篇)
  • [SKE]CPU 与 GPU 之间数据加密传输的认证与异常处理
  • 触觉智能RK3576核心板工业应用之软硬件全国产化,成功适配开源鸿蒙OpenHarmony5.0
  • aws s3 sdk c++使用指南、适配阿里云oss和aws
  • OCCT 中 BRepBuilderAPI_MakePolygon与BRepBuilderAPI_MakeWire
  • 5种常见的网络保密通信协议
  • 如何从 Ansys SpaceClaim 模型中提取 CAD 数据,该模型是在我计算机上安装的未来版本中创建的?
  • 亚马逊云服务器配置推荐
  • SMB协议在Windows内网中的核心地位
  • 华为:eSight网管平台使用snmp纳管交换机
  • React---Hooks深入
  • 国外域名注册做违法网站/西安sem竞价托管
  • 河北廊坊建设银行网站/怎么找需要做推广的公司
  • 赣州住房与城乡建设厅网站/天津关键词排名推广
  • 门户网站流程图/百度服务电话6988
  • 总公司网站备案后 分公司网站还需要备案吗/长春seo公司
  • 建站行业/新媒体seo培训