C#使用OpenVinoSharp+魔塔社区的读光中英文OCR ONNX模型进行文字检测(仅检测不做识别)
效果如下:
模型链接:读光中英文OCR ONNX · 模型库
模型信息:
全部代码如下:
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Runtime.InteropServices;
using OpenCvSharp;
using OpenVinoSharp;
using Point = OpenCvSharp.Point;
using Size = OpenCvSharp.Size;
namespace DBNetTextDetection
{
public class DBNetTextDetector : IDisposable
{
private Core _core;
private Model _model;
private CompiledModel _compiledModel;
private InferRequest _inferRequest;
// 模型输入输出信息
private string _inputName;
private string _outputName;
private int[] _inputShape;
private int[] _outputShape;
// 图像处理参数
private const int ModelWidth = 800;
private const int ModelHeight = 800;
public DBNetTextDetector(string modelPath)
{
// 初始化OpenVINO Runtime
_core = new Core();
// 加载模型
_model = _core.read_model(modelPath);
// 编译模型
_compiledModel = _core.compile_model(_model, "GPU.0");
// 创建推理请求
_inferRequest = _compiledModel.create_infer_request();
// 获取输入输出信息
var inputs = _compiledModel.get_input();
var outputs = _compiledModel.get_output();
_inputName = inputs.get_name();
_inputShape = new int[] { (int)inputs.get_shape()[0], (int)inputs.get_shape()[1], (int)inputs.get_shape()[2], (int)inputs.get_shape()[3] };
_outputName = outputs.get_name();
if (outputs.get_shape().Count == 4)
_outputShape = new int[] { (int)outputs.get_shape()[0], (int)outputs.get_shape()[1], (int)outputs.get_shape()[2], (int)outputs.get_shape()[3] };
else
_outputShape = new int[] { (int)outputs.get_shape()[0], (int)outputs.get_shape()[1], (int)outputs.get_shape()[2] };
Console.WriteLine($"模型加载成功: {modelPath}");
Console.WriteLine($"输入名称: {_inputName}, 形状: [{string.Join(", ", _inputShape)}]");
Console.WriteLine($"输出名称: {_outputName}, 形状: [{string.Join(", ", _outputShape)}]");
}
public List<RotatedRect> DetectText(Mat image, float threshold = 0.3f, float boxThreshold = 0.5f)
{
// 预处理图像
var preprocessedImage = PreprocessImage(image);
// 准备输入Tensor
var inputTensor = _inferRequest.get_input_tensor();
FillTensorWithImage(inputTensor, preprocessedImage);
// 执行推理
_inferRequest.infer();
// 获取输出
var outputTensor = _inferRequest.get_output_tensor();
// 后处理获取文本框
var textBoxes = PostProcess(outputTensor, image.Width, image.Height, threshold, boxThreshold);
return textBoxes;
}
private Mat PreprocessImage(Mat image)
{
// 调整图像大小到模型输入尺寸
Mat resized = new Mat();
Cv2.Resize(image, resized, new Size(ModelWidth, ModelHeight));
// 转换为float32并归一化
Mat floatImage = new Mat();
resized.ConvertTo(floatImage, MatType.CV_32FC3, 1.0 / 255.0);
// 注意:OpenCV默认是BGR格式,如果模型需要RGB,需要转换
//Cv2.CvtColor(floatImage, floatImage, ColorConversionCodes.BGR2RGB);
resized.Dispose();
return floatImage;
}
private unsafe void FillTensorWithImage(Tensor tensor, Mat image)
{
IntPtr tensorData = tensor.data();
int channels = image.Channels();
int height = image.Height;
int width = image.Width;
// 将HWC格式转换为CHW格式
Mat[] channelsArray = new Mat[channels];
Cv2.Split(image, out channelsArray);
float* tensorPtr = (float*)tensorData.ToPointer();
for (int c = 0; c < channels; c++)
{
using (Mat channelMat = channelsArray[c])
{
// 获取当前通道数据的指针
float* channelPtr = (float*)channelMat.Data.ToPointer();
for (int h = 0; h < height; h++)
{
for (int w = 0; w < width; w++)
{
// 计算目标索引 (CHW格式)
int destIndex = c * height * width + h * width + w;
// 计算源索引 (HW格式)
int srcIndex = h * width + w;
tensorPtr[destIndex] = channelPtr[srcIndex];
}
}
}
}
}
private List<RotatedRect> PostProcess(Tensor outputTensor, int originalWidth, int originalHeight,
float threshold, float boxThreshold)
{
List<RotatedRect> textBoxes = new List<RotatedRect>();
// 获取输出数据
float[] outputData = GetTensorData(outputTensor);
// 将一维数据转换为二维概率图
Mat probMap = new Mat(ModelHeight, ModelWidth, MatType.CV_32FC1, outputData);
// 二值化处理
Mat binaryMap = new Mat();
Cv2.Threshold(probMap, binaryMap, threshold, 1.0, ThresholdTypes.Binary);
binaryMap.ConvertTo(binaryMap, MatType.CV_8UC1,255.0);
Cv2.ImShow("binaryMap", binaryMap);
// 查找轮廓
Point[][] contours;
HierarchyIndex[] hierarchy;
Cv2.FindContours(binaryMap, out contours, out hierarchy,
RetrievalModes.List, ContourApproximationModes.ApproxSimple);
// 处理每个轮廓
foreach (var contour in contours)
{
if (contour.Length < 5) continue;
// 获取最小外接矩形
RotatedRect rotatedRect = Cv2.MinAreaRect(contour);
// 过滤太小的框
if (rotatedRect.Size.Width < 2 || rotatedRect.Size.Height < 2) continue;
// 将坐标缩放回原始图像尺寸
float scaleX = (float)originalWidth / ModelWidth;
float scaleY = (float)originalHeight / ModelHeight;
RotatedRect scaledRect = new RotatedRect(
new Point2f(rotatedRect.Center.X * scaleX, rotatedRect.Center.Y * scaleY),
new Size2f(rotatedRect.Size.Width * scaleX*2.3, rotatedRect.Size.Height * scaleY*2.0),
rotatedRect.Angle);
textBoxes.Add(scaledRect);
}
probMap.Dispose();
binaryMap.Dispose();
return textBoxes;
}
private float[] GetTensorData(Tensor tensor)
{
// 根据Tensor的数据类型获取数据
// 这里假设是float32类型
int elementCount = 1;
for (int i = 0; i < tensor.get_shape().Count; i++)
{
elementCount *= (int)tensor.get_shape()[i];
}
float[] data = new float[elementCount];
Marshal.Copy(tensor.data(), data, 0, elementCount);
return data;
}
public Mat DrawDetectionResult(Mat image, List<RotatedRect> textBoxes)
{
Mat result = image.Clone();
foreach (var box in textBoxes)
{
Point2f[] points = box.Points();
// 绘制旋转矩形
for (int i = 0; i < 4; i++)
{
Cv2.Line(result, (Point)points[i], (Point)points[(i + 1) % 4],
new Scalar(0, 255, 0), 2);
}
// 绘制中心点
Cv2.Circle(result, (Point)box.Center, 3, new Scalar(0, 0, 255), -1);
}
return result;
}
public void Dispose()
{
_inferRequest?.Dispose();
_compiledModel?.Dispose();
_model?.Dispose();
_core?.Dispose();
}
}
}