当前位置: 首页 > news >正文

朴实无华cnn 识别加工线段 删除标注线段

朴实无华cnn 识别加工线段 删除标注线段

模型地址标注识别加工线段识别cnn模型 · 模型库

http://localhost:5000/predict?image_path=E:\code\cad_ppo\output\pictures\5_LINE_3914.png&format=html

不是的,F1 Score 并不用于反向传播。让我为您澄清一下:损失函数与评估指标的区别
实际用于反向传播的是损失函数代码中使用的是 FocalLoss 作为损失函数进行反向传播
在训练循环中,通过 loss.backward() 来计算梯度并更新模型参数
F1 Score 是评估指标F1 Score 仅用于评估模型性能,不参与梯度计算
在代码中,F1 Score 用于判断何时保存最佳模型(基于验证集上的 F1 Score)
python
# 这部分是实际的训练过程,使用的是 criterion (FocalLoss)
outputs = model(inputs)
loss = criterion(outputs, labels)  # 使用 FocalLoss 计算损失
loss.backward()                    # 基于损失函数进行反向传播
optimizer.step()                   # 更新模型参数# 这部分是评估过程,计算 F1 Score 用于模型选择
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
if f1_score > best_f1_score:best_f1_score = f1_scoretorch.save(model.state_dict(), model_save_path)  # 保存最佳模型
为什么不能直接使用 F1 Score 进行反向传播?
F1 Score 是不可微的离散指标,在大多数情况下无法计算梯度
损失函数(如 FocalLoss)是连续可微的,适合用于梯度下降优化
我们通过优化损失函数来间接提高 F1 Score 等评估指标的表现
总结:模型通过 FocalLoss 进行反向传播和参数更新,而 F1 Score 只是用来衡量模型性能和选择最佳模型的指标。
F1 Score(F1分数)是一个在机器学习中广泛使用的评估指标,特别适用于衡量分类模型的性能,尤其是在处理类别不平衡的数据集时。它实际上是精确率(Precision)和召回率(Recall)的调和平均数(Harmonic Mean)。
1. 基本概念假设我们有一个二分类问题(例如:判断邮件是否为垃圾邮件):
精确率(Precision):预测为“正类”的样本中,有多少是真正的正类。
公式:Precision = TP / (TP + FP)
直观理解:模型说“是”的时候,它有多准。
召回率(Recall),也叫查全率:所有真正的“正类”样本中,有多少被模型正确找出来了。
公式:Recall = TP / (TP + FN)
直观理解:所有真正是“是”的样本里,模型找到了多少。其中:
TP(True Positive):真正例,预测对了,是正类。
FP(False Positive):假正例,预测错了,把负类当成正类。
FN(False Negative):假反例,预测错了,把正类当成负类。
2. F1 Score 公式$$
F1 = 2 \times \frac{Precision \times Recall}{Precision + Recall}
$$
3. 为什么用 F1 Score?
精确率和召回率常常是此消彼长的关系。提高精确率可能会降低召回率,反之亦然。
F1 Score 综合考虑了这两个指标,提供了一个单一的、平衡的性能度量。
当你既关心“预测准不准”(Precision),又关心“有没有漏掉”(Recall)时,F1 Score 是一个很好的指标。
4. 你的训练日志中 F1 Score 的含义从你的日志看:
Epoch 1: F1 Score 是 0.0000,说明模型在验证集上几乎没识别出任何正类样本(可能是预测全为负类,导致 Recall 或 Precision 为 0)。
Epoch 2: F1 Score 提升到 0.6943,说明模型开始有效识别正类,精确率和召回率的综合表现不错。
“Best model saved with validation F1 score: 0.6943” 表示这是目前模型在验证集上的最佳表现,因此保存了这个版本的模型。
5. F1 Score 的取值范围
范围在 0 到 1 之间。
1:完美,精确率和召回率都为 1。
0:最差,精确率或召回率为 0。
你的 0.6943 属于中等偏上的水平,还有提升空间。
总结F1 Score 是一个衡量分类模型综合性能的重要指标,尤其在类别不平衡时比单纯的准确率(Accuracy)更有参考价值。你的模型在第2个epoch时,F1 Score 从 0 提升到 0.69,说明学习是有效的。继续训练或调整模型,目标就是让这个值尽可能接近 1。
数据分布 - 删除: 314, 保留: 28
平衡后数据分布 - 删除: 314, 保留: 314
Using device: cuda:0        
Epoch [1/20]Train Loss: 0.0855, Train Acc: 54.38%Val Loss: 0.0570, Val Acc: 46.83%, F1 Score: 0.0000   
Epoch [2/20]Train Loss: 0.0528, Train Acc: 65.54%Val Loss: 0.0800, Val Acc: 53.17%, F1 Score: 0.6943   Best model saved with validation F1 score: 0.6943     
Epoch [3/20]Train Loss: 0.0351, Train Acc: 80.08%Val Loss: 0.0235, Val Acc: 87.30%, F1 Score: 0.8841   Best model saved with validation F1 score: 0.8841     
Epoch [4/20]Train Loss: 0.0252, Train Acc: 85.86%Val Loss: 0.0290, Val Acc: 89.68%, F1 Score: 0.8926   Best model saved with validation F1 score: 0.8926     
Epoch [5/20]Train Loss: 0.0236, Train Acc: 87.65%Val Loss: 0.0262, Val Acc: 90.48%, F1 Score: 0.9016   Best model saved with validation F1 score: 0.9016     
Epoch [6/20]Train Loss: 0.0201, Train Acc: 88.25%Val Loss: 0.0240, Val Acc: 87.30%, F1 Score: 0.8644   
Epoch [7/20]Train Loss: 0.0154, Train Acc: 93.43%Val Loss: 0.0151, Val Acc: 90.48%, F1 Score: 0.9016   
Epoch [8/20]Train Loss: 0.0174, Train Acc: 91.04%Val Loss: 0.0195, Val Acc: 92.86%, F1 Score: 0.9280   Best model saved with validation F1 score: 0.9280     
Epoch [9/20]Train Loss: 0.0169, Train Acc: 90.84%Val Loss: 0.0172, Val Acc: 92.86%, F1 Score: 0.9280   
Epoch [10/20]Train Loss: 0.0134, Train Acc: 92.83%Val Loss: 0.0175, Val Acc: 92.86%, F1 Score: 0.9280   
Epoch [11/20]Train Loss: 0.0104, Train Acc: 94.22%Val Loss: 0.0223, Val Acc: 92.86%, F1 Score: 0.9280   
Epoch [12/20]Train Loss: 0.0086, Train Acc: 95.02%Val Loss: 0.0140, Val Acc: 93.65%, F1 Score: 0.9365   Best model saved with validation F1 score: 0.9365     
Epoch [13/20]Train Loss: 0.0093, Train Acc: 95.22%Val Loss: 0.0238, Val Acc: 93.65%, F1 Score: 0.9365   
Epoch [14/20]Train Loss: 0.0087, Train Acc: 93.63%Val Loss: 0.0615, Val Acc: 80.95%, F1 Score: 0.7818   
Epoch [15/20]Train Loss: 0.0071, Train Acc: 95.82%Val Loss: 0.0234, Val Acc: 90.48%, F1 Score: 0.9016   
Epoch [16/20]Train Loss: 0.0070, Train Acc: 97.41%Val Loss: 0.0219, Val Acc: 92.06%, F1 Score: 0.9194   
Epoch [17/20]Train Loss: 0.0077, Train Acc: 95.42%Val Loss: 0.0187, Val Acc: 92.06%, F1 Score: 0.9194   
Epoch [18/20]Train Loss: 0.0046, Train Acc: 97.61%Val Loss: 0.0542, Val Acc: 74.60%, F1 Score: 0.6863   
Epoch [19/20]Train Loss: 0.0075, Train Acc: 96.61%Val Loss: 0.0171, Val Acc: 92.06%, F1 Score: 0.9194   
Epoch [20/20]Train Loss: 0.0062, Train Acc: 96.41%Val Loss: 0.0204, Val Acc: 92.06%, F1 Score: 0.9194   
Training completed. Best validation F1 score: 0.9365    
127.0.0.1 - - [21/Oct/2025 22:23:48] "GET /train?json_file=output/combined_labels.json&image_dir=output/pictures&model_save_path=models/improved_dxf_entity_cnn_model.pth&num_epochs=20 HTTP/1.1" 200 -
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json
import os
import numpy as np
from sklearn.metrics import classification_report, confusion_matrixclass ImprovedDXFEntityDataset(Dataset):"""改进的DXF实体数据集类,支持数据平衡和增强"""def __init__(self, json_file, root_dir, transform=None, balance_data=True):"""Args:json_file (string): 包含样本信息的json文件路径root_dir (string): 图像文件的根目录transform (callable, optional): 可选的图像变换balance_data (bool): 是否平衡数据集"""with open(json_file, 'r') as f:self.data_info = json.load(f)self.root_dir = root_dirself.transform = transformself.samples = self.data_info.get('samples', [])# 分析数据分布self.delete_count = sum(1 for s in self.samples if s['label']['action'] == 1)self.retain_count = len(self.samples) - self.delete_countprint(f"数据分布 - 删除: {self.delete_count}, 保留: {self.retain_count}")# 如果需要平衡数据if balance_data and self.delete_count != self.retain_count:self._balance_dataset()def _balance_dataset(self):"""平衡数据集,使删除和保留的样本数量相近"""delete_samples = [s for s in self.samples if s['label']['action'] == 1]retain_samples = [s for s in self.samples if s['label']['action'] == 0]# 过采样少数类,使用更积极的上采样策略if len(delete_samples) < len(retain_samples):# 上采样删除样本factor = len(retain_samples) // len(delete_samples) + 1delete_samples = (delete_samples * factor)[:len(retain_samples)]elif len(retain_samples) < len(delete_samples):# 上采样保留样本factor = len(delete_samples) // len(retain_samples) + 1retain_samples = (retain_samples * factor)[:len(delete_samples)]self.samples = delete_samples + retain_samplesprint(f"平衡后数据分布 - 删除: {len(delete_samples)}, 保留: {len(retain_samples)}")def __len__(self):return len(self.samples)def __getitem__(self, idx):if torch.is_tensor(idx):idx = idx.tolist()sample = self.samples[idx]img_path = sample['image_path']# 检查图像是否存在if not os.path.exists(img_path):print(f"警告: 图像文件不存在 {img_path}")# 返回默认图像或跳过raise FileNotFoundError(f"图像文件不存在: {img_path}")# 加载图像try:image = Image.open(img_path).convert('RGB')except Exception as e:print(f"警告: 无法加载图像 {img_path}: {e}")# 创建一个默认图像image = Image.new('RGB', (96, 96), color='black')# 获取标签 (1表示删除,0表示保留)label = sample['label']['action']if self.transform:image = self.transform(image)return image, torch.tensor(label, dtype=torch.long)class ImprovedDXFEntityCNN(nn.Module):"""改进的用于判断DXF实体是否应该删除的CNN网络"""def __init__(self, num_classes=2):super(ImprovedDXFEntityCNN, self).__init__()# 改进的卷积层,增加批归一化self.features = nn.Sequential(# 第一个卷积块nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32, 32, kernel_size=3, padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Dropout2d(0.25),# 第二个卷积块nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Dropout2d(0.25),# 第三个卷积块nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Dropout2d(0.25),)# 改进的分类器self.classifier = nn.Sequential(nn.Flatten(),nn.Dropout(0.5),nn.Linear(128 * 12 * 12, 512),nn.ReLU(inplace=True),nn.BatchNorm1d(512),nn.Dropout(0.5),nn.Linear(512, 128),nn.ReLU(inplace=True),nn.BatchNorm1d(128),nn.Dropout(0.5),nn.Linear(128, num_classes))def forward(self, x):x = self.features(x)x = self.classifier(x)return x# 实现Focal Loss来处理类别不平衡
class FocalLoss(nn.Module):def __init__(self, alpha=1, gamma=2, reduction='mean'):super(FocalLoss, self).__init__()self.alpha = alphaself.gamma = gammaself.reduction = reductiondef forward(self, inputs, targets):ce_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)pt = torch.exp(-ce_loss)focal_loss = self.alpha * (1-pt)**self.gamma * ce_lossif self.reduction == 'mean':return focal_loss.mean()elif self.reduction == 'sum':return focal_loss.sum()else:return focal_lossdef train_improved_cnn_model(json_file, image_dir, model_save_path='improved_dxf_cnn_model.pth', num_epochs=50, learning_rate=0.001, batch_size=32):"""改进的CNN模型训练函数Args:json_file: 包含训练数据信息的JSON文件路径image_dir: 图像文件所在的目录model_save_path: 模型保存路径num_epochs: 训练轮数learning_rate: 学习率batch_size: 批处理大小"""# 改进的数据预处理和增强,加强少数类的增强train_transform = transforms.Compose([transforms.Resize((96, 96)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(15),  # 增大旋转角度transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),  # 添加平移变换transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])val_transform = transforms.Compose([transforms.Resize((96, 96)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 创建数据集full_dataset = ImprovedDXFEntityDataset(json_file=json_file, root_dir=image_dir, transform=train_transform, balance_data=True)# 划分训练集和验证集train_size = int(0.8 * len(full_dataset))val_size = len(full_dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 为验证集应用不同的变换val_dataset.dataset.transform = val_transform# 创建数据加载器train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)# 检查是否有GPU可用device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 创建模型、损失函数和优化器model = ImprovedDXFEntityCNN(num_classes=2).to(device)# 使用Focal Loss处理类别不平衡criterion = FocalLoss(alpha=0.25, gamma=2)optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)# 训练模型best_val_loss = float('inf')best_f1_score = 0.0for epoch in range(num_epochs):# 训练阶段model.train()train_loss = 0.0train_correct = 0train_total = 0for inputs, labels in train_dataloader:inputs, labels = inputs.to(device), labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()# 统计信息train_loss += loss.item()_, predicted = torch.max(outputs.data, 1)train_total += labels.size(0)train_correct += (predicted == labels).sum().item()# 验证阶段model.eval()val_loss = 0.0val_correct = 0val_total = 0val_tp = 0  # True positivesval_fp = 0  # False positivesval_fn = 0  # False negativeswith torch.no_grad():for inputs, labels in val_dataloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)val_total += labels.size(0)val_correct += (predicted == labels).sum().item()# 计算F1分数所需统计量val_tp += ((predicted == 1) & (labels == 1)).sum().item()val_fp += ((predicted == 1) & (labels == 0)).sum().item()val_fn += ((predicted == 0) & (labels == 1)).sum().item()# 计算F1分数precision = val_tp / (val_tp + val_fp) if (val_tp + val_fp) > 0 else 0recall = val_tp / (val_tp + val_fn) if (val_tp + val_fn) > 0 else 0f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0# 计算平均损失avg_train_loss = train_loss / len(train_dataloader)avg_val_loss = val_loss / len(val_dataloader)train_accuracy = 100 * train_correct / train_totalval_accuracy = 100 * val_correct / val_total# 更新学习率scheduler.step(avg_val_loss)print(f'Epoch [{epoch+1}/{num_epochs}]')print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%')print(f'  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, F1 Score: {f1_score:.4f}')# 保存最佳模型(基于F1分数)if f1_score > best_f1_score:best_f1_score = f1_scoretorch.save(model.state_dict(), model_save_path)print(f'  Best model saved with validation F1 score: {best_f1_score:.4f}')print(f'Training completed. Best validation F1 score: {best_f1_score:.4f}')return modeldef evaluate_detailed_model(model, json_file, image_dir):"""详细评估模型性能,包括混淆矩阵和分类报告"""transform = transforms.Compose([transforms.Resize((96, 96)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])dataset = ImprovedDXFEntityDataset(json_file=json_file, root_dir=image_dir, transform=transform, balance_data=False)dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)model.eval()all_predictions = []all_labels = []with torch.no_grad():for inputs, labels in dataloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)all_predictions.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())# 计算详细指标accuracy = 100 * sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels)print(f'Overall Accuracy: {accuracy:.2f}%')# 混淆矩阵cm = confusion_matrix(all_labels, all_predictions)print("Confusion Matrix:")print(cm)# 分类报告target_names = ['Retain', 'Delete']print("\nClassification Report:")print(classification_report(all_labels, all_predictions, target_names=target_names))return accuracy# 在 delete_cnn.py 文件末尾添加以下代码
from flask import Flask, request, jsonify
import io
import base64
from PIL import Image as PILImage
import torchvision.transforms as transforms
import urllib.parseapp = Flask(__name__)# 全局变量存储模型和配置
model = None
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def init_model(model_path):"""初始化模型"""global modelmodel = ImprovedDXFEntityCNN(num_classes=2)model.load_state_dict(torch.load(model_path, map_location=device))model.to(device)model.eval()@app.route('/train', methods=['GET'])
def train_model_api():"""训练模型的API接口 (GET)---参数:- json_file: 训练数据JSON文件路径- image_dir: 图片目录路径- model_save_path: 模型保存路径- num_epochs: 训练轮数(可选,默认50)- learning_rate: 学习率(可选,默认0.001)- batch_size: 批处理大小(可选,默认32)"""try:# 从查询参数获取值json_file = request.args.get('json_file', 'output/combined_labels.json')image_dir = request.args.get('image_dir', 'output/pictures')model_save_path = request.args.get('model_save_path', 'models/improved_dxf_entity_cnn_model.pth')num_epochs = int(request.args.get('num_epochs', 50))learning_rate = float(request.args.get('learning_rate', 0.001))batch_size = int(request.args.get('batch_size', 32))# 创建模型保存目录os.makedirs(os.path.dirname(model_save_path), exist_ok=True)# 调用训练函数trained_model = train_improved_cnn_model(json_file=json_file,image_dir=image_dir,model_save_path=model_save_path,num_epochs=num_epochs,learning_rate=learning_rate,batch_size=batch_size)# 更新全局模型global modelmodel = trained_modelreturn jsonify({'status': 'success','message': 'Model training completed successfully'})except Exception as e:return jsonify({'status': 'error','message': str(e)}), 500@app.route('/evaluate', methods=['GET'])
def evaluate_model_api():"""评估模型的API接口 (GET)---参数:- json_file: 评估数据JSON文件路径- image_dir: 图片目录路径"""try:# 从查询参数获取值json_file = request.args.get('json_file', 'output/combined_labels.json')image_dir = request.args.get('image_dir', 'output/pictures')if model is None:return jsonify({'status': 'error','message': 'Model not initialized. Please train or load a model first.'}), 400# 执行评估accuracy = evaluate_detailed_model(model, json_file, image_dir)return jsonify({'status': 'success','accuracy': accuracy})except Exception as e:return jsonify({'status': 'error','message': str(e)}), 500@app.route('/predict', methods=['GET'])
def predict_single_image():"""对单张图片进行预测的API接口 (GET)---参数:- image_path: 待预测图片路径- format: 返回格式 ("json" 或 "html"),默认为 "json""""try:if model is None:return jsonify({'status': 'error','message': 'Model not initialized. Please train or load a model first.'}), 400# 从查询参数获取图片路径和格式image_path = request.args.get('image_path')format_type = request.args.get('format', 'json')  # 默认返回JSON格式if not image_path:return jsonify({'status': 'error','message': 'image_path parameter is required'}), 400# 图像预处理transform = transforms.Compose([transforms.Resize((96, 96)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加载并预处理图像image = PILImage.open(image_path).convert('RGB')input_tensor = transform(image).unsqueeze(0).to(device)# 执行预测with torch.no_grad():output = model(input_tensor)probabilities = torch.softmax(output, dim=1)prediction = torch.argmax(output, dim=1)# 解析结果action = "delete" if prediction.item() == 1 else "retain"confidence = probabilities[0][prediction.item()].item()result = {'status': 'success','prediction': {'action': action,'confidence': confidence,'probabilities': {'retain': probabilities[0][0].item(),'delete': probabilities[0][1].item()}},'image_path': image_path}# 根据format参数决定返回JSON还是HTMLif format_type.lower() == 'html':# 读取图片并转换为base64以便在HTML中显示with open(image_path, "rb") as img_file:img_data = base64.b64encode(img_file.read()).decode('utf-8')html_template = f"""<!DOCTYPE html><html><head><title>DXF Entity Prediction Result</title><style>body {{font-family: Arial, sans-serif;margin: 20px;background-color: #f5f5f5;}}.container {{max-width: 800px;margin: 0 auto;background-color: white;padding: 20px;border-radius: 8px;box-shadow: 0 2px 10px rgba(0,0,0,0.1);}}.result-card {{border: 1px solid #ddd;border-radius: 5px;padding: 15px;margin: 15px 0;background-color: #f9f9f9;}}.prediction-image {{max-width: 100%;height: auto;border: 1px solid #ccc;border-radius: 4px;}}.confidence-high {{color: #28a745;font-weight: bold;}}.confidence-medium {{color: #ffc107;font-weight: bold;}}.confidence-low {{color: #dc3545;font-weight: bold;}}.action-delete {{color: #dc3545;font-weight: bold;}}.action-retain {{color: #28a745;font-weight: bold;}}</style></head><body><div class="container"><h1>DXF Entity Prediction Result</h1><div class="result-card"><h2>Prediction Result</h2><p><strong>Image Path:</strong> {image_path}</p><p><strong>Action:</strong> <span class="action-{'delete' if action == 'delete' else 'retain'}">{action.upper()}</span></p><p><strong>Confidence:</strong> <span class="{'confidence-high' if confidence > 0.9 else 'confidence-medium' if confidence > 0.7 else 'confidence-low'}">{confidence:.2%}</span></p><p><strong>Probabilities:</strong></p><ul><li>Retain: {(probabilities[0][0].item()):.2%}</li><li>Delete: {(probabilities[0][1].item()):.2%}</li></ul></div><div class="result-card"><h2>Entity Image</h2><img src="data:image/png;base64,{img_data}" alt="DXF Entity" class="prediction-image"></div><div class="result-card"><h2>Interpretation</h2><p>The AI model predicts this DXF entity should be <strong>{action.upper()}</strong> with <strong>{confidence:.2%}</strong> confidence.</p>{'''<p style="color:#28a745;"><strong>Recommendation:</strong> The entity is likely safe to delete as indicated by high confidence.</p>''' if action == 'delete' and confidence > 0.9 else ''}{'''<p style="color:#ffc107;"><strong>Recommendation:</strong> Consider manual review before deleting as confidence is moderate.</p>''' if confidence <= 0.9 and confidence > 0.7 else ''}{'''<p style="color:#dc3545;"><strong>Recommendation:</strong> Low confidence - strongly recommend manual verification before taking action.</p>''' if confidence <= 0.7 else ''}</div></div></body></html>"""return html_template# 默认返回JSON格式return jsonify(result)except Exception as e:if request.args.get('format', 'json').lower() == 'html':return f"""<!DOCTYPE html><html><head><title>Error - DXF Entity Prediction</title><style>body {{font-family: Arial, sans-serif;margin: 20px;background-color: #f5f5f5;}}.container {{max-width: 800px;margin: 0 auto;background-color: white;padding: 20px;border-radius: 8px;box-shadow: 0 2px 10px rgba(0,0,0,0.1);}}.error {{color: #dc3545;padding: 15px;border: 1px solid #dc3545;border-radius: 5px;background-color: #f8d7da;}}</style></head><body><div class="container"><h1>Error - DXF Entity Prediction</h1><div class="error"><h2>Error occurred:</h2><p>{str(e)}</p></div></div></body></html>""", 500else:return jsonify({'status': 'error','message': str(e)}), 500@app.route('/predict_batch', methods=['GET'])
def predict_batch_images():"""批量预测图片的API接口 (GET)---参数:- json_file: 包含图片路径的JSON文件- image_dir: 图片目录路径(可选)"""try:if model is None:return jsonify({'status': 'error','message': 'Model not initialized. Please train or load a model first.'}), 400# 从查询参数获取值json_file = request.args.get('json_file')image_dir = request.args.get('image_dir', '')if not json_file:return jsonify({'status': 'error','message': 'json_file parameter is required'}), 400# 加载JSON文件with open(json_file, 'r') as f:samples_data = json.load(f)samples = samples_data.get('samples', [])# 图像预处理transform = transforms.Compose([transforms.Resize((96, 96)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])results = []for sample in samples:try:image_path = sample.get('image_path')if image_dir and not os.path.isabs(image_path):image_path = os.path.join(image_dir, image_path)# 加载并预处理图像image = PILImage.open(image_path).convert('RGB')input_tensor = transform(image).unsqueeze(0).to(device)# 执行预测with torch.no_grad():output = model(input_tensor)probabilities = torch.softmax(output, dim=1)prediction = torch.argmax(output, dim=1)# 解析结果action = "delete" if prediction.item() == 1 else "retain"confidence = probabilities[0][prediction.item()].item()results.append({'image_path': image_path,'prediction': {'action': action,'confidence': confidence},'ground_truth': sample.get('label', {})})except Exception as e:results.append({'image_path': image_path,'error': str(e)})return jsonify({'status': 'success','results': results})except Exception as e:return jsonify({'status': 'error','message': str(e)}), 500@app.route('/load_model', methods=['GET'])
def load_model_api():"""加载已训练模型的API接口 (GET)---参数:- model_path: 模型文件路径"""try:# 从查询参数获取模型路径model_path = request.args.get('model_path')if not model_path or not os.path.exists(model_path):return jsonify({'status': 'error','message': 'Model file not found'}), 400init_model(model_path)return jsonify({'status': 'success','message': 'Model loaded successfully'})except Exception as e:return jsonify({'status': 'error','message': str(e)}), 500# 将文件末尾的 __main__ 部分替换为以下代码
if __name__ == "__main__":# 默认模型路径default_model_path = "models/improved_dxf_entity_cnn_model.pth"# 检查命令行参数import argparseparser = argparse.ArgumentParser(description='Run DXF Entity CNN API')parser.add_argument('--model-path', type=str, help='Path to the trained model')args = parser.parse_args()# 尝试加载模型(优先使用命令行参数,否则使用默认路径)model_path = args.model_path if args.model_path else default_model_path# 如果模型文件存在,则加载if os.path.exists(model_path):init_model(model_path)print(f"Model loaded from {model_path}")else:print(f"Model file not found at {model_path}, starting server without preloaded model")print("You can train a model using the /train endpoint")app.run(host='0.0.0.0', port=5000, debug=True)

http://www.dtcms.com/a/513436.html

相关文章:

  • 广州网站建设联系新科海珠科技部网站建设合同
  • 宿迁网站优化排名网站集约化后如何建设
  • 物流网站做代理国内外高校门户网站建设的成功经验与特色分析
  • 运放电源配置
  • 呼伦贝尔市建设网站建设工程合同属于
  • 告别“读字”,开始“看图”:AI正在用人类的方式学习“阅读”
  • 网站建设管理的规章制度公司网站备案材料
  • 建设网站有何要求ios企业开发者账号
  • 正邦设计面试珠海网站优化
  • 3.cuda执行模型
  • 小兔自助建站宽屏网页设计尺寸
  • 微网站移交北京建设管理有限公司官网
  • 网站源码检测推广引流最快的方法
  • 网站备案多久一次安徽网站建设推荐 晨飞网络
  • 制作app连接网站有哪些专门制作网页的软件
  • 依托git交付代码,并提供技术支持的方案
  • 新网站seo外包怎么申请免费企业邮箱账号
  • 《算法千题(1)--- 第31场蓝桥算法挑战赛》
  • 网站加载动画效果看车二手车网站源码
  • 徐州本地网站网站流量功能更怎么做
  • 网站开发搭建有个网站是做视频相册的
  • 揭阳网站制作企业discuz 分类网站
  • 帮做网站制作挣钱wordpress菜单小图标不显示
  • jsp做的当当网站的文档东莞建设监督网
  • HashMap为什么线程不安全? ConcurrentHashMap如何保证线程安全? AQS如何实现锁的获取与释放?用男女关系进行解释,一看就懂
  • 免费开源网站系统有哪些门户网站建设方案费用
  • 动易网站后台管理系统新昌县住房和城乡建设局网站
  • 网站切片 做程序数据分析师报名入口
  • 宿迁市网站建设口腔医院网站开发
  • Redis 特性/应用场景/通用命令