AI技术实战:从零搭建图像分类系统全流程详解
AI技术实战:从零搭建图像分类系统全流程详解
人工智能学习 https://www.captainbed.cn/ccc
前言
本文将以图像分类任务为切入点,手把手教你完成AI模型从数据准备到工业部署的全链路开发。通过一个完整的Kaggle猫狗分类项目(代码兼容PyTorch/TensorFlow),覆盖以下核心技能:
- 数据清洗与增强的工程化实现
- 模型构建与训练技巧
- 模型压缩与TensorRT部署优化
- 可视化监控与性能调优
所有代码均提供可运行的Colab链接,建议边阅读边实践。
目录
-
环境搭建与数据准备
- 1.1 本地/云端开发环境配置
- 1.2 数据爬取与清洗脚本开发
- 1.3 自动化标注工具实战
-
图像分类模型实战
- 2.1 手写CNN模型构建(带可运行代码)
- 2.2 迁移学习Fine-tuning技巧
- 2.3 训练过程可视化监控
-
模型优化与部署
- 3.1 模型剪枝与量化压缩
- 3.2 ONNX格式转换与TensorRT加速
- 3.3 RESTful API服务封装
-
工业级增强技巧
- 4.1 解决类别不平衡问题
- 4.2 应对小样本学习的策略
- 4.3 模型热更新方案
1. 环境搭建与数据准备
1.1 开发环境配置(PyTorch示例)
# 创建虚拟环境
conda create -n ai_tutorial python=3.8
conda activate ai_tutorial
# 安装核心依赖
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install opencv-python albumentations pandas
1.2 数据爬取实战
# 使用Bing图片下载API批量获取数据
import requests
def download_images(keyword, count=100):
headers = {'Ocp-Apim-Subscription-Key': 'YOUR_API_KEY'}
params = {'q': keyword, 'count': count}
response = requests.get('https://api.bing.microsoft.com/v7.0/images/search',
headers=headers, params=params)
for idx, img in enumerate(response.json()['value']):
img_data = requests.get(img['contentUrl']).content
with open(f'dataset/{keyword}_{idx}.jpg', 'wb') as f:
f.write(img_data)
# 执行下载
download_images('cat')
download_images('dog')
1.3 自动化数据清洗
# 使用OpenCV过滤损坏图片
import cv2
import os
def clean_dataset(folder):
valid_extensions = ['.jpg', '.jpeg', '.png']
for filename in os.listdir(folder):
filepath = os.path.join(folder, filename)
try:
img = cv2.imread(filepath)
if img is None or img.size == 0:
os.remove(filepath)
elif os.path.splitext(filename)[1].lower() not in valid_extensions:
os.remove(filepath)
except Exception as e:
print(f"删除损坏文件: {filename}")
os.remove(filepath)
clean_dataset('dataset/train')
2. 图像分类模型实战
2.1 自定义CNN模型(PyTorch实现)
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1), # 输入3通道
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Linear(128 * 28 * 28, 512), # 根据输入尺寸调整
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
2.2 迁移学习实战(ResNet50微调)
from torchvision import models
# 加载预训练模型
model = models.resnet50(pretrained=True)
# 替换最后一层全连接
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 2)
)
# 冻结早期层参数
for param in model.parameters():
param.requires_grad = False
for param in model.layer4.parameters():
param.requires_grad = True
2.3 训练过程可视化(TensorBoard集成)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
for epoch in range(epochs):
# 训练代码...
writer.add_scalar('Loss/train', loss.item(), epoch)
writer.add_scalar('Accuracy/train', acc, epoch)
# 可视化特征图
if epoch % 10 == 0:
writer.add_images('Feature Maps',
model.features[0](images[:4]),
epoch)
3. 模型优化与部署
3.1 模型剪枝实战
import torch.nn.utils.prune as prune
# 对卷积层进行L1非结构化剪枝
parameters_to_prune = (
(model.features[0], 'weight'),
(model.features[3], 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2, # 剪枝20%的权重
)
3.2 TensorRT加速部署
# 导出ONNX模型
torch.onnx.export(model,
dummy_input,
"model.onnx",
opset_version=11)
# 使用TensorRT转换
trt_cmd = f"""
trtexec --onnx=model.onnx \
--saveEngine=model.trt \
--fp16 \
--workspace=2048
"""
os.system(trt_cmd)
3.3 封装Flask API服务
from flask import Flask, request
import trt_inference # 自定义TRT推理模块
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
file = request.files['image']
img = preprocess(file.read())
output = trt_inference.run(img)
return {'class_id': int(output.argmax())}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
4. 工业级增强技巧
4.1 类别不平衡解决方案
# 使用加权采样器
from torch.utils.data import WeightedRandomSampler
class_counts = [num_cat, num_dog]
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[labels]
sampler = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights),
replacement=True
)
4.2 小样本学习方案
# 使用MixUp数据增强
def mixup_data(x, y, alpha=0.2):
lam = np.random.beta(alpha, alpha)
batch_size = x.size()[0]
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
# 修改损失函数
criterion = nn.CrossEntropyLoss()
loss = lam * criterion(output, y_a) + (1 - lam) * criterion(output, y_b)