ARConv的复现流程
使用环境
Python 3.10.16
torch 2.1.1+cu118
torchvision 0.16.1+cu118
其它按照官方提供代码的requirements.txt安装
GitHub - WangXueyang-uestc/ARConv: Official repo for Adaptive Rectangular Convolution
数据准备
从官方主页下载pancollection数据集PanCollection for Survey Paper
以WV3 Dataset为例,我们下载训练集和测试集
[1] Training Dataset(训练数据集, 5.76GB): [Baidu Cloud]
[2] Testing Dataset(测试数据集, 20 Examples/per class): [ReducedData(H5 Format)] [FullData(H5 Format)]
训练
在这里我没有使用官方推荐的运行.sh文件,而是直接去调用trainer.py执行,那么我修改了两个文件以找到模型,主要是相对导入和绝对导入的问题。
ARConv/models/models.py
from ARConv import ARConv -> from .ARConv import ARConv
ARConv/trainer.py
from .models import ARNet -> from models import ARNet
运行trainer.py进行训练,下面给出仅使用GPU 0进行训练的代码
CUDA_VISIBLE_DEVICES="0" python trainer.py --batch_size 16 --epochs 600 --lr 0.0006 --ckpt 20 --train_set_path ./pansharpening/training_data/train_wv3.h5 --checkpoint_save_path ./workdir/wv3 --hw_range 1 18 --task 'wv3'
测试
训练完毕后,模型权重pth文件被存入设定的文件目录中,经过作者的回复,和自己的补充,我写了两个python脚本getFullmat.py和getReducedmat.py分别用于生成模型输出的文件,在Matlab中进行测试。将其中的checkpoint_path 改为自己pth存放的文件路径即可。
getFullmat.py
import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):data = h5py.File(file_path)lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()return lms, ms, pan# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_OrigScale_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Full/RRNet/results/'# 创建保存目录
os.makedirs(save_dir, exist_ok=True)# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()# 加载测试数据
lms, ms, pan = load_set(test_data_path)# 推理所有图像
with torch.no_grad():print('Running model inference...')for i in range(pan.shape[0]):output = model(pan[i], lms[i], 1000, [1, 18])output = rearrange(output, 'b c h w -> b h w c') * 2047output_np = output[0].cpu().numpy()save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')sio.savemat(save_mat_path, {'sr': output_np})print(f"Saved .mat to {save_mat_path}")
getReducedmat.py
import torch
import torch.nn as nn
import os
import scipy.io as sio
from einops import rearrange
from models import ARNet
import h5py
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def load_set(file_path):data = h5py.File(file_path)lms = torch.from_numpy(np.array(data['lms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()ms = torch.from_numpy(np.array(data['ms'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()pan = torch.from_numpy(np.array(data['pan'][...], dtype=np.float32) / 2047.).unsqueeze(dim=0).permute([1, 0, 2, 3, 4]).float()return lms, ms, pan# 路径设置(请根据实际路径修改)
checkpoint_path = r'workdir/wv3/checkpoint_160_2025-05-02-16-06-33.pth'
test_data_path = r'pansharpening/test_data/WV3/test_wv3_multiExm1.h5'
save_dir = r'2_DL_Result/PanCollection/WV3_Reduced/RRNet/results/'# 创建保存目录
os.makedirs(save_dir, exist_ok=True)# 加载模型
model = ARNet().cuda()
model = nn.DataParallel(model)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()# 加载测试数据
lms, ms, pan = load_set(test_data_path)# 推理所有图像
with torch.no_grad():print('Running model inference...')for i in range(pan.shape[0]):output = model(pan[i], lms[i], 1000, [1, 18])output = rearrange(output, 'b c h w -> b h w c') * 2047output_np = output[0].cpu().numpy()save_mat_path = os.path.join(save_dir, f'output_mulExm_{i}.mat')sio.savemat(save_mat_path, {'sr': output_np})print(f"Saved .mat to {save_mat_path}")
之后将要2_DL_Result放入ARConv\MetricCode中
修改 Demo1_Reduced_Resolution_MultiExm_wv3.m 和Demo2_Full_Resolution_multi_wv3.m中的file_test路径,改为存放测试集的文件即可。
我分别修改为了
Demo1_Reduced_Resolution_MultiExm_wv3.m :
opts.file = 'test_wv3_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');
Demo2_Full_Resolution_multi_wv3.m:
opts.file = 'test_wv3_OrigScale_multiExm1';
file_test = strcat('pansharpening/test_data/WV3/', opts.file,'.h5');
并在路径中放入了两个测试集文件
test_wv3_multiExm1.h5和test_wv3_OrigScale_multiExm1.h5
之后运行测试即可成功完成测试