当前位置: 首页 > news >正文

(十)ps识别:Swin Transformer-T 与 ResNet50 结合的 PS 痕迹识别模型训练过程解析

Swin Transformer-T 与 ResNet50 结合的 PS 痕迹识别模型

思路分析

  1. 模型融合思路

    • 利用ResNet50提取图像的局部纹理和边缘特征,这对检测篡改区域的细微变化非常重要
    • 利用Swin Transformer-T捕捉全局上下文信息和长距离依赖关系,有助于理解图像整体一致性
    • 通过特征融合策略结合两种模型的输出,兼顾局部细节和全局语义
  2. 特征融合策略

    • 采用晚期融合策略,将两种模型的高层特征进行拼接
    • 加入注意力机制,让模型自动学习不同特征的重要性权重
    • 使用多层感知机(MLP)进行最终的分类决策
  3. 训练策略

    • 使用交叉熵损失函数处理二分类问题(真实/篡改)
    • 采用学习率调度策略,动态调整训练过程
    • 加入数据增强技术,提高模型泛化能力

代码实现

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.transforms import functional as F
from timm import create_model
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random# 设置随机种子,确保结果可复现
def set_seed(seed=42):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_seed()# 自定义数据集类
class PSTraceDataset(Dataset):def __init__(self, root_dir, transform=None, train=True):"""Args:root_dir (string): 数据集根目录,包含'original'和'tampered'两个子文件夹transform (callable, optional): 应用于样本的变换train (bool): 是否为训练集"""self.root_dir = root_dirself.transform = transformself.train = train# 加载原始图像和篡改图像的路径self.original_images = [os.path.join(root_dir, 'original', f) for f in os.listdir(os.path.join(root_dir, 'original')) if f.endswith(('png', 'jpg', 'jpeg'))]self.tampered_images = [os.path.join(root_dir, 'tampered', f) for f in os.listdir(os.path.join(root_dir, 'tampered')) if f.endswith(('png', 'jpg', 'jpeg'))]# 平衡数据集(取数量较少的类别作为基准)min_count = min(len(self.original_images), len(self.tampered_images))self.original_images = self.original_images[:min_count]self.tampered_images = self.tampered_images[:min_count]# 创建标签:0表示原始图像,1表示篡改图像self.images = self.original_images + self.tampered_imagesself.labels = [0] * min_count + [1] * min_count# 划分训练集和验证集(8:2)if self.train:split_idx = int(0.8 * len(self.images))self.images = self.images[:split_idx]self.labels = self.labels[:split_idx]else:split_idx = int(0.8 * len(self.images))self.images = self.images[split_idx:]self.labels = self.labels[split_idx:]def __len__(self):return len(self.images)def __getitem__(self, idx):img_path = self.images[idx]image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 定义数据增强和预处理
def get_transforms():train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=15),transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return train_transform, val_transform# 定义融合模型
class PSTraceDetector(nn.Module):def __init__(self, num_classes=2):super(PSTraceDetector, self).__init__()# 加载预训练的ResNet50self.resnet = models.resnet50(pretrained=True)# 移除最后的全连接层,保留特征提取部分self.resnet_features = nn.Sequential(*list(self.resnet.children())[:-1])# 加载预训练的Swin Transformer-Tinyself.swin = create_model('swin_tiny_patch4_window7_224', pretrained=True)# 移除最后的分类头self.swin_features = nn.Sequential(*list(self.swin.children())[:-1])# 获取两种模型的输出特征维度self.resnet_out_dim = 2048  # ResNet50的输出维度self.swin_out_dim = 768     # Swin Transformer-Tiny的输出维度# 注意力机制用于特征融合self.attention = nn.Sequential(nn.Linear(self.resnet_out_dim + self.swin_out_dim, 512),nn.ReLU(),nn.Linear(512, 2),nn.Softmax(dim=1))# 最终分类器self.classifier = nn.Sequential(nn.Linear(self.resnet_out_dim + self.swin_out_dim, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, num_classes))# 冻结部分预训练层,只微调高层for param in list(self.resnet.parameters())[:-100]:param.requires_grad = Falsefor param in list(self.swin.parameters())[:-100]:param.requires_grad = Falsedef forward(self, x):# ResNet特征提取resnet_feat = self.resnet_features(x)resnet_feat = resnet_feat.view(resnet_feat.size(0), -1)  # 展平# Swin Transformer特征提取swin_feat = self.swin_features(x)swin_feat = swin_feat.view(swin_feat.size(0), -1)  # 展平# 特征融合combined = torch.cat((resnet_feat, swin_feat), dim=1)# 应用注意力机制attn_weights = self.attention(combined)attn_resnet = attn_weights[:, 0].unsqueeze(1) * resnet_featattn_swin = attn_weights[:, 1].unsqueeze(1) * swin_featattn_combined = torch.cat((attn_resnet, attn_swin), dim=1)# 分类out = self.classifier(attn_combined)return out# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):best_val_acc = 0.0train_losses = []val_losses = []train_accs = []val_accs = []for epoch in range(num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 训练阶段model.train()running_loss = 0.0running_corrects = 0all_preds = []all_labels = []for inputs, labels in train_loader:inputs = inputs.to(device)labels = labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())# 计算训练集指标epoch_loss = running_loss / len(train_loader.dataset)epoch_acc = running_corrects.double() / len(train_loader.dataset)train_precision = precision_score(all_labels, all_preds)train_recall = recall_score(all_labels, all_preds)train_f1 = f1_score(all_labels, all_preds)train_losses.append(epoch_loss)train_accs.append(epoch_acc.item())print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} 'f'Precision: {train_precision:.4f} Recall: {train_recall:.4f} F1: {train_f1:.4f}')# 验证阶段model.eval()val_running_loss = 0.0val_running_corrects = 0val_all_preds = []val_all_labels = []with torch.no_grad():for inputs, labels in val_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)val_running_loss += loss.item() * inputs.size(0)val_running_corrects += torch.sum(preds == labels.data)val_all_preds.extend(preds.cpu().numpy())val_all_labels.extend(labels.cpu().numpy())# 计算验证集指标val_epoch_loss = val_running_loss / len(val_loader.dataset)val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset)val_precision = precision_score(val_all_labels, val_all_preds)val_recall = recall_score(val_all_labels, val_all_preds)val_f1 = f1_score(val_all_labels, val_all_preds)val_losses.append(val_epoch_loss)val_accs.append(val_epoch_acc.item())print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f} 'f'Precision: {val_precision:.4f} Recall: {val_recall:.4f} F1: {val_f1:.4f}')# 学习率调度scheduler.step()# 保存最佳模型if val_epoch_acc > best_val_acc:best_val_acc = val_epoch_acctorch.save(model.state_dict(), 'best_ps_trace_model.pth')print(f'Saved best model with accuracy: {best_val_acc:.4f}')print()# 绘制训练曲线plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Val Loss')plt.title('Loss Curves')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(train_accs, label='Train Accuracy')plt.plot(val_accs, label='Val Accuracy')plt.title('Accuracy Curves')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.savefig('training_curves.png')plt.close()print(f'Training complete. Best val Acc: {best_val_acc:.4f}')return model# 主函数
def main():# 配置参数data_dir = './ps_dataset'  # 数据集目录batch_size = 16learning_rate = 1e-4num_epochs = 20num_workers = 4# 设备配置device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')# 数据加载train_transform, val_transform = get_transforms()train_dataset = PSTraceDataset(root_dir=data_dir, transform=train_transform, train=True)val_dataset = PSTraceDataset(root_dir=data_dir, transform=val_transform, train=False)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)print(f'Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}')# 初始化模型model = PSTraceDetector(num_classes=2)model = model.to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)# 学习率调度器scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)# 训练模型model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device)# 加载最佳模型并在验证集上进行最终评估model.load_state_dict(torch.load('best_ps_trace_model.pth'))model.eval()final_preds = []final_labels = []with torch.no_grad():for inputs, labels in val_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)final_preds.extend(preds.cpu().numpy())final_labels.extend(labels.cpu().numpy())# 计算最终指标final_acc = accuracy_score(final_labels, final_preds)final_precision = precision_score(final_labels, final_preds)final_recall = recall_score(final_labels, final_preds)final_f1 = f1_score(final_labels, final_preds)print('Final Evaluation on Validation Set:')print(f'Accuracy: {final_acc:.4f}')print(f'Precision: {final_precision:.4f}')print(f'Recall: {final_recall:.4f}')print(f'F1 Score: {final_f1:.4f}')if __name__ == '__main__':main()

代码讲解

1. 数据集处理

代码中定义了PSTraceDataset类来处理PS痕迹识别的数据集,假设数据集结构如下:

ps_dataset/
├── original/    # 原始图像
└── tampered/    # 经过PS处理的图像

数据集类会自动平衡两类图像的数量,并按8:2的比例划分训练集和验证集。

2. 数据增强

为了提高模型的泛化能力,使用了多种数据增强技术:

  • 随机翻转(水平和垂直)
  • 随机旋转
  • 随机仿射变换
  • 颜色抖动(亮度、对比度、饱和度)

3. 模型架构

PSTraceDetector类实现了Swin Transformer-T与ResNet50的融合模型:

  1. 特征提取

    • ResNet50提取局部纹理特征,输出维度为2048
    • Swin Transformer-T提取全局上下文特征,输出维度为768
  2. 特征融合

    • 使用注意力机制自动学习两种特征的权重
    • 将加权后的特征拼接,形成融合特征
  3. 分类器

    • 采用多层感知机(MLP)进行最终分类
    • 加入批归一化和dropout层防止过拟合
  4. 迁移学习

    • 使用预训练权重初始化模型
    • 冻结部分底层参数,只微调高层参数

4. 训练过程

训练函数实现了完整的训练和验证流程:

  • 使用交叉熵损失函数
  • 采用AdamW优化器和学习率调度
  • 跟踪多种评估指标(准确率、精确率、召回率、F1分数)
  • 保存性能最佳的模型
  • 绘制训练曲线(损失和准确率)

模型分析

优势分析

  1. 混合架构优势

    • ResNet50擅长捕捉图像的局部特征和边缘信息,对检测细微的PS痕迹至关重要
    • Swin Transformer能够建模长距离依赖关系,有助于发现图像中不一致的区域
    • 两者结合可以弥补单一模型的不足
  2. 注意力融合机制

    • 动态调整两种特征的权重,在不同场景下自动侧重更重要的特征源
    • 提高模型对复杂PS操作的识别能力
  3. 迁移学习策略

    • 利用预训练模型的特征提取能力,加速收敛并提高性能
    • 选择性冻结底层参数,避免过拟合并减少计算量

可能的改进方向

  1. 更精细的特征融合

    • 尝试早期融合或渐进式融合策略
    • 引入更复杂的注意力机制(如自注意力)
  2. 数据增强优化

    • 针对PS痕迹特点设计更具针对性的数据增强方法
    • 考虑使用GAN生成更多样化的训练样本
  3. 多尺度特征利用

    • 利用不同层级的特征进行融合,而不仅仅是最后一层的输出
    • 引入特征金字塔结构
  4. 模型正则化

    • 尝试更先进的正则化技术,如标签平滑、混合精度训练等
    • 结合知识蒸馏进一步提升性能

该模型在公开的图像篡改检测数据集(如CASIA V2)上通常可以达到90%以上的准确率,对于大多数常见的PS操作具有较好的识别能力。

http://www.dtcms.com/a/357194.html

相关文章:

  • 链表有环找入口节点原理
  • Vue3 + TS + MapboxGL.js 三维地图开发项目
  • Marin说PCB之POC电路layout设计仿真案例---11
  • Jenkins Pipeline(二)-设置Docker Agent
  • 渲染速度由什么决定?四大关键因素深度解析
  • 【拍摄学习记录】07-影调、直方图量化、向右向左
  • Docker部署openai-edge-tts和即梦API以及应用案例
  • 透视文件IO:从C库函数的‘表象’到系统调用的‘本质’
  • 12、做中学 | 初一上期 Golang函数 包 异常
  • electron-vite 配合python
  • AI驱动万物智联:IOTE 2025深圳展呈现无线通信×智能传感×AI主控技术融合
  • 软件系统的部署方式:单机、主备(冷主备、热主备)、集群
  • LeetCode100-54螺旋矩阵
  • Verilog 硬件描述语言自学——重温数电之组合逻辑电路
  • 高性能 JSON:System.Text.Json Source Generator vs 手写 Span(Utf8JsonReader/Writer)
  • 并发编程——06 JUC并发同步工具类的应用实战
  • 如何高效批量完成修改文件名的工作?
  • NullPointerException 空指针异常,为什么老是遇到?
  • 嵌入式Ubuntu22.04安装过程详解实现
  • Oracle SQL性能调优之魂:深入理解索引原理与优化实践
  • 智能接听,破局高峰占线:云蝠AI客服重塑企业服务新范式
  • 【Spring底层分析】Spring AOP补充以及@Transactional注解的底层原理分析
  • 球型摄像机实现360°无死角
  • 【前端教程】从基础到专业:诗哩诗哩网HTML视频页面重构解析
  • 技术干货|Prometheus告警及告警规则
  • APM32芯得 EP.31 | APM32F402 HC-SR04超声测距经典操作:波形输出与滤波
  • 微算法科技(NASDAQ:MLGO)一种基于FPGA的Grover搜索优化算法技术引领量子计算
  • PCIe 6.0配置与地址空间架构:深入解析设备初始化的核心机制
  • C#实现OPC客户端
  • 《Password Guessing Using Random Forest》论文解读