当前位置: 首页 > news >正文

使用pytorch和opencv根据颜色相似性提取图像

需求:将下图中的花朵提取出来。

代码:

import cv2
import torch
import numpy as np
import time

def get_similar_colors(image, color_list, threshold):
    # 将图像和颜色列表转换为torch张量
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)
    color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)

    # 计算每个像素与颜色列表中每个颜色的距离
    distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)

    # 找到最小距离及其索引
    min_distances, _ = torch.min(distances, dim=-1)

    # 创建掩码,标记接近目标颜色的像素
    mask = min_distances < threshold

    # 根据掩码提取接近颜色的部分
    result = torch.where(mask.unsqueeze(-1), image_tensor, torch.zeros_like(image_tensor))

    # 将结果转换回numpy数组
    result_np = result.cpu().numpy().astype(np.uint8)

    return result_np
# 读取图像s
image = cv2.imread('flower2.jpg')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(15, 220, 255),(30, 50, 220)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)

# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

 

进一步,输出掩码部分的黑白图像

import cv2
import torch
import numpy as np
import time

def get_similar_colors(image, color_list, threshold):
    # 将图像和颜色列表转换为torch张量
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)
    color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)

    # 计算每个像素与颜色列表中每个颜色的距离
    distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)

    # 找到最小距离及其索引
    min_distances, _ = torch.min(distances, dim=-1)

    # 创建掩码,标记接近目标颜色的像素
    mask = min_distances < threshold

    # 将符合条件的像素设置为黑色
    result = np.ones_like(image_tensor)
    result[mask] = [0, 0, 0]  # 设置为黑色

    return result
# 读取图像s
image = cv2.imread('your/image/path')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(50, 15, 0), (45, 10, 0), (30, 10, 0)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)

# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

相关文章:

  • 2025-03-04 学习记录--C/C++-PTA 习题5-5 使用函数统计指定数字的个数
  • Golang语法特性总结
  • 物联网 全部技术栈和实现方案
  • FastGPT 源码:controller.ts 主要定义
  • android13打基础: 控件checkbox
  • 期权帮|股指期货入门知识:什么是股指期货基差?什么是股指期货价差?
  • Flink学习方法
  • 除了合并接口,还有哪些优化 Flask API 的方法?
  • android接入rocketmq
  • CentOS 7 安装Nginx-1.26.3
  • OCCT 学习笔记:创建瓶子教程的三个关键知识点
  • 【金融量化】Ptrade中交易环境支持的业务类型
  • Compose Multiplatform开发记录之文件选择器封装
  • Rust 面向对象特性解析:对象、封装与继承
  • 手机号码归属地的实现
  • jwt 存在的无状态的安全问题与解决方案
  • 解锁高效编程:深度剖析C++11核心语法与标准库实战精要
  • python的运行--命令行
  • 安卓开发相机功能
  • Linux 下查看 CPU 使用率
  • 做网站工作图/seo外包
  • 济南高端网站设计/怎样创建自己的电商平台
  • 佛山网站优化效果/推广优化厂商联系方式
  • 1688官网下载/优化大师怎么卸载
  • 微信公众号 手机网站开发/软文推广的标准类型
  • 泰安集团网站建设报价/百度怎么创建自己的网站