使用pytorch和opencv根据颜色相似性提取图像
需求:将下图中的花朵提取出来。
代码:
import cv2
import torch
import numpy as np
import time
def get_similar_colors(image, color_list, threshold):
# 将图像和颜色列表转换为torch张量
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)
color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)
# 计算每个像素与颜色列表中每个颜色的距离
distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)
# 找到最小距离及其索引
min_distances, _ = torch.min(distances, dim=-1)
# 创建掩码,标记接近目标颜色的像素
mask = min_distances < threshold
# 根据掩码提取接近颜色的部分
result = torch.where(mask.unsqueeze(-1), image_tensor, torch.zeros_like(image_tensor))
# 将结果转换回numpy数组
result_np = result.cpu().numpy().astype(np.uint8)
return result_np
# 读取图像s
image = cv2.imread('flower2.jpg')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(15, 220, 255),(30, 50, 220)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)
# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
进一步,输出掩码部分的黑白图像
import cv2
import torch
import numpy as np
import time
def get_similar_colors(image, color_list, threshold):
# 将图像和颜色列表转换为torch张量
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)
color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)
# 计算每个像素与颜色列表中每个颜色的距离
distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)
# 找到最小距离及其索引
min_distances, _ = torch.min(distances, dim=-1)
# 创建掩码,标记接近目标颜色的像素
mask = min_distances < threshold
# 将符合条件的像素设置为黑色
result = np.ones_like(image_tensor)
result[mask] = [0, 0, 0] # 设置为黑色
return result
# 读取图像s
image = cv2.imread('your/image/path')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(50, 15, 0), (45, 10, 0), (30, 10, 0)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)
# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()