当前位置: 首页 > news >正文

使用tensorRT10部署低光照补偿模型

1.低光照补偿模型的简单介绍

作者介绍一种Zero-Reference Deep Curve Estimation (Zero-DCE)的方法用于在没有参考图像的情况下增强低光照图像的效果。

具体来说,它将低光照图像增强问题转化为通过深度网络进行图像特定曲线估计的任务。训练了一个轻量级的深度网络 DCE-Net,来估计像素级和高阶曲线,以对给定图像进行动态范围调整。这种曲线估计考虑了像素值范围、单调性和可微性等因素。

Zero-DCE 的优点在于它不需要任何成对或不成对的数据进行训练,它通过一系列精心设计的非参考损失函数来实现这一点,这些函数能隐式地衡量增强质量并驱动网络学习。该方法通过直观且简单的非线性曲线映射实现图像增强,并且在多种照明条件下都具有很好的适用性。

文章还通过大量的实验来证明 Zero-DCE 在亮度、色彩、对比度和自然度等方面的视觉效果优于现有的先进方法,而其他方法在处理极暗背光或生成彩色伪影方面可能会失败。相比之下,Zero-DCE 的训练方式也与其他深度学习方法不同,并且它在黑暗环境下的面部检测方面也具有潜在优势。

这篇论文的方案以及低光照补偿结果如下:

文章源码地址:https://github.com/Li-Chongyi/Zero-DCE.git

2. zero_dce源码的简单介绍

2.1模型设计

模型设计比较简单,常规常见的算子

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
#import pytorch_colors as colors
import numpy as npclass enhance_net_nopool(nn.Module):def __init__(self):super(enhance_net_nopool, self).__init__()self.relu = nn.ReLU(inplace=True)number_f = 32self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)def forward(self, x):x1 = self.relu(self.e_conv1(x))# p1 = self.maxpool(x1)x2 = self.relu(self.e_conv2(x1))# p2 = self.maxpool(x2)x3 = self.relu(self.e_conv3(x2))# p3 = self.maxpool(x3)x4 = self.relu(self.e_conv4(x3))x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))# x5 = self.upsample(x5)x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)x = x + r1*(torch.pow(x,2)-x)x = x + r2*(torch.pow(x,2)-x)x = x + r3*(torch.pow(x,2)-x)enhance_image_1 = x + r4*(torch.pow(x,2)-x)		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)		x = x + r6*(torch.pow(x,2)-x)	x = x + r7*(torch.pow(x,2)-x)enhance_image = x + r8*(torch.pow(x,2)-x)r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)return enhance_image_1,enhance_image,r

2.2模型训练和损失函数

模型的损失函数设计部分比较复杂,在训练过程中使用

	L_color = Myloss.L_color()L_spa = Myloss.L_spa()L_exp = Myloss.L_exp(16,0.6)L_TV = Myloss.L_TV()

而这里的损失函数全都在文件中间的Myloss.py文件中,在训练的过程中:

for epoch in range(config.num_epochs):for iteration, img_lowlight in enumerate(train_loader):img_lowlight = img_lowlight.cuda()enhanced_image_1,enhanced_image,A  = DCE_net(img_lowlight)Loss_TV = 200*L_TV(A)loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))loss_col = 5*torch.mean(L_color(enhanced_image))loss_exp = 10*torch.mean(L_exp(enhanced_image))# best_lossloss =  Loss_TV + loss_spa + loss_col + loss_exp

2.3 图像的前处理

源码中的图像前处理部分如下:

def __getitem__(self, index):data_lowlight_path = self.data_list[index]data_lowlight = Image.open(data_lowlight_path)data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)data_lowlight = (np.asarray(data_lowlight)/255.0) data_lowlight = torch.from_numpy(data_lowlight).float()return data_lowlight.permute(2,0,1)

在源码的lowlight_test中也可以看到图像的这个模型的前处理的代码:

    data_lowlight = Image.open(image_path)data_lowlight = (np.asarray(data_lowlight)/255.0)data_lowlight = torch.from_numpy(data_lowlight).float()data_lowlight = data_lowlight.permute(2,0,1)data_lowlight = data_lowlight.cuda().unsqueeze(0)

2.4 源码的后处理代码

源码直接使用torchvisopn.utils.save_image()方法保存了推理的结果

    _,enhanced_image,_ = DCE_net(data_lowlight)end_time = (time.time() - start)print(end_time)image_path = image_path.replace('test_data','result')result_path = image_pathif not os.path.exists(image_path.replace('/'+image_path.split("/")[-1],'')):os.makedirs(image_path.replace('/'+image_path.split("/")[-1],''))torchvision.utils.save_image(enhanced_image, result_path)

点开save_image()方法

def save_image(tensor: Union[torch.Tensor, List[torch.Tensor]],fp: Union[str, pathlib.Path, BinaryIO],format: Optional[str] = None,**kwargs,
) -> None:"""Save a given Tensor into an image file.Args:tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,saves the tensor as a grid of images by calling ``make_grid``.fp (string or file object): A filename or a file objectformat(Optional):  If omitted, the format to use is determined from the filename extension.If a file object was used instead of a filename, this parameter should always be used.**kwargs: Other arguments are documented in ``make_grid``."""if not torch.jit.is_scripting() and not torch.jit.is_tracing():_log_api_usage_once(save_image)grid = make_grid(tensor, **kwargs)# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integerndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()im = Image.fromarray(ndarr)im.save(fp, format=format)

3. 导出模型

由于这个代码本身没有复杂的算子和其他恶心的操作,我这边直接使用yolov5的环境测试这个lowlight_test.py的文件,发现可以直接运行。这里需要需要注意

DCE_net.load_state_dict(torch.load('/home/zpec/workspace/source-cnn/Zero-DCE-master/Zero-DCE_code/snapshots/Epoch99.pth'))

filePath = '/home/zpec/workspace/source-cnn/Zero-DCE-master/Zero-DCE_code/data/test_data/'

这里使用完整的路径。

在Zero-DCE_code文件夹下面创建export_onnx.py的文件,写如下的导出代码

import torch
import model  def convert_to_static_onnx():# 设备配置device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载预训练模型DCE_net = model.enhance_net_nopool().to(device)DCE_net.load_state_dict(torch.load('/home/zpec/workspace/source-cnn/Zero-DCE-master/Zero-DCE_code/snapshots/Epoch99.pth', map_location=device))DCE_net.eval()static_height = 640   static_width = 640    # 创建固定尺寸的虚拟输入dummy_input = torch.randn(1, 3, static_height, static_width).to(device)# 导出为静态模型torch.onnx.export(DCE_net,dummy_input,"ZeroDCE_static640.onnx",verbose=True,input_names=["input"],output_names=["output1","output2","output3"],opset_version=12, )if __name__ == "__main__":convert_to_static_onnx()

运行即可生成对应模型的onnx文件,onnx文件可视化如下:

4. 使用onnx加载推理模型试验

这里加载python版本的onnxruntime来试验推理模型,完整的推理代码如下:

import onnxruntime as ort
import numpy as np
import cv2def preprocess_image_cv2(image_path, input_shape):# 读取图像img = cv2.imread(image_path)# 转换为 RGB 格式img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 调整大小img = cv2.resize(img, (input_shape[2], input_shape[1]))# 归一化img = img / 255.0# 转换为通道优先格式 (C, H, W)img = img.transpose(2, 0, 1)# 添加批次维度 (1, C, H, W)img = np.expand_dims(img, axis=0).astype(np.float32)return imgdef postprocess_image_cv2(output, output_shape):# 去除批次维度output = np.squeeze(output, axis=0)# 转换为 HWC 格式output = output.transpose(1, 2, 0)# 调整大小到原始图像大小output = cv2.resize(output, output_shape)# 转换为 BGR 格式output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)return output# 示例用法
if __name__ == '__main__':# 示例图像路径image_path = '/home/zpec/workspace/source-cnn/Zero-DCE-master/Zero-DCE_code/data/test_data/DICM/06.jpg'# 模型输入形状 (例如: [3, 640, 640])input_shape = [3, 640, 640]# 预处理图像input_image = preprocess_image_cv2(image_path, input_shape)ort_session = ort.InferenceSession('ZeroDCE_static640.onnx')# 运行推理input_name = ort_session.get_inputs()[0].nameoutput_name = ort_session.get_outputs()[1].name# 检查模型的输入和输出的节点# 获取所有输入节点的信息inputs_info = ort_session.get_inputs()# 获取所有输出节点的信息outputs_info = ort_session.get_outputs()# 打印输入节点的信息print("Input nodes:")for idx, input_info in enumerate(inputs_info):print(f"Input node {idx}:")print(f"  Name: {input_info.name}")print(f"  Shape: {input_info.shape}")print(f"  Type: {input_info.type}")print()# 打印输出节点的信息print("Output nodes:")for idx, output_info in enumerate(outputs_info):print(f"Output node {idx}:")print(f"  Name: {output_info.name}")print(f"  Shape: {output_info.shape}")print(f"  Type: {output_info.type}")print()outputs = ort_session.run([output_name], {input_name: input_image})# 后处理输出# output_image = postprocess_image_cv2(outputs[0], (input_shape[2], input_shape[1]))output_image = postprocess_image_cv2(outputs[0], (640, 480))# 保存结果output_image_path = 'enhanced_image.jpg'cv2.imwrite(output_image_path, (output_image * 255).astype(np.uint8))print(f"Output image saved to {output_image_path}")

推理效果展示如下:

说明我们的前处理和后处理没有任何问题。现在开始tensorRT的模型部署和推理吧。

5. 使用tensorRT10部署模型

项目地址:GitHub - YLXA321/ZERO_DCE_model-tensorRT10: 基于tensorRT10部署低光照补偿代码

图像前处理的代码

void preprocess_cpu(cv::Mat &srcImg, float* dstDevData, const int width, const int height) {if (srcImg.data == nullptr) {std::cerr << "ERROR: Image file not found! Program terminated" << std::endl;return;}cv::Mat dstimg;if (srcImg.rows != height || srcImg.cols != width) {cv::resize(srcImg, dstimg, cv::Size(width, height), cv::INTER_AREA);} else {dstimg = srcImg.clone();}// BGR→RGB转换 + HWC→CHW转换int index = 0;int offset_ch0 = width * height * 0;  // R通道int offset_ch1 = width * height * 1;  // G通道int offset_ch2 = width * height * 2;  // B通道for (int i = 0; i < height; i++) {for (int j = 0; j < width; j++) {index = i * width * 3 + j * 3;// 从BGR数据中提取并赋值到目标通道dstDevData[offset_ch0++] = dstimg.data[index + 2] / 255.0f;  // RdstDevData[offset_ch1++] = dstimg.data[index + 1] / 255.0f;  // GdstDevData[offset_ch2++] = dstimg.data[index + 0] / 255.0f;  // B}}
}

图像后处理的代码:

cv::Mat decode_cpu(const float* model_output, const int KInputW, const int KInputH, const int src_width, const int src_height) {cv::Mat src_image;if (model_output == nullptr) {std::cerr << "ERROR: Model output is null." << std::endl;return cv::Mat();}// 创建临时浮点图像(HWC格式,RGB顺序)cv::Mat temp_image(KInputH, KInputW, CV_32FC3);float* temp_data = reinterpret_cast<float*>(temp_image.data);  // 直接操作内存// 计算各通道的起始指针const int channel_size = KInputH * KInputW;const float* r_channel = model_output + 0;           // R通道起始地址const float* g_channel = model_output + channel_size; // G通道起始地址const float* b_channel = model_output + 2 * channel_size; // B通道起始地址// 并行化填充(OpenCV自动优化)for (int i = 0; i < KInputH; ++i) {for (int j = 0; j < KInputW; ++j) {const int pixel_idx = (i * KInputW + j) * 3;  // HWC中每个像素的起始位置const int ch_idx = i * KInputW + j;          // CHW中当前像素的通道内索引temp_data[pixel_idx]     = r_channel[ch_idx]; // Rtemp_data[pixel_idx + 1] = g_channel[ch_idx]; // Gtemp_data[pixel_idx + 2] = b_channel[ch_idx]; // B}}// 反归一化并转为8UC3(与Python一致)temp_image.convertTo(temp_image, CV_8UC3, 255.0);// Resize到目标尺寸(使用INTER_LINEAR)if (KInputW != src_width || KInputH != src_height) {cv::resize(temp_image, src_image, cv::Size(src_width, src_height), cv::INTER_LINEAR);} else {src_image = temp_image.clone();}// RGB转BGR(与Python的cv2.COLOR_RGB2BGR一致)cv::cvtColor(src_image, src_image, cv::COLOR_RGB2BGR);return src_image;}

构建模型的推理引擎:

bool genEngine(std::string onnx_file_path, std::string save_engine_path, trtlogger::Logger level, int maxbatch){auto logger = std::make_shared<trtlogger::Logger>(level);// 创建builderauto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(*logger));if(!builder){std::cout<<" (T_T)~~~, Failed to create builder."<<std::endl;return false;}auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0U));if(!network){std::cout<<" (T_T)~~~, Failed to create network."<<std::endl;return false;}// 创建 configauto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());if(!config){std::cout<<" (T_T)~~~, Failed to create config."<<std::endl;return false;}// 创建parser 从onnx自动构建模型,否则需要自己构建每个算子auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, *logger));if(!parser){std::cout<<" (T_T)~~~, Failed to create parser."<<std::endl;return false;}// 读取onnx模型文件开始构建模型auto parsed = parser->parseFromFile(onnx_file_path.c_str(), 1);if(!parsed){std::cout<<" (T_T)~~~ ,Failed to parse onnx file."<<std::endl;return false;}{auto input = network->getInput(0);auto input_dims = input->getDimensions();auto profile = builder->createOptimizationProfile(); // 配置最小、最优、最大范围input_dims.d[0] = 1;                                                         profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, input_dims);profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, input_dims);input_dims.d[0] = maxbatch;profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, input_dims);config->addOptimizationProfile(profile);// 判断是否使用半精度优化模型// if(FP16)  config->setFlag(nvinfer1::BuilderFlag::kFP16);config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);// 设置默认设备类型为 DLAconfig->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);// 获取 DLA 核心支持情况int numDLACores = builder->getNbDLACores();if (numDLACores > 0) {std::cout << "DLA is available. Number of DLA cores: " << numDLACores << std::endl;// 设置 DLA 核心int coreToUse = 0; // 选择第一个 DLA 核心(可以根据实际需求修改)config->setDLACore(coreToUse);std::cout << "Using DLA core: " << coreToUse << std::endl;} else {std::cerr << "DLA not available on this platform, falling back to GPU." << std::endl;// 如果 DLA 不可用,则设置 GPU 回退config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);config->setDefaultDeviceType(nvinfer1::DeviceType::kGPU);}};config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 1 << 28);      /*在新的版本中被使用*/// 创建序列化引擎文件auto plan = std::unique_ptr<nvinfer1::IHostMemory>(builder->buildSerializedNetwork(*network, *config));if(!plan){std::cout<<" (T_T)~~~, Failed to SerializedNetwork."<<std::endl;return false;}//! 检查输入部分是否符合要求auto numInput = network->getNbInputs();std::cout<<"模型的输入个数是:"<<numInput<<std::endl;for(auto i = 0; i<numInput; ++i){std::cout<<"    模型的第"<<i<<"个输入:";auto mInputDims = network->getInput(i)->getDimensions();std::cout<<"  ✨~ model input dims: "<<mInputDims.nbDims <<std::endl;for(size_t ii=0; ii<mInputDims.nbDims; ++ii){std::cout<<"  ✨^_^ model input dim"<<ii<<": "<<mInputDims.d[ii] <<std::endl;}}auto numOutput = network->getNbOutputs();std::cout<<"模型的输出个数是:"<<numOutput<<std::endl;for(auto i=0; i<numOutput; ++i){std::cout<<"    模型的第"<<i<<"个输出:";auto mOutputDims = network->getOutput(i)->getDimensions();std::cout<<"  ✨~ model output dims: "<<mOutputDims.nbDims <<std::endl;for(size_t jj=0; jj<mOutputDims.nbDims; ++jj){std::cout<<"  ✨^_^ model output dim"<<jj<<": "<<mOutputDims.d[jj] <<std::endl;}}// 序列化保存推理引擎文件文件std::ofstream engine_file(save_engine_path, std::ios::binary);if(!engine_file.good()){std::cout<<" (T_T)~~~, Failed to open engine file"<<std::endl;return false;}engine_file.write((char *)plan->data(), plan->size());engine_file.close();std::cout << " ~~Congratulations! 🎉🎉🎉~  Engine build success!!! ✨✨✨~~ " << std::endl;return true;}

创建runtime部分:

bool ZeroDCEModel::Runtime(std::string engine_file_path, trtlogger::Logger level,int maxBatch){auto logger = std::make_shared<trtlogger::Logger>(level);// 初始化trt插件
//    initLibNvInferPlugins(&logger, "");std::ifstream engineFile(engine_file_path, std::ios::binary);long int fsize = 0;engineFile.seekg(0, engineFile.end);fsize = engineFile.tellg();engineFile.seekg(0, engineFile.beg);std::vector<char> engineString(fsize);engineFile.read(engineString.data(), fsize);if (engineString.size() == 0) { std::cout << "Failed getting serialized engine!" << std::endl; return false; }// 创建推理引擎m_runtime.reset(nvinfer1::createInferRuntime(*logger));if(!m_runtime){std::cout<<" (T_T)~~~, Failed to create runtime."<<std::endl;return false;}// 反序列化推理引擎m_engine.reset(m_runtime->deserializeCudaEngine(engineString.data(), fsize));if(!m_engine){std::cout<<" (T_T)~~~, Failed to deserialize."<<std::endl;return false;}// 获取优化后的模型的输入维度和输出维度// int nbBindings = m_engine->getNbBindings(); // trt8.5 以前版本int nbBindings = m_engine->getNbIOTensors();  // trt8.5 以后版本// 推理执行上下文m_context.reset(m_engine->createExecutionContext());if(!m_context){std::cout<<" (T_T)~~~, Failed to create ExecutionContext."<<std::endl;return false;}auto input_dims = m_context->getTensorShape("input");input_dims.d[0] = maxBatch;m_context->setInputShape("input", input_dims);std::cout << " ~~Congratulations! 🎉🎉🎉~  create execution context success!!! ✨✨✨~~ " << std::endl;return true; 
}

申请内存,并且绑定模型输入输出:

bool ZeroDCEModel::trtIOMemory() {m_inputDims = m_context->getTensorShape("input"); // 模型输入m_outputDims[0] = m_context->getTensorShape("output1"); //第一个输出m_outputDims[1] = m_context->getTensorShape("output2"); //第二个输出m_outputDims[2] = m_context->getTensorShape("output3"); //第三个输出this->kInputH = m_inputDims.d[2];this->kInputW = m_inputDims.d[3];m_inputSize = m_inputDims.d[0] * m_inputDims.d[1] * m_inputDims.d[2] * m_inputDims.d[3] * sizeof(float);m_outputSize[0] = m_outputDims[0].d[0] * m_outputDims[0].d[1] * m_outputDims[0].d[2] * m_outputDims[0].d[3] * sizeof(float);m_outputSize[1] = m_outputDims[1].d[0] * m_outputDims[1].d[1] * m_outputDims[1].d[2] * m_outputDims[1].d[3] * sizeof(float);m_outputSize[2] = m_outputDims[2].d[0] * m_outputDims[2].d[1] * m_outputDims[2].d[2] * m_outputDims[2].d[3] * sizeof(float);// 声明cuda的内存大小checkRuntime(cudaMalloc(&buffers[0], m_inputSize));checkRuntime(cudaMalloc(&buffers[1], m_outputSize[0]));checkRuntime(cudaMalloc(&buffers[2], m_outputSize[1]));checkRuntime(cudaMalloc(&buffers[3], m_outputSize[2]));// 声明cpu内存大小checkRuntime(cudaMallocHost(&cpu_buffers[0], m_inputSize));checkRuntime(cudaMallocHost(&cpu_buffers[1], m_outputSize[0]));checkRuntime(cudaMallocHost(&cpu_buffers[2], m_outputSize[1]));checkRuntime(cudaMallocHost(&cpu_buffers[3], m_outputSize[2]));m_context->setTensorAddress("input", buffers[0]);m_context->setTensorAddress("output1", buffers[1]);m_context->setTensorAddress("output2", buffers[2]);m_context->setTensorAddress("output3", buffers[3]);checkRuntime(cudaStreamCreate(&m_stream));return true; 
}

推理模型:

cv::Mat ZeroDCEModel::doInference(cv::Mat& frame) {if(useGPU){zero_dce_preprocess::preprocess_gpu(frame, (float*)buffers[0], kInputH, kInputW,  m_stream);}else{zero_dce_preprocess::preprocess_cpu(frame, cpu_buffers[0], kInputW, kInputH);// Preprocess -- 将host的数据移动到device上checkRuntime(cudaMemcpyAsync(buffers[0], cpu_buffers[0], m_inputSize, cudaMemcpyHostToDevice, m_stream));}bool status = this->m_context->enqueueV3(m_stream);if (!status) std::cerr << "(T_T)~~~, Failed to create ExecutionContext." << std::endl;// 将gpu推理的结果返回到cpu上面处理checkRuntime(cudaMemcpyAsync(cpu_buffers[1], buffers[1], m_outputSize[0], cudaMemcpyDeviceToHost, m_stream));checkRuntime(cudaMemcpyAsync(cpu_buffers[2], buffers[2], m_outputSize[1], cudaMemcpyDeviceToHost, m_stream));checkRuntime(cudaMemcpyAsync(cpu_buffers[3], buffers[3], m_outputSize[2], cudaMemcpyDeviceToHost, m_stream));checkRuntime(cudaStreamSynchronize(m_stream));int height = frame.rows;int width = frame.cols;cv::Mat enhance_image;if(useGPU){enhance_image = zero_dce_postprocess::decode_gpu(buffers[2],kInputW,kInputH,width,height);}else{// cv::Mat enhance_image_1 = zero_dce_postprocess::decode_cpu(cpu_buffers[1],kInputW,kInputH,height,width);enhance_image = zero_dce_postprocess::decode_cpu(cpu_buffers[2],kInputW,kInputH,width,height);// cv::Mat r = zero_dce_postprocess::decode_cpu(cpu_buffers[3],kInputW,kInputH,height,width);}return enhance_image;}

部署代码的增强图展示:

至此,完成模型zero_dce_model模型的部署代码。

6. 低光照补偿代码的使用

这是一个低光照补偿的模型部署,一般情况下需要配合其他模型使用。比如在检测模型中,发现实际检测场景比较暗,这个时候可以先配合检查图像的暗亮程度,如果过暗的话,可以使用这个模型先增加图像的亮度,然后再次输入到检测模型中开始检测。

检测图像的暗亮程度,对这个图像灰度化,然后求取图像的平均亮度作为判断条件,然后再使用其他模型。

int main(){cv::VideoCapture cap("media/6.mp4");// 检查视频是否成功打开if (!cap.isOpened()) {std::cerr << "无法打开视频文件或摄像头!" << std::endl;return -1;}// 创建一个窗口用于显示视频cv::namedWindow("Video", cv::WINDOW_NORMAL);cv::Mat frame;while (true) {// 读取一帧if (!cap.read(frame)) {std::cerr << "无法读取视频帧!" << std::endl;break;}//----------------判断的亮度------------------------cv::Mat gray;// 将彩色图转换为灰度图cv::cvtColor(frame, gray, cv::COLOR_BGR2GRAY);cv::Scalar mean_value = cv::mean(gray);std::cout << "[OpenCV] Average: " << mean_value[0] << std::endl;if (mean_value[0]< 30){frame = zero_model.doInference(frame);}//----------------判断的亮度------------------------auto detections = model.doInference(frame);model.draw(frame,detections);// 显示这一帧cv::imshow("Video", frame);// 按下 'q' 键退出循环if (cv::waitKey(30) == 'q') {break;}}// 释放资源并关闭窗口cap.release();cv::destroyAllWindows();return 0;
}

相关文章:

  • MySQL相关
  • [强化学习的数学原理—赵世钰老师]学习笔记02-贝尔曼方程
  • 支持selenium的chrome driver更新到136.0.7103.94
  • 【2025年软考中级】第一章1.6 安全性、可靠性、性能评价
  • Python爬虫实战:获取1688商品信息
  • 无需配置光猫,使用网管交换机配合路由器的IPTV功能实现单线复用
  • Uniapp开发鸿蒙应用时如何运行和调试项目
  • Kotlin与机器学习实战:Android端集成TensorFlow Lite全指南
  • 从神经架构到万物自动化的 AI 革命:解码深度学习驱动的智能自动化新范式
  • 人工智能100问☞第25问:什么是循环神经网络(RNN)?
  • 基于OpenCV的SIFT特征和FLANN匹配器的指纹认证
  • 互联网大厂Java面试:从Spring到微服务的全面探讨
  • Spring Initializr快速创建项目案例
  • QT使用QXlsx读取excel表格中的图片
  • OGGMA 21c 微服务 (MySQL) 安装避坑指南
  • 25、DeepSeek-R1论文笔记
  • 设计模式7大原则与UML类图详解
  • C++学习:六个月从基础到就业——C++11/14:列表初始化
  • 数学复习笔记 19
  • JDK 21新特性全面解析
  • 穆迪下调美国主权信用评级
  • 试点首发进口消费品检验便利化措施,上海海关与上海商务委发文
  • 习近平就乌拉圭前总统穆希卡逝世向乌拉圭总统奥尔西致唁电
  • 国寿资产获批参与第三批保险资金长期投资改革试点
  • 《日出》华丽的悲凉,何赛飞和赵文瑄演绎出来了
  • 阿联酋与美国达成超过2000亿美元协议