当前位置: 首页 > news >正文

基于FashionMnist数据集的自监督学习(生成式自监督学习AE算法)

目录

一,生成式自监督学习

1.1 简介

1.2 核心思想

1.3 常见算法

1.3.1 自动编码器(Autoencoder)

1.3.2 生成对抗网络(GANs)

1.3.3 变分自编码器(VAE)

1.3.4 Transformer-based 模型(如 BERT、GPT)

1.3.5 扩散模型(Diffusion Models)

1.3.6 自回归模型(Autoregressive Models)

1.3.7 对比总结

二,代码逻辑分析

2.1 数据处理

2.2 模型定义

2.3 模型训练

2.4 主函数逻辑

三,测试结果

3.1 图片重建效果

3.2 分类测试效果

3.3 总结

四,完整代码


一,生成式自监督学习

1.1 简介

        生成式自监督学习(Generative Self-Supervised Learning)是机器学习中一种利用数据自身结构进行无监督学习的方法,其核心思想是通过生成模型构建自监督信号,让模型从无标注数据中自动学习数据的潜在规律和特征表示。这种方法无需人工标注标签,而是利用数据本身的内在关联(如上下文关系、时序依赖、结构特征等)生成训练目标,从而提升模型对数据的理解和生成能力。

1.2 核心思想

        生成式自监督学习就像让机器自己跟自己玩 “猜谜游戏”—— 不用别人告诉它 “答案是什么”,它自己从海量无标注的数据(比如网上的文字、图片)里找规律。比如,给它一段被遮住几个词的句子,它会猜缺失的词是什么;给它一张模糊的图片,它会试着还原清晰的样子;甚至还能根据 “星空下的森林” 这样的描述画出一幅画。通过不断 “猜谜”“还原”“创造”,机器就能自己学会数据里的隐藏逻辑(比如语言的顺序、图像的色彩搭配),实现无师自通,现在很多 AI 写文章、生成图片的能力,背后靠的就是这种 “自己教自己” 的本事。

1.3 常见算法

1.3.1 自动编码器(Autoencoder)

        自动编码器(Autoencoder,AE) 是一种简单且经典的无监督学习模型,核心思想是通过 “压缩 - 重建” 数据来学习数据的潜在特征。它特别适合用于 特征提取、数据压缩、去噪 等任务,尤其在图像领域应用广泛。假设你有一张 28x28 的服装图片(如 T 恤),AE 会先 “压缩” 图片成一个更小的 “特征向量”(比如 64 维),这个向量包含了图片的核心信息(如轮廓、纹理);然后再用这个向量 “重建” 出原始图片。让重建的图片尽可能接近原图,迫使模型学习到最关键的特征。通过自监督任务(重建自身),让模型自动挖掘数据的内在结构,无需人工标注标签。

              输入图像 → 编码器 → 隐向量(特征) → 解码器 → 重建图像 → 与原图比较


1.3.2 生成对抗网络(GANs)

        生成对抗网络(GANs)是一种通过 “对抗博弈” 机制实现数据生成的机器学习模型,其核心逻辑类似 “造假者” 与 “鉴伪专家” 的攻防战:生成器负责将随机噪声 “加工” 成假数据(如伪造的衣服图片),试图以假乱真;判别器则专注鉴别输入数据是真实样本(如 FashionMNIST 真实图片)还是生成器的 “伪造品”,力求火眼金睛。两者在训练中互相博弈 —— 生成器不断优化造假技术让假数据更逼真,判别器持续升级鉴别能力识破套路,最终当生成器的输出能让判别器无法区分真假(概率接近 50%)时,模型便成功学会生成以假乱真的新数据。GANs 的优势在于能创造全新样本(如 FashionMNIST 中不存在的服饰款式),常用于数据增强、图像生成、风格迁移等场景,但其训练难度高,需平衡两者能力以防 “崩溃”。与自动编码器(AE)相比,AE 侧重数据压缩与还原(类似复印机),而 GANs 专注 “无中生有” 的创造性生成(类似艺术家),更适合需要生成新样本的自监督任务,但直接用于分类时提取特征的效率可能不如 AE。


1.3.3 变分自编码器(VAE)

        变分自动编码器(VAE)就像一个会 “猜可能性” 的智能画家:它先观察大量衣服图片(比如 FashionMNIST),学会把每张图 “翻译” 成一个带 “概率标签” 的密码(比如 “这件 T 恤有 70% 可能是蓝色、圆领,30% 可能有条纹”),这个密码不是一个固定的数字,而是一个 “可能性范围”(用均值和方差表示)。然后,它能从这个可能性范围里随机 “抽样”,画出符合这些特征的新衣服(比如生成一件没见过的蓝白条纹 T 恤)。

        它的核心是让生成的新图既要 “像真的”(尽量接近原图,避免变成裤子),又要让所有密码的 “可能性范围” 均匀分布(避免只记住几种固定款式)。相比只能复制原图的 AE(像复印机),VAE 能生成多样化的新样本(比如不同花纹的鞋子);相比靠对抗博弈生成的 GANs(像造假者和警察打架),它更稳定,虽然生成的图可能没那么逼真,但胜在 “可控”(比如能按 “圆领”“长袖” 等特征生成)。在自监督分类里,它提取的密码自带 “衣服类型” 的隐藏信息(比如鞋子和包包的密码差异很大),可以直接用来训练分类器,是一种简单又实用的 “数据翻译官”。


1.3.4 Transformer-based 模型(如 BERT、GPT)

        Transformer-based 模型是一种让 AI 拥有 “全局思维” 的智能架构,核心是通过 “自注意力机制” 让模型处理数据时能像人类一样 “抓重点、理关系”。比如读句子 “小狗叼着骨头跑向主人,因为它饿了”,模型会让 “它” 主动 “看向” 前面的 “小狗”,不管句子多长都能准确建立关联(解决长距离依赖难题)。它的结构分为 “分析员” Encoder 和 “创造者” Decoder:前者负责拆解输入数据(如文本、图像块),用自注意力给每个元素标上 “重点标签”(比如 “骨头” 是 “叼” 的宾语);后者则根据这些标签生成内容(如翻译后的中文句子、对应文字描述的图片),生成时还会反复 “回看” 分析结果,确保逻辑连贯(比如 “跑向主人” 要对应正确的动作方向)。

这种模型的厉害之处在于:

并行处理效率高:不像传统模型只能逐字逐句处理,它能同时分析所有元素的关系(比如同时判断 “小狗”“骨头”“主人” 的关联),处理长文本或大规模数据更快;

跨领域通用:在文本领域能让 ChatGPT 流畅聊天、帮 GPT 写文章,在图像领域能让 ViT 分类图片,甚至在多模态场景(如 Stable Diffusion)中,能把 “夕阳下的沙滩” 文字描述 “翻译” 成逼真图像,靠的就是自注意力把文字和图像块 “配对” 的能力;

长距离记忆强:哪怕前后文隔得很远(如开头的 “小猫” 和结尾的 “它”),也能精准 “牵线”,避免 “失忆”。

Transformer 就像给 AI 大脑装了一个 “全局导航系统”,让它看数据时能快速锁定重点、理清逻辑关系,无论是写文章、翻译、生成图片还是理解复杂内容,都能驾轻就熟,是如今 AI 领域的 “万能底座”,撑起了从聊天机器人到生成式 AI 的核心能力。


1.3.5 扩散模型(Diffusion Models)

        扩散模型(Diffusion Models)是一种通过 “渐进去噪” 实现高质量数据生成的 AI 技术,核心原理类似 “从模糊到清晰还原画面” 的逆向修复过程。它先通过扩散过程(如往清水中滴墨水)将清晰数据(如图像)逐步转化为随机噪声(从隐约可见轮廓到完全杂乱像素),再通过逆扩散过程(反向去噪)让神经网络学会从噪声中逐层还原真实数据 —— 就像剥洋葱一样,每一步都用名为 U-Net 的 “对称漏斗状神经网络” 分析当前噪声图,预测并去除最关键的噪声颗粒,最终 “雕刻” 出逼真的图像、视频等内容。训练时,模型通过大量 “原图 + 不同噪声程度版本” 的数据对,学习 “噪声变化规律”(如 “猫的眼睛区域该去掉哪种噪声”),确保每一步去噪都符合真实数据的分布逻辑。其生成内容细节丰富(如 Stable Diffusion 能画出毛发纹理),稳定性远超传统对抗模型(如 GAN),但需数十步去噪计算,速度较慢。简单说,扩散模型是 AI 界的 “精细雕刻师”,通过数学上的渐进去噪魔法,从混沌噪声中还原或创造出高质量的视觉内容,成为当前图像生成、修复等领域的标杆技术。


1.3.6 自回归模型(Autoregressive Models)

        自回归模型(Autoregressive Models)是一种让 AI 实现 “按顺序创作” 的核心技术,其本质是让模型像人类说话一样,根据已生成的内容逐步预测下一个元素—— 比如写 “今天天气” 时,会基于 “今天” 和 “天气” 的语境,推测下一个词可能是 “晴朗”“炎热” 或 “多变”。它的工作逻辑类似 “接龙游戏”,每一步生成都依赖于前面所有结果,通过数学上的概率计算(如极大似然估计)最大化 “下一词符合语境” 的可能性,例如先算 “猫” 出现的概率,再算 “追” 在 “猫” 之后的概率,以此类推串联成完整内容。现代自回归模型(如 GPT 系列)采用自注意力机制升级 “记忆系统”,让模型在生成时能 “全局回看” 所有历史内容(如开头的 “小猫” 和结尾的 “它” 直接关联),解决了传统循环神经网络(RNN)长距离依赖差的问题,使生成的文本、语音等序列数据逻辑更连贯。其应用覆盖文本生成(写文章、代码、对话)、语音合成等领域,特点是输出自然流畅,但因需逐个元素生成(如逐词写句子),速度较慢且无法并行处理。简单说,自回归模型是 AI 的‘顺序创作引擎’,通过‘步步依赖、层层生成’的方式,让机器学会像人类一样‘先说前半句,再顺理成章接后半句’,成为 ChatGPT 等生成式 AI 的底层技术支撑。


1.3.7 对比总结

方法核心思想生成特点优缺点典型应用场景
自动编码器(AE)压缩数据再还原,学关键特征(类似 “压缩包解压”)- 还原输入,适合提取核心特征
- 生成新样本能力弱(只能模仿,难创新)
- 优点:简单高效,适合数据降维、去噪
- 缺点:生成能力差,新内容质量低
图像压缩、医学图像去噪、异常检测
生成对抗网络(GANs)两个模型对抗:一个造假,一个打假,越打越真(类似 “猫鼠游戏”)- 生成样本逼真(如人脸、风景)
- 容易 “偷懒”(只生成几种类型,缺乏多样性)
- 优点:图像效果逼真
- 缺点:训练难(易崩溃),结果不可控
虚拟人物生成、艺术创作、风格迁移
变分自编码器(VAE)给压缩的特征加 “概率滤镜”,能随机生成新样本(类似 “特征抽奖”)- 生成多样新样本(如不同风格的猫)
- 样本可能模糊(细节不清晰)
- 优点:能创造新样本,适合扩展数据
- 缺点:画质 / 音质一般,不够清晰
数据增强(生成同类变体)、药物分子设计
Transformer(如 GPT)按顺序预测下一个词 / 元素,学长距离逻辑(类似 “接龙游戏”)- 文本逻辑连贯(能写文章、对话)
- 生成速度慢(一个字一个字蹦)
- 优点:擅长长文本,会 “理解” 上下文
- 缺点:耗算力(需要超大规模训练)
写文章、聊天机器人、代码生成、翻译
扩散模型从 “噪声” 中一点点还原清晰数据(类似 “擦除马赛克”)- 生成质量极高(细节拉满,如复杂场景画图)
- 计算慢(需要几十步 “擦除”)
- 优点:图像 / 视频生成天花板
- 缺点:耗时耗显卡(训练要几周,生成要几十步)
高质量图像生成(DALL・E、MidJourney)、视频生成
自回归模型按顺序生成(如先写第一个词,再根据第一个词写第二个词)- 严格按顺序生成(适合文本、语音)
- 长序列效率低(如生成很长的句子会变慢)
- 优点:适合 “按步骤” 生成(如逐字、逐像素)
- 缺点:并行能力差(不能同时生成多个部分)
文本生成、语音合成、图像逐块生成

二,代码逻辑分析

2.1 数据处理

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)# 数据预处理 - 调整为32×32输入
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载FashionMNIST数据集
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform
)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

        将 28×28 的 FashionMNIST 图像调整为 32×32,适配后续网络结构,随机水平翻转图像,增强数据多样性,将像素值归一化到 [-1, 1] 区间(均值 0.5,标准差 0.5),batch_size=256,训练集shuffle=True确保数据随机打乱,测试集不打乱以保持顺序。


2.2 模型定义

# 定义残差块
class ResidualBlock(nn.Module):...# 定义32×32输入的自动编码器模型
class AdvancedAutoencoder(nn.Module):def __init__(self):# 编码器(适配32×32输入)self.encoder = nn.Sequential(...)# 解码器(适配32×32输出)self.decoder = nn.Sequential(...)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xdef extract_features(self, x):return self.encoder(x)# 定义线性分类器
class LinearClassifier(nn.Module):def __init__(self, input_dim=512 * 2 * 2, num_classes=10):super().__init__()self.linear = nn.Linear(input_dim, num_classes)def forward(self, x):x = x.view(x.size(0), -1)return self.linear(x)

        ResidualBlock残差连接(Shortcut Connection),缓解深层网络的梯度消失问题,输入输出通道数或尺寸不一致时,通过 1×1 卷积调整维度。

        AdvancedAutoencoder:编码器:通过 5 次下采样(卷积 + 残差块)将 32×32 图像压缩为 2×2×512 的特征表示。解码器:通过 5 次上采样(反卷积 + 残差块)将特征重构为 32×32 图像。extract_features方法:仅使用编码器提取特征,用于后续分类任务。

        LinearClassifier:输入维度:512×2×2=2048(编码器输出展平后)。单层线性映射,直接连接到 10 个分类类别,用于评估特征质量。


2.3 模型训练

# 训练AE模型
def train_ae(model, train_loader, criterion, optimizer, epochs, device):...# 训练线性分类器
def train_linear_classifier(ae_model, train_loader, test_loader, num_classes, device):# 冻结AE参数for param in ae_model.parameters():param.requires_grad = False# 仅训练线性分类器feature_extractor = ae_model.extract_featuresclassifier = LinearClassifier(input_dim, num_classes).to(device)...

train_ae():训练自动编码器,目标是最小化重构误差(MSE 损失)。使用 Adam 优化器,学习率 1e-3,训练 50 个 epochs。

train_linear_classifier():冻结 AE 的参数,仅训练线性分类器,使用预训练的 AE 提取特征,输入到线性层进行分类,评估特征的质量(线性分类准确率反映特征的判别能力)。


2.4 主函数逻辑

def main():# 1. 初始化设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 2. 训练自动编码器ae_model = AdvancedAutoencoder().to(device)ae_model = train_ae(ae_model, train_loader, ..., epochs=50)# 3. 可视化重构效果visualize_reconstructions(ae_model, test_loader, device)# 4. 训练线性分类器(评估特征质量)classifier, test_acc = train_linear_classifier(ae_model, ...)# 5. 保存模型torch.save(ae_model.state_dict(), 'fashion_mnist_32_ae.pth')torch.save(classifier.state_dict(), 'fashion_mnist_32_classifier.pth')

自监督学习阶段:训练 AE 学习图像的特征表示(无标签数据),通过重构质量评估 AE 性能。

线性评估阶段:使用冻结的 AE 提取特征,训练线性分类器,分类准确率反映特征的质量(是否包含类别判别信息)。


三,测试结果

3.1 图片重建效果

经过50个epcho训练后,loss大概能到0.00x的水平,说明损失也是非常小了

可以看到这里推图片的还原程度还是很高的

3.2 分类测试效果

最终线性分类的准确度大概能到91%

3.3 总结

        总的来看,AE方法的训练成本比较低而且准确度较高,在本次实验中,发现调参时epcho不能过大,不然最终的classifier acc基本会维持在百分之90以下。AE对Fashionminst数据集的处理也比较合适,如果想达到更好的准确度,就更需要细细调参,或者是改用效率更高的结构。


四,完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 数据预处理 - 将输入图像调整为32×32尺寸
transform = transforms.Compose([transforms.Resize((32, 32)),  # 将原始28×28图像调整为32×32,便于后续卷积操作transforms.RandomHorizontalFlip(),  # 随机水平翻转图像,增加数据多样性transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5,), (0.5,))  # 归一化处理,将像素值缩放到[-1, 1]范围
])# 加载FashionMNIST数据集(服装分类数据集,包含10个类别)
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform
)# 创建数据加载器,用于批量加载数据
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)  # 训练集打乱顺序
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)  # 测试集不打乱顺序# 定义残差块,用于构建深度网络,解决梯度消失问题
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()# 第一个卷积层:3×3卷积,保持特征图尺寸或减半(由stride控制)self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)  # 批量归一化,加速训练并提高稳定性self.relu = nn.ReLU(inplace=True)  # ReLU激活函数,引入非线性# 第二个卷积层:3×3卷积,保持特征图尺寸不变self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 捷径连接:当输入输出通道数或尺寸不一致时,使用1×1卷积调整self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):# 前向传播:第一个卷积块out = self.relu(self.bn1(self.conv1(x)))# 第二个卷积块out = self.bn2(self.conv2(out))# 残差连接:将输入直接加到输出上out += self.shortcut(x)# 最后通过ReLU激活out = self.relu(out)return out# 定义32×32输入的自动编码器模型(自监督学习)
class AdvancedAutoencoder(nn.Module):def __init__(self):super(AdvancedAutoencoder, self).__init__()# 编码器:将输入图像压缩为低维特征表示self.encoder = nn.Sequential(# 第一层:保持尺寸32×32,增加通道数到32nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 输出: 32×32×32nn.BatchNorm2d(32),nn.ReLU(),ResidualBlock(32, 32),  # 残差块,保持通道数不变# 第一次下采样:尺寸减半为16×16,通道数翻倍到64nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 输出: 16×16×64nn.BatchNorm2d(64),nn.ReLU(),ResidualBlock(64, 64),# 第二次下采样:尺寸减半为8×8,通道数翻倍到128nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 输出: 8×8×128nn.BatchNorm2d(128),nn.ReLU(),ResidualBlock(128, 128),# 第三次下采样:尺寸减半为4×4,通道数翻倍到256nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # 输出: 4×4×256nn.BatchNorm2d(256),nn.ReLU(),ResidualBlock(256, 256),# 第四次下采样:尺寸减半为2×2,通道数翻倍到512nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # 输出: 2×2×512nn.BatchNorm2d(512),nn.ReLU())# 解码器:将低维特征重构为原始图像self.decoder = nn.Sequential(# 第一次上采样:尺寸翻倍为4×4,通道数减半到256nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # 输出: 4×4×256nn.BatchNorm2d(256),nn.ReLU(),ResidualBlock(256, 256),# 第二次上采样:尺寸翻倍为8×8,通道数减半到128nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # 输出: 8×8×128nn.BatchNorm2d(128),nn.ReLU(),ResidualBlock(128, 128),# 第三次上采样:尺寸翻倍为16×16,通道数减半到64nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 输出: 16×16×64nn.BatchNorm2d(64),nn.ReLU(),ResidualBlock(64, 64),# 第四次上采样:尺寸翻倍为32×32,通道数减半到32nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 输出: 32×32×32nn.BatchNorm2d(32),nn.ReLU(),ResidualBlock(32, 32),# 最后一层:保持尺寸32×32,将通道数压缩到1(原始图像通道数)nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),  # 输出: 32×32×1nn.Tanh()  # 将输出值限制在[-1, 1]范围内,与输入归一化范围一致)def forward(self, x):# 完整的前向传播:编码 -> 解码x = self.encoder(x)x = self.decoder(x)return x# 提取特征的方法:仅使用编码器部分def extract_features(self, x):return self.encoder(x)# 定义线性分类器:用于评估自动编码器提取的特征质量
class LinearClassifier(nn.Module):def __init__(self, input_dim=512 * 2 * 2, num_classes=10):super(LinearClassifier, self).__init__()# 线性层:将编码器输出的特征向量映射到分类类别self.linear = nn.Linear(input_dim, num_classes)def forward(self, x):# 将特征张量展平为一维向量x = x.view(x.size(0), -1)  # 从[batch_size, 512, 2, 2]变为[batch_size, 2048]return self.linear(x)# 训练自动编码器的函数
def train_ae(model, train_loader, criterion, optimizer, epochs, device):model.train()  # 设置为训练模式for epoch in range(epochs):running_loss = 0.0for data, _ in train_loader:  # 忽略标签(自监督学习)data = data.to(device)  # 将数据移至GPU(如果可用)optimizer.zero_grad()  # 清零梯度outputs = model(data)  # 前向传播loss = criterion(outputs, data)  # 计算重构损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()# 打印每个epoch的平均损失avg_loss = running_loss / len(train_loader)print(f'AE Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}')return model# 训练线性分类器的函数(使用预训练的自动编码器提取特征)
def train_linear_classifier(ae_model, train_loader, test_loader, num_classes, device):# 冻结自动编码器的所有参数,仅训练线性分类器for param in ae_model.parameters():param.requires_grad = Falsefeature_extractor = ae_model.extract_features  # 获取特征提取器input_dim = 512 * 2 * 2  # 编码器输出的特征维度classifier = LinearClassifier(input_dim, num_classes).to(device)  # 创建线性分类器criterion = nn.CrossEntropyLoss()  # 分类任务使用交叉熵损失optimizer = optim.Adam(classifier.parameters(), lr=1e-3)  # 仅优化分类器参数classifier.train()  # 设置为训练模式epochs = 30for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for data, labels in train_loader:data, labels = data.to(device), labels.to(device)  # 数据移至GPU# 使用预训练的AE提取特征(不需要梯度计算)with torch.no_grad():features = feature_extractor(data)outputs = classifier(features)  # 通过分类器预测loss = criterion(outputs, labels)  # 计算分类损失optimizer.zero_grad()  # 清零梯度loss.backward()  # 反向传播optimizer.step()  # 更新参数# 计算准确率_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()running_loss += loss.item()# 打印每个epoch的损失和训练准确率train_acc = 100. * correct / totalavg_loss = running_loss / len(train_loader)print(f'Classifier Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}, Acc: {train_acc:.2f}%')# 在测试集上评估分类器性能test_acc = evaluate_classifier(classifier, feature_extractor, test_loader, device)print(f'Test Accuracy: {test_acc:.2f}%')return classifier, test_acc# 评估分类器性能的函数
def evaluate_classifier(classifier, feature_extractor, test_loader, device):classifier.eval()  # 设置为评估模式correct = 0total = 0# 不计算梯度,加速推理with torch.no_grad():for data, labels in test_loader:data, labels = data.to(device), labels.to(device)features = feature_extractor(data)  # 提取特征outputs = classifier(features)  # 分类预测# 计算准确率_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()return 100. * correct / total# 可视化原始图像和重构图像的函数
def visualize_reconstructions(model, test_loader, device, num_samples=5):model.eval()  # 设置为评估模式# 不计算梯度,加速推理with torch.no_grad():data, _ = next(iter(test_loader))  # 获取一批测试数据data = data[:num_samples].to(device)  # 取前几个样本reconstructions = model(data)  # 生成重构图像# 转换为CPU张量并调整维度,从[B,1,H,W]转为[B,H,W]data = data.cpu().numpy().squeeze(1)reconstructions = reconstructions.cpu().numpy().squeeze(1)# 创建图像对比图fig, axes = plt.subplots(2, num_samples, figsize=(15, 8))for i in range(num_samples):# 显示原始图像axes[0, i].imshow(data[i], cmap='gray')axes[0, i].set_title('Original (32×32)')axes[0, i].axis('off')# 显示重构图像axes[1, i].imshow(reconstructions[i], cmap='gray')axes[1, i].set_title('Reconstructed (32×32)')axes[1, i].axis('off')plt.tight_layout()plt.show()# 主函数:程序入口点
def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 创建并训练自动编码器ae_model = AdvancedAutoencoder().to(device)criterion = nn.MSELoss()  # 使用均方误差损失函数optimizer = optim.Adam(ae_model.parameters(), lr=1e-3)  # Adam优化器print("Training Advanced Autoencoder...")ae_model = train_ae(ae_model, train_loader, criterion, optimizer, epochs=50, device=device)# 可视化重构效果,检查AE训练质量visualize_reconstructions(ae_model, test_loader, device)# 训练并评估线性分类器print("\nTraining Linear Classifier...")num_classes = len(train_dataset.classes)  # 获取类别数量(10类)classifier, test_acc = train_linear_classifier(ae_model, train_loader, test_loader, num_classes, device)# 保存模型torch.save(ae_model.state_dict(), 'fashion_mnist_32_ae.pth')torch.save(classifier.state_dict(), 'fashion_mnist_32_classifier.pth')print(f"Models saved: fashion_mnist_32_ae.pth, fashion_mnist_32_classifier.pth")if __name__ == "__main__":main()  # 执行主函数

相关文章:

  • C++基础算法————贪心
  • 那些常用的运维工具
  • b. 组合数
  • C++:参数传递方法(Parameter Passing Methods)
  • 用户认证的魔法配方:从模型设计到密码安全的奇幻之旅
  • HackMyVM-First
  • Linux【工具 04】Java等常用工具的多版本管理工具SDKMAN安装使用实例
  • SpringBoot整合MyBatis完整实践指南
  • Android任务栈管理策略总结
  • # CppCon 2014 学习: Quick game development with C++11/C++14
  • 构建多模型协同的Ollama智能对话系统
  • WEB3——为什么做NFT铸造平台?
  • 2025.5.29 学习日记 docker概念以及基本指令
  • 算法:滑动窗口
  • MySQL项目实战演练:搭建用户管理系统的完整数据库结构【MySQL系列】
  • 如何实现一个请求库?【面试场景题】
  • 牛客小白月赛117
  • 实施ESOP投入收益研究报告
  • 趋势直线指标
  • C语言学习——C语言强制类型转换2023.12.20
  • 在什么网站可以接活做/新冠病毒最新消息
  • 黄色视频做爰网站安全/aso安卓优化
  • 眉山做网站/百度seo和谷歌seo有什么区别
  • 用javascript做的网站/深圳市网络营销推广服务公司
  • 对网站建设培训的建议/排名优化关键词公司
  • 拟定一个物流网站的建设方案/画质优化app下载