OpenSTL PredRNNv2 模型复现与自定义数据集训练
OpenSTL PredRNNv2 模型复现与自定义数据集训练
概述
本文将详细介绍如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们将从环境配置开始,逐步讲解数据预处理、模型构建、训练过程和预测实现,最终实现输入多张连续时间序列的 500×500 图像并输出相应数量预测图像的目标。
目录
- 环境配置与依赖安装
- 数据集准备与预处理
- PredRNNv2 模型原理与架构
- 数据加载器实现
- 模型训练流程
- 预测与结果可视化
- 模型评估与优化
- 完整代码实现
- 常见问题与解决方案
- 总结与展望
1. 环境配置与依赖安装
首先,我们需要创建一个合适的 Python 环境并安装所有必要的依赖包。
# 创建conda环境
conda create -n openstl python=3.8
conda activate openstl# 安装PyTorch (根据CUDA版本选择)
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116# 安装其他依赖
pip install numpy==1.21.6
pip install opencv-python==4.7.0.72
pip install matplotlib==3.5.3
pip install tensorboard==2.11.2
pip install scikit-learn==1.0.2
pip install tqdm==4.64.1
pip install nni==2.8
pip install timm==0.6.12
pip install einops==0.6.0
接下来,我们需要克隆 OpenSTL 仓库并安装相关依赖:
git clone https://github.com/chengtan9907/OpenSTL.git
cd OpenSTL
git checkout OpenSTL-Lightning
pip install -e .
2. 数据集准备与预处理
我们的数据集是 NPY 格式的文件,每张图像尺寸为 500×500,且文件之间在时间上是连续的。首先,我们需要了解数据集的目录结构:
dataset/
├── train/
│ ├── sequence_001/
│ │ ├── frame_001.npy
│ │ ├── frame_002.npy
│ │ └── ...
│ ├── sequence_002/
│ └── ...
├── valid/
└── test/
2.1 数据预处理类实现
我们需要创建一个数据预处理类,将 NPY 文件转换为模型可用的格式:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import cv2class NPYDataset(Dataset):def __init__(self, data_root, mode='train', input_frames=10, output_frames=10, future_frames=10, transform=None, preprocess=True):"""初始化NPY数据集参数:data_root: 数据根目录mode: 模式 ('train', 'valid', 'test')input_frames: 输入帧数output_frames: 输出帧数future_frames: 未来帧数 (预测帧数)transform: 数据转换函数preprocess: 是否进行预处理"""self.data_root = os.path.join(data_root, mode)self.mode = modeself.input_frames = input_framesself.output_frames = output_framesself.future_frames = future_framesself.transform = transformself.preprocess = preprocess# 获取所有序列self.sequences = []for seq_name in os.listdir(self.data_root):seq_path = os.path.join(self.data_root, seq_name)if os.path.isdir(seq_path):frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.npy')])if len(frames) >= input_frames + future_frames:self.sequences.append((seq_path, frames))# 数据标准化器self.scaler = Noneif preprocess:self._init_scaler()def _init_scaler(self):"""初始化数据标准化器"""print(f"Initializing scaler for {self.mode} mode...")all_data = []for seq_path, frames in self.sequences:for frame_name in frames[:min(100, len(frames))]: # 使用前100帧计算统计量frame_path = os.path.join(seq_path, frame_name)data = np.load(frame_path)all_data.append(data.flatten())all_data = np.concatenate(all_data).reshape(-1, 1)self.scaler = StandardScaler()self.scaler.fit(all_data)print("Scaler initialized.")def _preprocess_data(self, data):"""预处理数据"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.transform(data)data = data.reshape(original_shape)return datadef _postprocess_data(self, data):"""后处理数据"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.inverse_transform(data)data = data.reshape(original_shape)return datadef __len__(self):return len(self.sequences)def __getitem__(self, idx):seq_path, frames = self.sequences[idx]# 随机选择起始帧total_frames = len(frames)max_start = total_frames - self.input_frames - self.future_framesstart_idx = np.random.randint(0, max_start + 1) if self.mode == 'train' else 0# 加载输入帧input_frames = []for i in range(start_idx, start_idx + self.input_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)input_frames.append(frame_data)# 加载目标帧target_frames = []for i in range(start_idx + self.input_frames, start_idx + self.input_frames + self.future_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)target_frames.append(frame_data)# 转换为numpy数组input_seq = np.stack(input_frames, axis=0)target_seq = np.stack(target_frames, axis=0)# 添加通道维度input_seq = np.expand_dims(input_seq, axis=1) # [T, 1, H, W]target_seq = np.expand_dims(target_seq, axis=1) # [T, 1, H, W]# 转换为张量input_seq = torch.FloatTensor(input_seq)target_seq = torch.FloatTensor(target_seq)if self.transform:input_seq = self.transform(input_seq)target_seq = self.transform(target_seq)return input_seq, target_seq# 数据增强转换
class RandomRotate:def __init__(self, angles=[0, 90, 180, 270]):self.angles = anglesdef __call__(self, x):angle = np.random.choice(self.angles)if angle == 0:return x# 旋转每个帧rotated = []for i in range(x.shape[0]):frame = x[i].numpy()# 对于3D数据,我们需要分别旋转每个通道if len(frame.shape) == 3:frame_rotated = np.stack([cv2.rotate(frame[c], cv2.ROTATE_90_CLOCKWISE) for c in range(frame.shape[0])], axis=0)else:frame_rotated = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)rotated.append(frame_rotated)return torch.FloatTensor(np.stack(rotated, axis=0))class RandomFlip:def __init__(self, p=0.5):self.p = pdef __call__(self, x):if np.random.random() < self.p:# 水平翻转return x.flip(-1)return x
3. PredRNNv2 模型原理与架构
PredRNNv2 是一种改进的循环神经网络,专门用于视频预测任务。它通过引入时空记忆(STM)单元来更好地捕捉时空动态。
3.1 核心组件
import torch
import torch.nn as nn
from einops import rearrangeclass SpatioTemporalLSTMCell(nn.Module):def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):super(SpatioTemporalLSTMCell, self).__init__()self.num_hidden = num_hiddenself.padding = filter_size // 2self._forget_bias = 1.0# 卷积层self.conv_x = nn.Sequential(nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 7, height, width]))self.conv_h = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 4, height, width]))self.conv_m = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 3, height, width]))self.conv_o = nn.Sequential(nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden, height, width]))self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,stride=1, padding=0, bias=False)def forward(self, x_t, h_t, c_t, m_t):# 计算门控信号x_concat = self.conv_x(x_t)h_concat = self.conv_h(h_t)m_concat = self.conv_m(m_t)i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)i_t = torch.sigmoid(i_x + i_h)f_t = torch.sigmoid(f_x + f_h + self._forget_bias)g_t = torch.tanh(g_x + g_h)c_new = f_t * c_t + i_t * g_ti_t_prime = torch.sigmoid(i_x_prime + i_m)f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)g_t_prime = torch.tanh(g_x_prime + g_m)m_new = f_t_prime * m_t + i_t_prime * g_t_primemem = torch.cat((c_new, m_new), 1)o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))h_new = o_t * torch.tanh(self.conv_last(mem))return h_new, c_new, m_newclass PredRNNv2(nn.Module):def __init__(self, configs):super(PredRNNv2, self).__init__()self.configs = configsself.frame_channel = configs.patch_size * configs.patch_size * configs.img_channelself.num_layers = len(configs.num_hidden)self.num_hidden = configs.num_hiddenself.device = configs.device# 构建网络cell_list = []height = configs.img_height // configs.patch_sizewidth = configs.img_width // configs.patch_sizefor i in range(self.num_layers):in_channel = self.frame_channel if i == 0 else self.num_hidden[i-1]cell_list.append(SpatioTemporalLSTMCell(in_channel, self.num_hidden[i], height, width,configs.filter_size, configs.stride, configs.layer_norm))self.cell_list = nn.ModuleList(cell_list)# 输出层self.conv_last = nn.Conv2d(self.num_hidden[self.num_layers-1], self.frame_channel,kernel_size=1, stride=1, padding=0, bias=False)def forward(self, frames_tensor, mask_true):# frames_tensor: [batch, length, channel, height, width]batch = frames_tensor.shape[0]height = frames_tensor.shape[3]width = frames_tensor.shape[4]# 初始化隐藏状态和记忆状态next_frames = []h_t = []c_t = []m_t = []for i in range(self.num_layers):zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.device)h_t.append(zeros)c_t.append(zeros)m_t.append(zeros)# 记忆状态memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.device)# 序列长度seq_length = self.configs.input_length + self.configs.total_lengthfor t in range(seq_length - 1):# 反向调度采样if self.configs.reverse_scheduled_sampling == 1:if t == 0:net = frames_tensor[:, t]else:# 从真实数据或预测数据中采样net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_genelse:# 常规训练if t < self.configs.input_length:net = frames_tensor[:, t]else:# 从真实数据或预测数据中采样net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen# 第一层h_t[0], c_t[0], m_t[0] = self.cell_list[0](net, h_t[0], c_t[0], m_t[0])# 后续层for i in range(1, self.num_layers):h_t[i], c_t[i], m_t[i] = self.cell_list[i](h_t[i-1], h_t[i], c_t[i], m_t[i])# 生成预测x_gen = self.conv_last(h_t[self.num_layers-1])next_frames.append(x_gen)# [length, batch, channel, height, width] -> [batch, length, channel, height, width]next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)return next_frames
4. 数据加载器实现
接下来,我们需要实现数据加载器,将数据集转换为模型可用的格式:
def create_data_loaders(configs):"""创建训练、验证和测试数据加载器"""# 数据转换if configs.data_augmentation:train_transform = nn.Sequential(RandomRotate(),RandomFlip())else:train_transform = None# 创建数据集train_dataset = NPYDataset(data_root=configs.data_root,mode='train',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=train_transform,preprocess=configs.preprocess_data)valid_dataset = NPYDataset(data_root=configs.data_root,mode='valid',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)test_dataset = NPYDataset(data_root=configs.data_root,mode='test',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=configs.batch_size,shuffle=True,num_workers=configs.num_workers,pin_memory=True)valid_loader = DataLoader(valid_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)test_loader = DataLoader(test_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)return train_loader, valid_loader, test_loader
5. 模型训练流程
现在,我们实现完整的训练流程,包括损失函数、优化器和学习率调度器:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdmclass Trainer:def __init__(self, configs, model, train_loader, valid_loader, test_loader):self.configs = configsself.model = modelself.train_loader = train_loaderself.valid_loader = valid_loaderself.test_loader = test_loaderself.device = configs.device# 损失函数self.criterion = nn.MSELoss()# 优化器self.optimizer = optim.Adam(model.parameters(),lr=configs.lr,weight_decay=configs.weight_decay)# 学习率调度器self.scheduler = ReduceLROnPlateau(self.optimizer,mode='min',factor=0.5,patience=5,verbose=True)# 记录训练历史self.train_losses = []self.valid_losses = []self.best_loss = float('inf')# 创建检查点目录os.makedirs(configs.save_dir, exist_ok=True)def train_epoch(self, epoch):"""训练一个epoch"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 前向传播self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)# 计算损失loss = self.criterion(outputs, targets)# 反向传播loss.backward()self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_lossdef validate(self):"""验证模型"""self.model.eval()total_loss = 0with torch.no_grad():for inputs, targets in self.valid_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()avg_loss = total_loss / len(self.valid_loader)self.valid_losses.append(avg_loss)return avg_lossdef test(self):"""测试模型"""self.model.eval()total_loss = 0all_outputs = []all_targets = []with torch.no_grad():for inputs, targets in self.test_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()# 保存结果用于后续分析all_outputs.append(outputs.cpu().numpy())all_targets.append(targets.cpu().numpy())avg_loss = total_loss / len(self.test_loader)return avg_loss, np.concatenate(all_outputs, axis=0), np.concatenate(all_targets, axis=0)def save_checkpoint(self, epoch, is_best=False):"""保存检查点"""checkpoint = {'epoch': epoch,'model_state_dict': self.model.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'scheduler_state_dict': self.scheduler.state_dict(),'train_losses': self.train_losses,'valid_losses': self.valid_losses,'best_loss': self.best_loss}# 保存最新检查点torch.save(checkpoint, os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))# 如果是最佳模型,保存为最佳检查点if is_best:torch.save(checkpoint, os.path.join(self.configs.save_dir, 'best_checkpoint.pth'))def load_checkpoint(self, checkpoint_path):"""加载检查点"""checkpoint = torch.load(checkpoint_path)self.model.load_state_dict(checkpoint['model_state_dict'])self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])self.train_losses = checkpoint['train_losses']self.valid_losses = checkpoint['valid_losses']self.best_loss = checkpoint['best_loss']return checkpoint['epoch']def train(self, num_epochs):"""完整训练过程"""start_epoch = 0# 如果存在检查点,加载检查点if self.configs.resume and os.path.exists(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth')):print("Loading checkpoint...")start_epoch = self.load_checkpoint(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))print(f"Resumed from epoch {start_epoch}")for epoch in range(start_epoch, num_epochs):print(f"\nEpoch {epoch+1}/{num_epochs}")# 训练train_loss = self.train_epoch(epoch)print(f"Train Loss: {train_loss:.6f}")# 验证valid_loss = self.validate()print(f"Valid Loss: {valid_loss:.6f}")# 更新学习率self.scheduler.step(valid_loss)# 保存检查点is_best = valid_loss < self.best_lossif is_best:self.best_loss = valid_lossself.save_checkpoint(epoch, is_best)# 每5个epoch测试一次if (epoch + 1) % 5 == 0:test_loss, _, _ = self.test()print(f"Test Loss: {test_loss:.6f}")# 最终测试print("\nFinal Testing...")test_loss, outputs, targets = self.test()print(f"Final Test Loss: {test_loss:.6f}")return test_loss, outputs, targets
6. 预测与结果可视化
实现预测功能和结果可视化:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGridclass Predictor:def __init__(self, configs, model):self.configs = configsself.model = modelself.device = configs.deviceself.model.eval()def predict(self, input_seq):"""预测未来帧"""with torch.no_grad():input_seq = input_seq.to(self.device)output_seq = self.model(input_seq, mask_true=None)return output_seq.cpu()def visualize_results(self, inputs, targets, predictions, save_path=None):"""可视化输入、目标和预测结果"""# 选择第一个批次进行可视化inputs = inputs[0].squeeze() # [T, H, W]targets = targets[0].squeeze() # [T, H, W]predictions = predictions[0].squeeze() # [T, H, W]# 创建子图total_frames = inputs.shape[0] + targets.shape[0]fig = plt.figure(figsize=(20, 10))grid = ImageGrid(fig, 111, nrows_ncols=(3, total_frames), axes_pad=0.1)# 绘制输入帧for i in range(inputs.shape[0]):ax = grid[i]ax.imshow(inputs[i], cmap='viridis')ax.set_title(f'Input {i+1}')ax.axis('off')# 绘制目标帧for i in range(targets.shape[0]):ax = grid[inputs.shape[0] + i]ax.imshow(targets[i], cmap='viridis')ax.set_title(f'Target {i+1}')ax.axis('off')# 绘制预测帧for i in range(predictions.shape[0]):ax = grid[inputs.shape[0] + targets.shape[0] + i]ax.imshow(predictions[i], cmap='viridis')ax.set_title(f'Pred {i+1}')ax.axis('off')plt.tight_layout()if save_path:plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()def save_predictions(self, predictions, save_dir):"""保存预测结果为NPY文件"""os.makedirs(save_dir, exist_ok=True)for i, pred_seq in enumerate(predictions):for j, frame in enumerate(pred_seq):frame_path = os.path.join(save_dir, f'batch_{i}_frame_{j}.npy')np.save(frame_path, frame.squeeze())def evaluate_metrics(self, targets, predictions):"""评估预测性能"""from sklearn.metrics import mean_squared_error, mean_absolute_error# 展平数据targets_flat = targets.flatten()predictions_flat = predictions.flatten()# 计算指标mse = mean_squared_error(targets_flat, predictions_flat)mae = mean_absolute_error(targets_flat, predictions_flat)rmse = np.sqrt(mse)# 计算PSNRmax_val = np.max(targets_flat)psnr = 20 * np.log10(max_val / rmse) if rmse > 0 else float('inf')# 计算SSIM (需要安装skimage)try:from skimage.metrics import structural_similarity as ssim_funcssim = ssim_func(targets_flat.reshape(targets.shape), predictions_flat.reshape(targets.shape),data_range=max_val)except ImportError:ssim = 0print("SSIM calculation requires skimage. Install with: pip install scikit-image")return {'MSE': mse,'MAE': mae,'RMSE': rmse,'PSNR': psnr,'SSIM': ssim}
7. 模型评估与优化
实现模型评估和超参数优化功能:
def hyperparameter_optimization(configs):"""超参数优化"""import nni# 获取NNI超参数optimized_params = nni.get_next_parameter()configs.lr = optimized_params.get('lr', configs.lr)configs.batch_size = optimized_params.get('batch_size', configs.batch_size)configs.num_hidden = optimized_params.get('num_hidden', configs.num_hidden)# 创建模型和数据加载器model = PredRNNv2(configs).to(configs.device)train_loader, valid_loader, test_loader = create_data_loaders(configs)# 训练模型trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, _, _ = trainer.train(configs.epoch)# 报告最终结果nni.report_final_result(test_loss)return test_lossdef analyze_results(configs, outputs, targets):"""分析预测结果"""predictor = Predictor(configs, None)metrics = predictor.evaluate_metrics(targets, outputs)print("Evaluation Metrics:")for metric, value in metrics.items():print(f"{metric}: {value:.4f}")# 绘制损失曲线plt.figure(figsize=(10, 6))plt.plot(range(len(outputs)), outputs.flatten(), label='Predictions', alpha=0.7)plt.plot(range(len(targets)), targets.flatten(), label='Targets', alpha=0.7)plt.xlabel('Sample Index')plt.ylabel('Value')plt.title('Predictions vs Targets')plt.legend()plt.grid(True)plt.savefig(os.path.join(configs.save_dir, 'predictions_vs_targets.png'), dpi=300)plt.show()return metrics
8. 完整代码实现
现在,我们将所有组件整合到一个完整的脚本中:
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models import PredRNNv2
from data_loader import NPYDataset, create_data_loaders
from trainer import Trainer
from predictor import Predictor
from utils import analyze_resultsdef parse_args():parser = argparse.ArgumentParser(description='PredRNNv2 for NPY dataset')# 数据参数parser.add_argument('--data_root', type=str, default='./dataset', help='数据集根目录')parser.add_argument('--input_length', type=int, default=10, help='输入帧数')parser.add_argument('--total_length', type=int, default=20, help='总帧数(输入+预测)')parser.add_argument('--img_width', type=int, default=500, help='图像宽度')parser.add_argument('--img_height', type=int, default=500, help='图像高度')parser.add_argument('--img_channel', type=int, default=1, help='图像通道数')parser.add_argument('--preprocess_data', type=bool, default=True, help='是否预处理数据')parser.add_argument('--data_augmentation', type=bool, default=True, help='是否使用数据增强')# 模型参数parser.add_argument('--num_hidden', type=list, default=[64, 64, 64, 64], help='每层隐藏单元数')parser.add_argument('--filter_size', type=int, default=5, help='滤波器大小')parser.add_argument('--stride', type=int, default=1, help='步长')parser.add_argument('--patch_size', type=int, default=1, help='补丁大小')parser.add_argument('--layer_norm', type=bool, default=True, help='是否使用层归一化')parser.add_argument('--reverse_scheduled_sampling', type=int, default=0, help='反向调度采样')# 训练参数parser.add_argument('--batch_size', type=int, default=4, help='批次大小')parser.add_argument('--lr', type=float, default=1e-3, help='学习率')parser.add_argument('--weight_decay', type=float, default=0, help='权重衰减')parser.add_argument('--epoch', type=int, default=100, help='训练轮数')parser.add_argument('--num_workers', type=int, default=4, help='数据加载工作线程数')parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='设备')parser.add_argument('--save_dir', type=str, default='./checkpoints', help='保存目录')parser.add_argument('--resume', type=bool, default=False, help='是否恢复训练')# 其他参数parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'predict'], help='运行模式')parser.add_argument('--checkpoint_path', type=str, default='', help='检查点路径')return parser.parse_args()def main():# 解析参数configs = parse_args()# 创建保存目录os.makedirs(configs.save_dir, exist_ok=True)# 创建模型model = PredRNNv2(configs).to(configs.device)print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")if configs.mode == 'train':# 创建数据加载器train_loader, valid_loader, test_loader = create_data_loaders(configs)# 创建训练器并开始训练trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, outputs, targets = trainer.train(configs.epoch)# 分析结果analyze_results(configs, outputs, targets)elif configs.mode == 'test':# 加载检查点if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 创建数据加载器_, _, test_loader = create_data_loaders(configs)# 测试模型trainer = Trainer(configs, model, None, None, test_loader)test_loss, outputs, targets = trainer.test()print(f"Test Loss: {test_loss:.6f}")# 分析结果metrics = analyze_results(configs, outputs, targets)# 保存结果np.save(os.path.join(configs.save_dir, 'test_outputs.npy'), outputs)np.save(os.path.join(configs.save_dir, 'test_targets.npy'), targets)elif configs.mode == 'predict':# 加载检查点if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 创建预测器predictor = Predictor(configs, model)# 加载要预测的数据# 这里假设有一个单独的预测数据集predict_dataset = NPYDataset(data_root=configs.data_root,mode='predict',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)predict_loader = DataLoader(predict_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)all_predictions = []all_inputs = []with torch.no_grad():for inputs, _ in predict_loader:inputs = inputs.to(configs.device)predictions = predictor.predict(inputs)all_predictions.append(predictions.numpy())all_inputs.append(inputs.cpu().numpy())all_predictions = np.concatenate(all_predictions, axis=0)all_inputs = np.concatenate(all_inputs, axis=0)# 保存预测结果output_dir = os.path.join(configs.save_dir, 'predictions')os.makedirs(output_dir, exist_ok=True)for i, (input_seq, pred_seq) in enumerate(zip(all_inputs, all_predictions)):# 保存输入序列for j, frame in enumerate(input_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_input_{j:03d}.npy')np.save(frame_path, frame.squeeze())# 保存预测序列for j, frame in enumerate(pred_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_pred_{j:03d}.npy')np.save(frame_path, frame.squeeze())print(f"Predictions saved to {output_dir}")# 可视化一些结果if len(all_inputs) > 0:sample_idx = 0predictor.visualize_results(all_inputs[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],save_path=os.path.join(output_dir, 'sample_prediction.png'))if __name__ == '__main__':main()
9. 常见问题与解决方案
9.1 内存不足问题
当处理 500×500 的大尺寸图像时,可能会遇到内存不足的问题。解决方案:
- 使用数据分块:将大图像分割成小块进行处理
- 降低批次大小:减少每次处理的样本数量
- 使用混合精度训练:使用半精度浮点数减少内存占用
# 混合精度训练示例
from torch.cuda.amp import autocast, GradScalerdef train_epoch_with_amp(self, epoch):"""使用混合精度训练一个epoch"""self.model.train()total_loss = 0scaler = GradScaler()progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 使用自动混合精度with autocast():outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)# 缩放损失并反向传播self.optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(self.optimizer)scaler.update()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss
9.2 训练不稳定问题
PredRNNv2 模型训练可能会不稳定,可以尝试以下方法:
- 梯度裁剪:防止梯度爆炸
- 学习率调度:动态调整学习率
- 权重初始化:使用合适的初始化方法
# 梯度裁剪示例
def train_epoch_with_gradient_clipping(self, epoch, clip_value=1.0):"""带梯度裁剪的训练"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss
9.3 过拟合问题
当模型在训练集上表现良好但在验证集上表现不佳时,可能存在过拟合问题:
- 数据增强:增加数据多样性
- 正则化:使用 Dropout 或权重衰减
- 早停:在验证损失不再改善时停止训练
# 早停实现
class EarlyStopping:def __init__(self, patience=10, min_delta=0):self.patience = patienceself.min_delta = min_deltaself.counter = 0self.best_loss = Noneself.early_stop = Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.min_delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0return self.early_stop# 在训练循环中使用早停
early_stopping = EarlyStopping(patience=10)for epoch in range(num_epochs):# 训练和验证...if early_stopping(valid_loss):print("Early stopping triggered")break
10. 总结与展望
本文详细介绍了如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们涵盖了从环境配置、数据预处理、模型构建到训练和评估的完整流程。
10.1 主要成果
- 完整的数据处理流程:实现了针对 NPY 格式数据的加载、预处理和增强功能
- PredRNNv2 模型复现:成功实现了 PredRNNv2 模型的核心组件和完整架构
- 训练框架:构建了完整的训练、验证和测试流程,包括损失函数、优化器和学习率调度
- 预测与可视化:实现了预测功能和结果可视化,便于分析模型性能
- 问题解决方案:提供了针对常见问题(内存不足、训练不稳定、过拟合)的解决方案
10.2 未来工作方向
- 模型优化:尝试更先进的视频预测模型,如 SimVP、PhyDNet 等
- 多模态融合:结合其他传感器数据(如气象数据、地理信息)提高预测精度
- 实时预测:优化模型推理速度,实现实时预测功能
- 不确定性量化:增加对预测结果不确定性的估计
- 部署优化:将模型部署到生产环境,支持大规模数据处理
通过本文的指导和代码实现,读者应该能够成功复现 PredRNNv2 模型,并在自己的数据集上进行训练和预测。希望这项工作能够为视频预测任务的研究和应用提供有价值的参考。