第三十八课:实战案例-飞鸟和飞机的识别
PyTorch飞鸟与飞机识别大作战 🐦✈️
大家好!这节课我们要用PyTorch打造一个"鸟类观察家兼航空管制员"AI系统,让它能区分天上飞的是鸟还是飞机。我们会用CIFAR-10数据集(已经包含这两类图片),全程笑料不断,保证学得开心!
1. 准备"望远镜"(环境设置)
首先安装必要的库(如果你还没安装的话):
pip install torch torchvision matplotlib
2. 获取"观察记录"(数据集)
CIFAR-10数据集就像一本包含10类物体的相册,其中正好有鸟(class 2)和飞机(class 0):
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 定义我们的"图像增强望远镜"(数据增强)
transform = transforms.Compose([transforms.ToTensor(), # 把图片变成数字张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化
])# 下载"观察记录本"(数据集)
train_data = datasets.CIFAR10(root='data', train=True,download=True, # 如果本地没有就下载transform=transform
)test_data = datasets.CIFAR10(root='data',train=False,download=True,transform=transform
)# 只看鸟和飞机(类别0和2)
bird_plane_train = [(img, label) for img, label in train_data if label in [0, 2]]
bird_plane_test = [(img, label) for img, label in test_data if label in [0, 2]]print(f"训练集大小: {len(bird_plane_train)} 张")
print(f"测试集大小: {len(bird_plane_test)} 张")
3. 看看我们的"观察对象"
让我们随机看看几张图片,猜猜是鸟还是飞机:
# 定义类别标签
classes = {0: '飞机', 2: '鸟'}# 显示一些样本
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
for i, ax in enumerate(axes):img, label = bird_plane_train[i]img = img / 2 + 0.5 # 反标准化ax.imshow(img.permute(1, 2, 0)) # 调整维度顺序ax.set_title(classes[label])ax.axis('off')
plt.suptitle("天上飞的到底是...?", fontsize=16)
plt.show()
4. 打造"智能望远镜"(模型)
我们要建一个CNN模型,就像给AI配了台高级望远镜:
import torch.nn as nn
import torch.nn.functional as Fclass BirdPlaneSpotter(nn.Module):def __init__(self):super().__init__()# 第一组"镜片"self.conv1 = nn.Conv2d(3, 16, 3, padding=1) # 3通道输入,16个滤镜,3x3卷积核self.pool = nn.MaxPool2d(2, 2) # 缩小图像尺寸# 第二组"镜片"self.conv2 = nn.Conv2d(16, 32, 3, padding=1)# 智能分析系统self.fc1 = nn.Linear(32 * 8 * 8, 256) # 全连接层self.fc2 = nn.Linear(256, 2) # 输出2类: 鸟或飞机self.dropout = nn.Dropout(0.2) # 防止过度关注细节def forward(self, x):# 第一轮观察x = self.pool(F.relu(self.conv1(x)))# 第二轮更仔细的观察x = self.pool(F.relu(self.conv2(x)))# 展平图像特征x = x.view(-1, 32 * 8 * 8)# 分析决策x = self.dropout(F.relu(self.fc1(x)))x = self.fc2(x)return x# 创建我们的智能望远镜
model = BirdPlaneSpotter()
print("我们的智能望远镜结构:")
print(model)
5. 训练"观察员"(模型训练)
现在要教我们的AI怎么区分鸟和飞机了:
from torch.utils.data import DataLoader
import torch.optim as optim# 准备数据加载器
train_loader = DataLoader(bird_plane_train, batch_size=32, shuffle=True)
test_loader = DataLoader(bird_plane_test, batch_size=32)# 选择"学习教材"(损失函数和优化器)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 开始训练!
epochs = 10
train_losses, test_losses = [], []print("\n开始训练观察员...")
for epoch in range(epochs):# 训练模式model.train()running_loss = 0.0for images, labels in train_loader:# 把标签中的2(鸟)改成1,因为我们要二分类labels = torch.where(labels == 2, torch.tensor(1), torch.tensor(0))optimizer.zero_grad() # 清空之前的观察笔记outputs = model(images) # 进行观察loss = criterion(outputs, labels) # 计算错误loss.backward() # 反思哪里看错了optimizer.step() # 调整观察方法running_loss += loss.item()# 测试模式model.eval()test_loss = 0.0accuracy = 0.0with torch.no_grad(): # 测试时不需要记笔记for images, labels in test_loader:labels = torch.where(labels == 2, torch.tensor(1), torch.tensor(0))outputs = model(images)test_loss += criterion(outputs, labels).item()# 计算准确率_, predicted = torch.max(outputs, 1)accuracy += (predicted == labels).sum().item()# 统计本轮表现train_loss = running_loss / len(train_loader)test_loss = test_loss / len(test_loader)accuracy = accuracy / len(bird_plane_test)train_losses.append(train_loss)test_losses.append(test_loss)print(f"第{epoch+1}期训练 | "f"训练损失: {train_loss:.3f} | "f"测试损失: {test_loss:.3f} | "f"准确率: {accuracy*100:.2f}%")
6. 查看"训练成果"
让我们看看AI学得怎么样:
# 绘制学习曲线
plt.plot(train_losses, label='训练损失')
plt.plot(test_losses, label='测试损失')
plt.legend()
plt.title('观察员的学习进步曲线')
plt.show()# 随机测试几个样本
model.eval()
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
for i, ax in enumerate(axes.flat):img, label = bird_plane_test[i]true_label = '鸟' if label == 2 else '飞机'with torch.no_grad():output = model(img.unsqueeze(0))_, predicted = torch.max(output, 1)predicted_label = '鸟' if predicted == 1 else '飞机'img = img / 2 + 0.5 # 反标准化ax.imshow(img.permute(1, 2, 0))ax.set_title(f"真: {true_label}\n预测: {predicted_label}")ax.axis('off')
plt.tight_layout()
plt.suptitle('AI观察员的测试表现', y=1.02, fontsize=16)
plt.show()
7. 保存"优秀观察员"
训练好的模型可以保存下来以后使用:
# 保存模型
torch.save(model.state_dict(), 'bird_plane_spottter.pth')# 以后加载模型
# model = BirdPlaneSpotter()
# model.load_state_dict(torch.load('bird_plane_spottter.pth'))
8.AI观察员成长记
- 数据集:就像给AI的"鸟类和飞机图鉴"
- CNN模型:是AI的"智能望远镜系统"
- 训练过程:相当于老观察员带徒弟:
- 徒弟:“那个是鸟吗?”
- 师傅:“笨蛋!那是飞机!看翅膀形状!”
- 测试阶段:就像毕业考试,看看能不能独立工作
记住这个口诀:
想要区分鸟和机,
PyTorch帮你造神器,
CNN是核心科技,
数据加载要仔细,
训练就像教徒弟,
测试别忘eval(),
保存模型记心里,
下次直接能装逼!
现在你的AI已经是个合格的"天空观察员"了!试着用你自己的图片测试它吧(可能需要调整大小到32x32)~ 🦅🛫