公司网站建设方案拓扑图武汉全网营销推广公司
在上一篇文章中,我们探讨了联邦学习与隐私保护技术。本文将深入介绍神经架构搜索(Neural Architecture Search, NAS)这一自动化机器学习方法,它能够自动设计高性能的神经网络架构。我们将使用PyTorch实现基于梯度优化的DARTS方法,并在CIFAR-10数据集上进行验证。
一、神经架构搜索基础
神经架构搜索是AutoML的核心技术之一,旨在自动化神经网络设计过程。
1. NAS的核心组件
组件 | 描述 | 典型实现 |
---|---|---|
搜索空间 | 定义可能架构的集合 | 细胞结构、宏架构 |
搜索策略 | 探索搜索空间的方法 | 强化学习、进化算法、梯度优化 |
性能评估 | 评估架构质量的方式 | 代理指标、权重共享 |
2. 主流NAS方法对比
class NASMethod(Enum):RL_BASED = "基于强化学习" # Google早期方案EVOLUTIONARY = "进化算法" # Google Brain提出GRADIENT_BASED = "梯度优化" # DARTS为代表ONESHOT = "权重共享" # ENAS、ProxylessNAS
3. DARTS数学原理
DARTS(Differentiable ARchiTecture Search)将离散架构搜索转化为连续优化问题:
二、DARTS实战:CIFAR-10图像分类
1. 环境配置
pip install torch torchvision matplotlib graphviz
2. 实现可微分架构搜索
2.1 搜索空间定义
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from matplotlib import pyplot as plt
import copy
from graphviz import Digraph
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 操作类型集合
OPS = {'none': lambda C, stride: Zero(stride),'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),'conv_3x3': lambda C, stride: ConvBNReLU(C, C, 3, stride, 1),'conv_5x5': lambda C, stride: ConvBNReLU(C, C, 5, stride, 2),'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 4, 2),'max_pool_3x3': lambda C, stride: PoolBN('max', C, 3, stride, 1),'avg_pool_3x3': lambda C, stride: PoolBN('avg', C, 3, stride, 1)
}
# 基础操作模块
class ConvBNReLU(nn.Module):def __init__(self, C_in, C_out, kernel_size, stride, padding):super().__init__()self.op = nn.Sequential(nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(C_out),nn.ReLU(inplace=False))
def forward(self, x):return self.op(x)
class DilConv(nn.Module):def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):super().__init__()self.op = nn.Sequential(nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation, groups=C_in, bias=False),nn.Conv2d(C_in, C_out, 1, padding=0, bias=False),nn.BatchNorm2d(C_out),nn.ReLU(inplace=False))
def forward(self, x):return self.op(x)
class PoolBN(nn.Module):def __init__(self, pool_type, C, kernel_size, stride, padding):super().__init__()if pool_type == 'max':self.pool = nn.MaxPool2d(kernel_size, stride, padding)elif pool_type == 'avg':self.pool = nn.AvgPool2d(kernel_size, stride, padding)else:raise ValueError()self.bn = nn.BatchNorm2d(C)
def forward(self, x):return self.bn(self.pool(x))
class Identity(nn.Module):def __init__(self):super().__init__()
def forward(self, x):return x
class Zero(nn.Module):def __init__(self, stride):super().__init__()self.stride = stride
def forward(self, x):if self.stride == 1:return x.mul(0.)return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):def __init__(self, C_in, C_out):super().__init__()self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):return self.bn(torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1))
2.2 可微分细胞结构实现
class MixedOp(nn.Module):"""混合操作实现"""
def __init__(self, C, stride):super().__init__()self._ops = nn.ModuleList()for primitive in OPS.keys():op = OPS[primitive](C, stride)self._ops.append(op)
def forward(self, x, weights):return sum(w * op(x) for w, op in zip(weights, self._ops))
class Cell(nn.Module):"""可微分细胞结构"""
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):super().__init__()self.reduction = reductionself.steps = stepsself.multiplier = multiplier
# 预处理节点if reduction_prev:self.preprocess0 = FactorizedReduce(C_prev_prev, C)else:self.preprocess0 = ConvBNReLU(C_prev_prev, C, 1, 1, 0)self.preprocess1 = ConvBNReLU(C_prev, C, 1, 1, 0)
# 构建DAG结构self._ops = nn.ModuleList()self._bns = nn.ModuleList()for i in range(self.steps):for j in range(2 + i):stride = 2 if reduction and j < 2 else 1op = MixedOp(C, stride)self._ops.append(op)
def forward(self, s0, s1, weights):s0 = self.preprocess0(s0)s1 = self.preprocess1(s1)
states = [s0, s1]offset = 0for i in range(self.steps):s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states))offset += len(states)states.append(s)
return torch.cat(states[-self.multiplier:], dim=1)
2.3 完整搜索网络
class Network(nn.Module):"""可微分架构搜索网络"""
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):super().__init__()self._C = Cself._num_classes = num_classesself._layers = layersself._criterion = criterionself._steps = stepsself._multiplier = multiplier
C_curr = stem_multiplier * Cself.stem = nn.Sequential(nn.Conv2d(3, C_curr, 3, padding=1, bias=False),nn.BatchNorm2d(C_curr))
C_prev_prev, C_prev, C_curr = C_curr, C_curr, Cself.cells = nn.ModuleList()reduction_prev = Falsefor i in range(layers):if i in [layers // 3, 2 * layers // 3]:C_curr *= 2reduction = Trueelse:reduction = Falsecell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)reduction_prev = reductionself.cells.append(cell)C_prev_prev, C_prev = C_prev, multiplier * C_curr
self.global_pooling = nn.AdaptiveAvgPool2d(1)self.classifier = nn.Linear(C_prev, num_classes)
# 架构参数k = sum(2 + i for i in range(steps))num_ops = len(OPS)self._alphas = nn.Parameter(1e-3 * torch.randn(k, num_ops)) # 使用随机初始化
# 修正优化器初始化self._arch_optimizer = torch.optim.Adam([self._alphas], lr=6e-4, betas=(0.5, 0.999))
def forward(self, x):s0 = s1 = self.stem(x)weights = F.softmax(self._alphas, dim=-1)
for cell in self.cells:s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)logits = self.classifier(out.view(out.size(0), -1))return logits
def _loss(self, input, target):logits = self(input)# 添加L1正则化reg_loss = 0.01 * torch.sum(torch.exp(-self._alphas))return self._criterion(logits, target) + reg_loss
def arch_parameters(self):return [self._alphas]
def genotype(self):"""从架构参数导出离散架构"""
def _parse(weights):gene = []start = 0for i in range(self._steps):end = start + i + 2W = weights[start:end].copy()edges = []for j in range(2 + i):k_best = Nonefor k in range(len(W[j])):if k_best is None or W[j][k] > W[j][k_best]:k_best = kedges.append((list(OPS.keys())[k_best], j)) # 修正OPS.keys()索引gene.append(edges)start = endreturn gene
gene_normal = _parse(F.softmax(self._alphas, dim=-1).data.cpu().numpy())return gene_normal
def plot_genotype(self, filename):"""可视化基因型"""dot = Digraph(format='png')
for i, edges in enumerate(self.genotype()):for op, j in edges:dot.edge(str(j), str(i + 2), label=op)
dot.node("0", fillcolor='lightblue', style='filled')dot.node("1", fillcolor='lightblue', style='filled')
dot.render(filename, view=True)
3. 搜索算法实现
class DARTS:def __init__(self, model, train_loader, val_loader, epochs=50):self.model = model.to(device)self.train_loader = train_loaderself.val_loader = val_loaderself.epochs = epochs
# 优化器self.optimizer = torch.optim.SGD(model.parameters(), lr=0.025, momentum=0.9, weight_decay=3e-4)self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, epochs, eta_min=0.001)
# 架构搜索参数self.arch_optimizer = torch.optim.Adam(model.arch_parameters(), lr=3e-4, betas=(0.5, 0.999))
def _train(self):self.model.train()train_loss = 0correct = 0total = 0
for inputs, targets in self.train_loader:inputs, targets = inputs.to(device), targets.to(device)
# 更新架构参数self.arch_optimizer.zero_grad()arch_loss = self.model._loss(inputs, targets)arch_loss.backward()self.arch_optimizer.step()
# 更新模型权重self.optimizer.zero_grad()loss = self.model._loss(inputs, targets)loss.backward()nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)self.optimizer.step()
train_loss += loss.item()_, predicted = self.model(inputs).max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()
return train_loss / len(self.train_loader), 100. * correct / total
def _validate(self):self.model.eval()val_loss = 0correct = 0total = 0
with torch.no_grad():for inputs, targets in self.val_loader:inputs, targets = inputs.to(device), targets.to(device)outputs = self.model(inputs)loss = self.model._loss(inputs, targets)
val_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()
return val_loss / len(self.val_loader), 100. * correct / total
def search(self):best_acc = 0history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
for epoch in range(self.epochs):train_loss, train_acc = self._train()val_loss, val_acc = self._validate()self.scheduler.step()
history['train_loss'].append(train_loss)history['val_loss'].append(val_loss)history['train_acc'].append(train_acc)history['val_acc'].append(val_acc)
if val_acc > best_acc:best_acc = val_accbest_genotype = copy.deepcopy(self.model.genotype())
print(f"Epoch: {epoch + 1}/{self.epochs} | "f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
return best_genotype, history
4. 完整训练流程
# 数据准备
def prepare_data(batch_size=64, val_ratio=0.1):transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)val_size = int(val_ratio * len(full_dataset))train_size = len(full_dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader
# 主函数
def main():train_loader, val_loader = prepare_data()
# 初始化模型criterion = nn.CrossEntropyLoss().to(device)model = Network(C=16, num_classes=10, layers=8, criterion=criterion)
# 开始搜索darts = DARTS(model, train_loader, val_loader, epochs=50)best_genotype, history = darts.search()
# 保存结果print("Best Genotype:", best_genotype)model.plot_genotype("best_architecture")
# 绘制训练曲线plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Train')plt.plot(history['val_loss'], label='Validation')plt.title('Loss Curve')plt.legend()
plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Train')plt.plot(history['val_acc'], label='Validation')plt.title('Accuracy Curve')plt.legend()
plt.savefig('search_progress.png')plt.show()
if __name__ == "__main__":main()
输出为:
使用设备: cuda
Files already downloaded and verified
Epoch: 1/50 | Train Loss: 2.3124 | Val Loss: 2.1910 | Train Acc: 46.93% | Val Acc: 55.88%
Epoch: 2/50 | Train Loss: 1.3163 | Val Loss: 1.1209 | Train Acc: 67.39% | Val Acc: 68.60%
Epoch: 3/50 | Train Loss: 1.0027 | Val Loss: 0.8909 | Train Acc: 75.54% | Val Acc: 75.54%
Epoch: 4/50 | Train Loss: 0.8162 | Val Loss: 0.7608 | Train Acc: 80.52% | Val Acc: 79.12%
Epoch: 5/50 | Train Loss: 0.7098 | Val Loss: 0.7313 | Train Acc: 83.37% | Val Acc: 79.44%
Epoch: 6/50 | Train Loss: 0.6268 | Val Loss: 0.6517 | Train Acc: 85.73% | Val Acc: 82.10%
Epoch: 7/50 | Train Loss: 0.5662 | Val Loss: 0.6084 | Train Acc: 87.38% | Val Acc: 83.04%
Epoch: 8/50 | Train Loss: 0.5164 | Val Loss: 0.5669 | Train Acc: 88.96% | Val Acc: 84.56%
Epoch: 9/50 | Train Loss: 0.4790 | Val Loss: 0.5206 | Train Acc: 90.14% | Val Acc: 85.66%
Epoch: 10/50 | Train Loss: 0.4447 | Val Loss: 0.5097 | Train Acc: 91.23% | Val Acc: 85.64%
Epoch: 11/50 | Train Loss: 0.4135 | Val Loss: 0.5081 | Train Acc: 92.16% | Val Acc: 85.78%
Epoch: 12/50 | Train Loss: 0.3887 | Val Loss: 0.5135 | Train Acc: 92.81% | Val Acc: 85.76%
Epoch: 13/50 | Train Loss: 0.3687 | Val Loss: 0.4952 | Train Acc: 93.40% | Val Acc: 86.02%
Epoch: 14/50 | Train Loss: 0.3490 | Val Loss: 0.4915 | Train Acc: 94.02% | Val Acc: 86.72%
Epoch: 15/50 | Train Loss: 0.3323 | Val Loss: 0.5027 | Train Acc: 94.69% | Val Acc: 86.20%
Epoch: 16/50 | Train Loss: 0.3109 | Val Loss: 0.4722 | Train Acc: 95.34% | Val Acc: 87.44%
Epoch: 17/50 | Train Loss: 0.2952 | Val Loss: 0.4687 | Train Acc: 95.74% | Val Acc: 87.14%
Epoch: 18/50 | Train Loss: 0.2780 | Val Loss: 0.4605 | Train Acc: 96.38% | Val Acc: 87.92%
Epoch: 19/50 | Train Loss: 0.2591 | Val Loss: 0.4469 | Train Acc: 96.82% | Val Acc: 88.26%
Epoch: 20/50 | Train Loss: 0.2474 | Val Loss: 0.4479 | Train Acc: 97.22% | Val Acc: 88.04%
Epoch: 21/50 | Train Loss: 0.2371 | Val Loss: 0.4765 | Train Acc: 97.46% | Val Acc: 87.90%
Epoch: 22/50 | Train Loss: 0.2257 | Val Loss: 0.4213 | Train Acc: 97.78% | Val Acc: 89.00%
Epoch: 23/50 | Train Loss: 0.2100 | Val Loss: 0.4625 | Train Acc: 98.21% | Val Acc: 88.38%
Epoch: 24/50 | Train Loss: 0.2045 | Val Loss: 0.4474 | Train Acc: 98.24% | Val Acc: 88.74%
Epoch: 25/50 | Train Loss: 0.1859 | Val Loss: 0.4511 | Train Acc: 98.60% | Val Acc: 88.48%
Epoch: 26/50 | Train Loss: 0.1790 | Val Loss: 0.4307 | Train Acc: 98.81% | Val Acc: 89.54%
Epoch: 27/50 | Train Loss: 0.1644 | Val Loss: 0.4390 | Train Acc: 99.08% | Val Acc: 89.80%
Epoch: 28/50 | Train Loss: 0.1541 | Val Loss: 0.4344 | Train Acc: 99.17% | Val Acc: 89.60%
Epoch: 29/50 | Train Loss: 0.1449 | Val Loss: 0.4176 | Train Acc: 99.32% | Val Acc: 90.34%
Epoch: 30/50 | Train Loss: 0.1352 | Val Loss: 0.3915 | Train Acc: 99.47% | Val Acc: 90.64%
Epoch: 31/50 | Train Loss: 0.1261 | Val Loss: 0.4300 | Train Acc: 99.58% | Val Acc: 90.20%
Epoch: 32/50 | Train Loss: 0.1183 | Val Loss: 0.3936 | Train Acc: 99.67% | Val Acc: 91.10%
Epoch: 33/50 | Train Loss: 0.1056 | Val Loss: 0.3889 | Train Acc: 99.77% | Val Acc: 91.00%
Epoch: 34/50 | Train Loss: 0.0990 | Val Loss: 0.3937 | Train Acc: 99.81% | Val Acc: 91.00%
Epoch: 35/50 | Train Loss: 0.0949 | Val Loss: 0.3694 | Train Acc: 99.77% | Val Acc: 92.16%
Epoch: 36/50 | Train Loss: 0.0862 | Val Loss: 0.3788 | Train Acc: 99.89% | Val Acc: 91.72%
Epoch: 37/50 | Train Loss: 0.0815 | Val Loss: 0.3893 | Train Acc: 99.90% | Val Acc: 91.52%
Epoch: 38/50 | Train Loss: 0.0768 | Val Loss: 0.3847 | Train Acc: 99.92% | Val Acc: 91.92%
Epoch: 39/50 | Train Loss: 0.0729 | Val Loss: 0.3602 | Train Acc: 99.95% | Val Acc: 91.90%
Epoch: 40/50 | Train Loss: 0.0689 | Val Loss: 0.3846 | Train Acc: 99.94% | Val Acc: 91.68%
Epoch: 41/50 | Train Loss: 0.0656 | Val Loss: 0.3361 | Train Acc: 99.95% | Val Acc: 92.62%
Epoch: 42/50 | Train Loss: 0.0625 | Val Loss: 0.3563 | Train Acc: 99.96% | Val Acc: 92.18%
Epoch: 43/50 | Train Loss: 0.0598 | Val Loss: 0.3475 | Train Acc: 99.96% | Val Acc: 92.28%
Epoch: 44/50 | Train Loss: 0.0579 | Val Loss: 0.3468 | Train Acc: 99.94% | Val Acc: 92.22%
Epoch: 45/50 | Train Loss: 0.0561 | Val Loss: 0.3680 | Train Acc: 99.94% | Val Acc: 91.64%
Epoch: 46/50 | Train Loss: 0.0532 | Val Loss: 0.3334 | Train Acc: 99.95% | Val Acc: 92.40%
Epoch: 47/50 | Train Loss: 0.0509 | Val Loss: 0.3381 | Train Acc: 99.96% | Val Acc: 92.50%
Epoch: 48/50 | Train Loss: 0.0493 | Val Loss: 0.3517 | Train Acc: 99.95% | Val Acc: 92.16%
Epoch: 49/50 | Train Loss: 0.0474 | Val Loss: 0.3305 | Train Acc: 99.95% | Val Acc: 92.30%
Epoch: 50/50 | Train Loss: 0.0458 | Val Loss: 0.3305 | Train Acc: 99.94% | Val Acc: 92.84%
Best Genotype: [[('conv_5x5', 0), ('conv_5x5', 1)], [('none', 0), ('conv_5x5', 1), ('conv_5x5', 2)], [('conv_5x5', 0), ('conv_5x5', 1), ('conv_5x5', 2), ('conv_5x5', 3)], [('conv_5x5', 0), ('conv_5x5', 1), ('conv_5x5', 2), ('conv_5x5', 3), ('conv_5x5', 4)]]
perl: warning: Setting locale failed.
perl: warning: Please check that your locale settings:LANGUAGE = (unset),LC_ALL = (unset),LC_CTYPE = "C.UTF-8",LANG = "en_US.UTF-8"are supported and installed on your system.
perl: warning: Falling back to the standard locale ("C").
Error: no "view" mailcap rules found for type "image/png"
/usr/bin/xdg-open: 869: www-browser: not found
/usr/bin/xdg-open: 869: links2: not found
/usr/bin/xdg-open: 869: elinks: not found
/usr/bin/xdg-open: 869: links: not found
/usr/bin/xdg-open: 869: lynx: not found
/usr/bin/xdg-open: 869: w3m: not found
xdg-open: no method available for opening 'best_architecture.png'
错误提示 Error: no "view" mailcap rules found for type "image/png" 是因为系统缺少图片查看工具(如浏览器)
不影响结果:文件 best_architecture.png 仍会生成,但无法自动弹出预览。
三、进阶话题
1. 搜索空间设计技巧
class MacroSearchSpace:"""宏架构搜索空间示例"""def __init__(self):self.resolutions = [224, 192, 160, 128] # 输入分辨率self.depths = [3, 4, 5, 6] # 网络深度self.widths = [32, 64, 96, 128] # 初始通道数self.ops = OPS.keys() # 操作类型
2. 多目标NAS实现
class MultiObjectiveNAS:"""同时优化精度和延迟"""def __init__(self, model, latency_predictor):self.model = modelself.latency_predictor = latency_predictordef evaluate(self, genotype):# 预测延迟latency = self.latency_predictor(genotype)# 评估精度accuracy = evaluate_accuracy(genotype)return {'accuracy': accuracy,'latency': latency,'score': accuracy * (latency ** -0.07) # 平衡因子}
3. 实际应用建议
场景 | 推荐方法 | 理由 |
---|---|---|
移动端部署 | ProxylessNAS | 直接优化目标设备指标 |
研究探索 | DARTS | 灵活可扩展 |
工业级应用 | ENAS | 搜索效率高 |
四、总结与展望
本文实现了基于DARTS的神经架构搜索系统,主要亮点包括:
-
完整实现了可微分架构搜索:包括混合操作、细胞结构和双层优化
-
可视化搜索过程:支持架构基因型的图形化展示
-
实用训练技巧:采用余弦退火学习率等优化策略
在下一篇文章中,我们将探讨图生成模型与分子设计,介绍如何利用深度生成模型设计新型分子结构。