多模型协同预测在风机故障预测的应用(demo)
-
数据加载和预处理的真实性:
- 下面的代码中,
DummyDataset
和数据加载部分仍然是高度简化和占位的。为了让这个训练循环真正有效,您必须用您自己的数据加载逻辑替换它。 - 这意味着您需要创建一个
torch.utils.data.Dataset
的子类,它能够正确地从您的数据源(例如CSV文件、数据库、文件夹中的原始信号文件)加载每个样本的多种传感器数据。 - 在
__getitem__
方法中,您需要调用DataProcessor
的相应process_...
方法来提取特征,然后进行归一化(如果需要模型直接处理归一化后的特征,而不是在predict
中才做),并将所有数据转换成模型期望的张量格式和形状。 fit_scalers
的调用时机:DataProcessor
中的fit_scalers
方法必须在创建DataLoader
并开始训练之前,使用整个训练集提取出的特征进行调用。这一步至关重要。
- 下面的代码中,
-
特征准备的复杂性:
- 在批处理训练中,为每个模型(CNN, LSTM, Electrical)准备输入可能很复杂,特别是 LSTM 需要特征序列。您可能需要在
Dataset
的__getitem__
中就准备好这些序列,或者设计一个高效的批处理函数。 - 规则引擎在验证批次中的使用也需要仔细考虑,因为它通常处理单个样本的特征。
- 在批处理训练中,为每个模型(CNN, LSTM, Electrical)准备输入可能很复杂,特别是 LSTM 需要特征序列。您可能需要在
-
模型和融合权重的调优:
- 此代码提供了一个结构。要获得高准确率,您仍然需要对各个模型的超参数、
model_weights
(融合权重)进行仔细的实验和调优。
- 此代码提供了一个结构。要获得高准确率,您仍然需要对各个模型的超参数、
-
计算资源:
- 训练深度学习模型(尤其是多个模型)可能需要大量的计算资源(GPU)和时间。
import numpy as np
import torch
import torch.nn as nn
from scipy import signal as sig
from scipy.stats import kurtosis, skew
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pickle
from typing import Dict, List, Tuple, Optional, Any
import os # 用于创建目录# --- 配置常量 (根据需要调整) ---
VIBRATION_RAW_SEQ_LEN = 1024
VIBRATION_SAMPLING_RATE = 10000
PRESSURE_SAMPLING_RATE = 100
LSTM_FEATURE_SIZE = 10 # 温度(5) + 压力(4) + 电流单特征(1)
LSTM_SEQ_LEN = 5 # LSTM输入特征的示例序列长度
ELECTRICAL_FEATURE_SIZE = 2
NUM_FAULT_CLASSES = 10# 创建模型和结果保存目录
MODEL_SAVE_DIR = "saved_models"
SCALER_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, "fitted_scalers.pkl")
BEST_CNN_PATH = os.path.join(MODEL_SAVE_DIR, "best_cnn.pth")
BEST_LSTM_PATH = os.path.join(MODEL_SAVE_DIR, "best_lstm.pth")
BEST_ELECTRICAL_PATH = os.path.join(MODEL_SAVE_DIR, "best_electrical.pth")if not os.path.exists(MODEL_SAVE_DIR):os.makedirs(MODEL_SAVE_DIR)class DataProcessor:"""全故障覆盖的数据处理模块"""def __init__(self):self.scalers = {'vibration_features': StandardScaler(),'temperature': StandardScaler(),'pressure': StandardScaler(),'blade_angle': StandardScaler(),'oil_particles': StandardScaler(),'current': StandardScaler()}self.fitted_scalers = {}self.fault_types = {0: "正常", 1: "叶轮不平衡", 2: "轴承失效", 3: "动叶卡涩",4: "喘振", 5: "积灰堵塞", 6: "电机绕组故障", 7: "传感器失效",8: "基础松动", 9: "密封失效"}assert len(self.fault_types) == NUM_FAULT_CLASSES, "NUM_FAULT_CLASSES 和 fault_types 长度不匹配"def fit_scalers(self, training_features_dict: Dict[str, List[np.ndarray]]):print("正在拟合缩放器...")for feature_type, features_list in training_features_dict.items():if feature_type in self.scalers and features_list:all_features_for_type = np.array(features_list)if all_features_for_type.ndim == 1:all_features_for_type = all_features_for_type.reshape(-1, 1)if all_features_for_type.shape[0] == 0:print(f"警告:未提供用于拟合 {feature_type} 缩放器的数据。")continuetry:self.scalers[feature_type].fit(all_features_for_type)self.fitted_scalers[feature_type] = Trueprint(f"{feature_type} 的缩放器已拟合,数据形状: {all_features_for_type.shape}")except Exception as e:print(f"拟合 {feature_type} 缩放器时出错 (形状 {all_features_for_type.shape}): {e}")elif feature_type not in self.scalers:print(f"未为特征类型定义缩放器: {feature_type}")def save_scalers(self, path=SCALER_SAVE_PATH):with open(path, 'wb') as f:pickle.dump(self.scalers, f)print(f"已拟合的缩放器保存至 {path}")def load_scalers(self, path=SCALER_SAVE_PATH):try:with open(path, 'rb') as f:self.scalers = pickle.load(f)for key in self.scalers.keys():self.fitted_scalers[key] = Trueprint(f"缩放器从 {path} 加载成功")except FileNotFoundError:print(f"缩放器文件 {path} 未找到。如果不是在训练阶段,则初始化新的缩放器。")except Exception as e:print(f"加载缩放器错误: {e}。如果不是在训练阶段,则初始化新的缩放器。")self.__init__()def process_vibration(self, vibration_signal: np.ndarray, sampling_rate: int = VIBRATION_SAMPLING_RATE) -> np.ndarray:if not isinstance(vibration_signal, np.ndarray) or vibration_signal.ndim != 1:raise ValueError("振动信号必须是一维numpy数组。")if len(vibration_signal) < 2:return np.zeros(8)rms = np.sqrt(np.mean(vibration_signal**2))peak = np.max(np.abs(vibration_signal))kurtosis_val = kurtosis(vibration_signal)skewness_val = skew(vibration_signal)nperseg_val = min(len(vibration_signal), 1024)if nperseg_val == 0:f, psd = np.array([]), np.array([])else:f, psd = sig.welch(vibration_signal, sampling_rate, nperseg=nperseg_val)def get_freq_component(freq_target):if f.size > 0 and psd.size > 0:idx = np.argmin(np.abs(f - freq_target))if np.min(np.abs(f - freq_target)) < ( (f[1]-f[0]) if len(f)>1 else sampling_rate/2 ):return psd[idx]return 0.0freq_1x = get_freq_component(50)freq_2x = get_freq_component(100)freq_3x = get_freq_component(150)low_freq_energy = 0.0if psd.size > 0 and np.sum(psd) > 0:relevant_psd = psd[(f >= 5) & (f <= 50)]if relevant_psd.size > 0:low_freq_energy = np.sum(relevant_psd) / np.sum(psd)return np.array([rms, peak, kurtosis_val, skewness_val, freq_1x, freq_2x, freq_3x, low_freq_energy])def process_temperature(self, temp_series: np.ndarray) -> np.ndarray:if not isinstance(temp_series, np.ndarray) or temp_series.ndim != 1 or len(temp_series) == 0:return np.zeros(5)current_temp = temp_series[-1]avg_temp = np.mean(temp_series)max_temp = np.max(temp_series)temp_rate = np.diff(temp_series).mean() if len(temp_series) > 1 else 0.0stator_temp = temp_series[1] if len(temp_series) > 1 else temp_series[0]return np.array([current_temp, avg_temp, max_temp, temp_rate, stator_temp])def process_pressure(self, pressure_series: np.ndarray, fs: int = PRESSURE_SAMPLING_RATE) -> np.ndarray:if not isinstance(pressure_series, np.ndarray) or pressure_series.ndim != 1 or len(pressure_series) == 0:return np.zeros(4)mean_pressure = np.mean(pressure_series)std_pressure = np.std(pressure_series)max_fluctuation = np.max(np.abs(np.diff(pressure_series))) if len(pressure_series) > 1 else 0.0low_freq_energy = self._calculate_low_freq_energy(pressure_series, fs, 0.5, 2.0)return np.array([mean_pressure, std_pressure, max_fluctuation, low_freq_energy])def process_blade_angle(self, angle_series: np.ndarray, target_angle: float) -> np.ndarray:if not isinstance(angle_series, np.ndarray) or angle_series.ndim != 1 or len(angle_series) == 0:return np.zeros(5)current_angle = angle_series[-1]angle_deviation = np.abs(current_angle - target_angle)angle_rate = np.abs(np.diff(angle_series)).mean() if len(angle_series) > 1 else 0.0stuck_points = np.sum(np.abs(np.diff(angle_series)) < 0.5) / len(angle_series) if len(angle_series) > 1 else 0.0return np.array([current_angle, target_angle, angle_deviation, angle_rate, stuck_points])def process_oil_analysis(self, particles: float, viscosity: float) -> np.ndarray:return np.array([particles, viscosity])def process_current(self, current_series: np.ndarray) -> np.ndarray:if not isinstance(current_series, np.ndarray) or len(current_series) == 0:return np.zeros(2)mean_curr = np.mean(current_series)if mean_curr == 0: return np.array([0.0, 0.0])if current_series.ndim == 1:harmonic_ratio = np.std(current_series) / mean_curr if mean_curr != 0 else 0unbalance_metric = np.max(np.abs(current_series - mean_curr)) / mean_curr if mean_curr != 0 else 0elif current_series.ndim == 0 and len(current_series.shape) == 1 and len(current_series) == 3:harmonic_ratio = 0unbalance_metric = (np.max(current_series) - np.min(current_series)) / mean_curr if mean_curr != 0 else 0else:harmonic_ratio = np.std(current_series.flatten()) / np.mean(current_series.flatten()) if np.mean(current_series.flatten()) !=0 else 0unbalance_metric = 0return np.array([harmonic_ratio, unbalance_metric])def _calculate_low_freq_energy(self, input_signal: np.ndarray, fs: float, f_low: float, f_high: float) -> float:if not isinstance(input_signal, np.ndarray) or len(input_signal) == 0: return 0.0nperseg_val = min(len(input_signal), 128)if nperseg_val < 2 : return 0.0f, psd = sig.welch(input_signal, fs=fs, nperseg=nperseg_val)if psd.size == 0 or np.sum(psd) == 0: return 0.0relevant_psd = psd[(f >= f_low) & (f <= f_high)]if relevant_psd.size == 0 : return 0.0return np.sum(relevant_psd) / np.sum(psd)def normalize_features(self, features: np.ndarray, feature_type: str) -> np.ndarray:if feature_type in self.scalers:if not self.fitted_scalers.get(feature_type):if "predicting_now" in globals() and globals()["predicting_now"]:raise RuntimeError(f"{feature_type} 的缩放器必须在预测前拟合。")return featuresif features.ndim == 1:features_reshaped = features.reshape(1, -1)elif features.ndim == 2 and features.shape[0] == 1:features_reshaped = featureselse:raise ValueError(f"{feature_type} 的特征形状对于缩放不符合预期: {features.shape}")return self.scalers[feature_type].transform(features_reshaped).flatten()return featuresclass CNNModel(nn.Module):def __init__(self, input_channels: int = 1, num_classes: int = NUM_FAULT_CLASSES, example_input_len: int = VIBRATION_RAW_SEQ_LEN):super(CNNModel, self).__init__()self.conv_layers = nn.Sequential(nn.Conv1d(input_channels, 32, kernel_size=7, stride=1, padding=3), nn.ReLU(),nn.MaxPool1d(kernel_size=2, stride=2),nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2), nn.ReLU(),nn.MaxPool1d(kernel_size=2, stride=2),nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),nn.MaxPool1d(kernel_size=2, stride=2))dummy_input = torch.randn(1, input_channels, example_input_len)with torch.no_grad():conv_output_size = self.conv_layers(dummy_input).view(1, -1).size(1)self.fc = nn.Sequential(nn.Linear(conv_output_size, 256), nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, num_classes))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)return self.fc(x)class LSTMModel(nn.Module):def __init__(self, input_size: int = LSTM_FEATURE_SIZE, hidden_size: int = 128, num_classes: int = NUM_FAULT_CLASSES):super(LSTMModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)self.fc = nn.Sequential(nn.Linear(hidden_size * 2, 256), nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, num_classes))def forward(self, x):lstm_out, _ = self.lstm(x)x = lstm_out[:, -1, :]return self.fc(x)class DSEvidenceFusion:def __init__(self, num_classes: int = NUM_FAULT_CLASSES):self.num_classes = num_classesdef combine_evidence(self, evidences: List[np.ndarray], weights: Optional[List[float]] = None) -> np.ndarray:if not evidences:return np.ones(self.num_classes) / self.num_classesif weights is None:weights = [1.0 / len(evidences)] * len(evidences)else:if len(weights) != len(evidences):raise ValueError("权重数量必须与证据数量匹配。")weights = np.array(weights) / np.sum(weights)processed_evidences = []for e in evidences:if not isinstance(e, np.ndarray): e = np.array(e)if e.shape != (self.num_classes,):raise ValueError(f"证据形状不匹配。期望 ({self.num_classes},),得到 {e.shape}")processed_evidences.append(e)combined_belief = np.zeros(self.num_classes)for i, evidence in enumerate(processed_evidences):combined_belief += weights[i] * evidencereturn combined_beliefclass FaultDiagnosisSystem:def __init__(self, cnn_input_len=VIBRATION_RAW_SEQ_LEN, lstm_feat_size=LSTM_FEATURE_SIZE, electrical_feat_size=ELECTRICAL_FEATURE_SIZE):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.data_processor = DataProcessor()self.num_classes = len(self.data_processor.fault_types)self.cnn_model = CNNModel(num_classes=self.num_classes, example_input_len=cnn_input_len).to(self.device)self.lstm_model = LSTMModel(input_size=lstm_feat_size, num_classes=self.num_classes).to(self.device)self.electrical_model = nn.Sequential(nn.Linear(electrical_feat_size, 32), nn.ReLU(),nn.Linear(32, self.num_classes),).to(self.device)self.fusion = DSEvidenceFusion(num_classes=self.num_classes)self.model_weights = [0.4, 0.3, 0.2, 0.1]self.design_pressure_default = 100.0def load_models(self, cnn_path=BEST_CNN_PATH, lstm_path=BEST_LSTM_PATH, electrical_path=BEST_ELECTRICAL_PATH, scalers_path=SCALER_SAVE_PATH):try:self.cnn_model.load_state_dict(torch.load(cnn_path, map_location=self.device))self.lstm_model.load_state_dict(torch.load(lstm_path, map_location=self.device))self.electrical_model.load_state_dict(torch.load(electrical_path, map_location=self.device))print("神经网络模型加载成功。")except Exception as e:print(f"加载神经网络模型时出错: {e}。请确保路径正确且模型匹配。")raiseself.cnn_model.eval()self.lstm_model.eval()self.electrical_model.eval()if scalers_path:self.data_processor.load_scalers(scalers_path)else:print("警告:未提供缩放器路径。除非单独拟合/加载缩放器,否则特征不会被归一化。")def save_models(self, cnn_path=BEST_CNN_PATH, lstm_path=BEST_LSTM_PATH, electrical_path=BEST_ELECTRICAL_PATH, scalers_path=SCALER_SAVE_PATH):torch.save(self.cnn_model.state_dict(), cnn_path)torch.save(self.lstm_model.state_dict(), lstm_path)torch.save(self.electrical_model.state_dict(), electrical_path)print(f"神经网络模型已保存至 {cnn_path}, {lstm_path}, {electrical_path}")if scalers_path:self.data_processor.save_scalers(scalers_path)def _rule_engine_predict(self, **kwargs) -> np.ndarray:fault_probs = np.zeros(self.num_classes)fault_probs[0] = 0.1v_feats = kwargs.get('vibration_features', np.zeros(8))t_feats = kwargs.get('temperature_features', np.zeros(5))p_feats = kwargs.get('pressure_features', np.zeros(4))b_feats = kwargs.get('blade_angle_features', np.zeros(5))o_feats = kwargs.get('oil_features', np.zeros(2))c_feats = kwargs.get('current_features', np.zeros(2))design_pressure = kwargs.get('design_pressure', self.design_pressure_default)if v_feats[2] > 5: fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.7)if v_feats[7] > 0.1: fault_probs[self.data_processor.fault_types_str_to_int("基础松动")] = max(fault_probs[self.data_processor.fault_types_str_to_int("基础松动")], 0.6)if t_feats[0] > 85: fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.7)if b_feats[2] > 5: fault_probs[self.data_processor.fault_types_str_to_int("动叶卡涩")] = max(fault_probs[self.data_processor.fault_types_str_to_int("动叶卡涩")], 0.8)if p_feats[3] > 0.2: fault_probs[self.data_processor.fault_types_str_to_int("喘振")] = max(fault_probs[self.data_processor.fault_types_str_to_int("喘振")], 0.7)if p_feats[0] > 1.1 * design_pressure : fault_probs[self.data_processor.fault_types_str_to_int("积灰堵塞")] = max(fault_probs[self.data_processor.fault_types_str_to_int("积灰堵塞")], 0.6)if c_feats[0] > 0.1: fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")] = max(fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")], 0.7)if c_feats[1] > 0.1: fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")] = max(fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")], 0.6)if o_feats[0] > 50:fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.5)fault_probs[self.data_processor.fault_types_str_to_int("密封失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("密封失效")], 0.4)if o_feats[1] < 70:fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.4)fault_probs[self.data_processor.fault_types_str_to_int("密封失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("密封失效")], 0.3)return self._normalize_prob(fault_probs)def _normalize_prob(self, probs: np.ndarray) -> np.ndarray:if np.sum(probs) == 0:norm_probs = np.ones_like(probs) / len(probs)return norm_probsreturn probs / np.sum(probs)def predict(self, vibration_raw: np.ndarray, temperature_series: np.ndarray, pressure_series: np.ndarray, blade_angle_series: np.ndarray, oil_particles_val: float, oil_viscosity_val: float, current_signal: np.ndarray,target_blade_angle: float, design_pressure: float,vibration_sampling_rate: int = VIBRATION_SAMPLING_RATE,pressure_sampling_rate: int = PRESSURE_SAMPLING_RATE) -> Dict[str, float]:global predicting_nowpredicting_now = Truevibration_features = self.data_processor.process_vibration(vibration_raw, vibration_sampling_rate)temperature_features = self.data_processor.process_temperature(temperature_series)pressure_features = self.data_processor.process_pressure(pressure_series, pressure_sampling_rate)blade_angle_features = self.data_processor.process_blade_angle(blade_angle_series, target_blade_angle)oil_features = self.data_processor.process_oil_analysis(oil_particles_val, oil_viscosity_val)current_features = self.data_processor.process_current(current_signal)norm_vibration_f = self.data_processor.normalize_features(vibration_features, 'vibration_features')norm_temperature_f = self.data_processor.normalize_features(temperature_features, 'temperature')norm_pressure_f = self.data_processor.normalize_features(pressure_features, 'pressure')norm_blade_angle_f = self.data_processor.normalize_features(blade_angle_features, 'blade_angle')norm_oil_f = self.data_processor.normalize_features(oil_features, 'oil_particles')norm_current_f = self.data_processor.normalize_features(current_features, 'current')if len(vibration_raw) < VIBRATION_RAW_SEQ_LEN:vibration_padded = np.pad(vibration_raw, (0, VIBRATION_RAW_SEQ_LEN - len(vibration_raw)), 'constant')else:vibration_padded = vibration_raw[:VIBRATION_RAW_SEQ_LEN]cnn_input = torch.tensor(vibration_padded.reshape(1, 1, -1), dtype=torch.float32).to(self.device)lstm_input_feats_combined = np.concatenate([norm_temperature_f, norm_pressure_f, norm_current_f[:1]])if lstm_input_feats_combined.shape[0] != LSTM_FEATURE_SIZE:raise ValueError(f"LSTM 输入特征大小不匹配。期望 {LSTM_FEATURE_SIZE}, 得到 {lstm_input_feats_combined.shape[0]}")lstm_input = torch.tensor(lstm_input_feats_combined.reshape(1, 1, -1), dtype=torch.float32).to(self.device)if norm_current_f.shape[0] != ELECTRICAL_FEATURE_SIZE:raise ValueError(f"电气模型输入特征大小不匹配。期望 {ELECTRICAL_FEATURE_SIZE}, 得到 {norm_current_f.shape[0]}")electrical_input = torch.tensor(norm_current_f.reshape(1, -1), dtype=torch.float32).to(self.device)with torch.no_grad():cnn_logits = self.cnn_model(cnn_input)cnn_output_probs = torch.softmax(cnn_logits, dim=1).cpu().numpy()[0]lstm_logits = self.lstm_model(lstm_input)lstm_output_probs = torch.softmax(lstm_logits, dim=1).cpu().numpy()[0]electrical_logits = self.electrical_model(electrical_input)electrical_output_probs = torch.softmax(electrical_logits, dim=1).cpu().numpy()[0]rule_output_probs = self._rule_engine_predict(vibration_features=norm_vibration_f,temperature_features=norm_temperature_f,pressure_features=norm_pressure_f,blade_angle_features=norm_blade_angle_f,oil_features=norm_oil_f,current_features=norm_current_f,design_pressure=design_pressure)evidences = [cnn_output_probs, lstm_output_probs, electrical_output_probs, rule_output_probs]fused_probs = self.fusion.combine_evidence(evidences, self.model_weights)predicting_now = Falsereturn {self.data_processor.fault_types[i]: float(fused_probs[i]) for i in range(len(fused_probs))}def train(self, train_loader, val_loader, epochs=10, lr=0.001):if not all(self.data_processor.fitted_scalers.get(s_type) for s_type in self.data_processor.scalers):print("警告:并非所有缩放器都已标记为已拟合。请确保在有代表性的训练数据上调用了 fit_scalers。")# raise RuntimeError("必须在训练前拟合所有数据缩放器。") # 可以选择更严格地在此处报错cnn_optimizer = torch.optim.Adam(self.cnn_model.parameters(), lr=lr)lstm_optimizer = torch.optim.Adam(self.lstm_model.parameters(), lr=lr)electrical_optimizer = torch.optim.Adam(self.electrical_model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()best_val_accuracy = 0.0for epoch in range(epochs):self.cnn_model.train()self.lstm_model.train()self.electrical_model.train()total_train_loss_cnn, total_train_loss_lstm, total_train_loss_elec = 0, 0, 0for batch_idx, batch_content in enumerate(train_loader):# 解包批次内容,假设 batch_content 是 (data_dict, labels_tensor)# data_dict 包含模型所需的各种输入数据# 例如: data_dict['vibration_raw'], data_dict['lstm_input_feature_sequence'], data_dict['electrical_features']# labels_tensor 是对应的故障类别标签# 确保所有张量都被移到正确的设备上 (self.device)# **您必须在此处实现从 train_loader 的 batch_content 中提取和准备数据到模型输入的逻辑**# 以下是概念性的数据准备,您需要用真实逻辑替换# -------------------------------------------------------------------if not isinstance(batch_content, (tuple, list)) or len(batch_content) != 2:print(f"警告: train_loader 的批次内容格式不符合预期。跳过批次 {batch_idx}。")continuedata_dict, labels_batch_cpu = batch_contentif not isinstance(data_dict, dict) or not isinstance(labels_batch_cpu, torch.Tensor):print(f"警告: train_loader 的 data_dict 或 labels 格式不符合预期。跳过批次 {batch_idx}。")continuelabels_batch = labels_batch_cpu.to(self.device)# CNN 输入准备 (假设 'vibration_raw' 在 data_dict 中)if 'vibration_raw' not in data_dict:print(f"警告: 批次 {batch_idx} 中缺少 'vibration_raw'。跳过CNN训练。")else:cnn_input_batch = data_dict['vibration_raw'].unsqueeze(1).to(self.device) # 添加通道维度cnn_optimizer.zero_grad()cnn_logits = self.cnn_model(cnn_input_batch)loss_cnn = criterion(cnn_logits, labels_batch)loss_cnn.backward()cnn_optimizer.step()total_train_loss_cnn += loss_cnn.item()# LSTM 输入准备 (假设 'lstm_input_feature_sequence' 在 data_dict 中)if 'lstm_input_feature_sequence' not in data_dict:print(f"警告: 批次 {batch_idx} 中缺少 'lstm_input_feature_sequence'。跳过LSTM训练。")else:lstm_input_batch = data_dict['lstm_input_feature_sequence'].to(self.device)lstm_optimizer.zero_grad()lstm_logits = self.lstm_model(lstm_input_batch)loss_lstm = criterion(lstm_logits, labels_batch)loss_lstm.backward()lstm_optimizer.step()total_train_loss_lstm += loss_lstm.item()# Electrical Model 输入准备 (假设 'electrical_features' 在 data_dict 中)if 'electrical_features' not in data_dict:print(f"警告: 批次 {batch_idx} 中缺少 'electrical_features'。跳过Electrical模型训练。")else:electrical_input_batch = data_dict['electrical_features'].to(self.device)electrical_optimizer.zero_grad()electrical_logits = self.electrical_model(electrical_input_batch)loss_electrical = criterion(electrical_logits, labels_batch)loss_electrical.backward()electrical_optimizer.step()total_train_loss_elec += loss_electrical.item()# -------------------------------------------------------------------if batch_idx > 0 and batch_idx % 50 == 0:print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)} | "f"CNN Loss: {total_train_loss_cnn/batch_idx:.4f} | "f"LSTM Loss: {total_train_loss_lstm/batch_idx:.4f} | "f"Elec Loss: {total_train_loss_elec/batch_idx:.4f}")avg_loss_cnn = total_train_loss_cnn / len(train_loader) if len(train_loader) > 0 else 0avg_loss_lstm = total_train_loss_lstm / len(train_loader) if len(train_loader) > 0 else 0avg_loss_elec = total_train_loss_elec / len(train_loader) if len(train_loader) > 0 else 0print(f"Epoch {epoch+1}/{epochs} Training Complete. Avg Losses: CNN={avg_loss_cnn:.4f}, LSTM={avg_loss_lstm:.4f}, Elec={avg_loss_elec:.4f}")# --- 验证阶段 ---self.cnn_model.eval()self.lstm_model.eval()self.electrical_model.eval()all_val_preds_fused_indices = []all_val_labels_list = []with torch.no_grad():for batch_val_content in val_loader:# **您必须在此处实现从 val_loader 的 batch_val_content 中提取和准备数据到模型输入的逻辑**# 同时,为规则引擎准备特征(这可能需要对批次中的每个样本单独处理)# -------------------------------------------------------------------if not isinstance(batch_val_content, (tuple, list)) or len(batch_val_content) != 2:print(f"警告: val_loader 的批次内容格式不符合预期。跳过验证批次。")continuedata_dict_val, labels_batch_val_cpu = batch_val_contentif not isinstance(data_dict_val, dict) or not isinstance(labels_batch_val_cpu, torch.Tensor):print(f"警告: val_loader 的 data_dict 或 labels 格式不符合预期。跳过验证批次。")continuelabels_batch_val = labels_batch_val_cpu.to(self.device)batch_size_val = labels_batch_val.size(0)# 获取神经网络模型概率cnn_probs_b = torch.zeros((batch_size_val, self.num_classes), device=self.device)if 'vibration_raw' in data_dict_val:cnn_input_val_batch = data_dict_val['vibration_raw'].unsqueeze(1).to(self.device)cnn_probs_b = torch.softmax(self.cnn_model(cnn_input_val_batch), dim=1)else: print("验证中缺少 'vibration_raw'")lstm_probs_b = torch.zeros((batch_size_val, self.num_classes), device=self.device)if 'lstm_input_feature_sequence' in data_dict_val:lstm_input_val_batch = data_dict_val['lstm_input_feature_sequence'].to(self.device)lstm_probs_b = torch.softmax(self.lstm_model(lstm_input_val_batch), dim=1)else: print("验证中缺少 'lstm_input_feature_sequence'")electrical_probs_b = torch.zeros((batch_size_val, self.num_classes), device=self.device)if 'electrical_features' in data_dict_val:electrical_input_val_batch = data_dict_val['electrical_features'].to(self.device)electrical_probs_b = torch.softmax(self.electrical_model(electrical_input_val_batch), dim=1)else: print("验证中缺少 'electrical_features'")# 为批次中的每个样本应用规则引擎和融合for i in range(batch_size_val):# **为规则引擎提取和归一化第 i 个样本的特征**# 这部分需要您根据 data_dict_val 的内容和您的 Dataset 实现来填充# 例如:# vibration_features_sample_i = self.data_processor.process_vibration(data_dict_val['vibration_raw_unprocessed'][i], ...)# norm_vibration_f_i = self.data_processor.normalize_features(vibration_features_sample_i, 'vibration_features')# ... 其他特征 ...# rule_output_probs_sample_i = self._rule_engine_predict(vibration_features=norm_vibration_f_i, ...)# 为了演示,我们使用一个虚拟的规则输出rule_output_probs_sample_i = self._normalize_prob(np.random.rand(self.num_classes))evidences_sample_i = [cnn_probs_b[i].cpu().numpy(), lstm_probs_b[i].cpu().numpy(), electrical_probs_b[i].cpu().numpy(), rule_output_probs_sample_i]fused_probs_sample_i = self.fusion.combine_evidence(evidences_sample_i, self.model_weights)all_val_preds_fused_indices.append(np.argmax(fused_probs_sample_i))all_val_labels_list.extend(labels_batch_val_cpu.numpy()) # 使用CPU上的标签# -------------------------------------------------------------------if all_val_labels_list: # 确保处理了至少一个验证批次val_accuracy = accuracy_score(all_val_labels_list, all_val_preds_fused_indices)print(f"Epoch {epoch+1}/{epochs}, 验证准确率 (融合后): {val_accuracy:.4f}")print(classification_report(all_val_labels_list, all_val_preds_fused_indices, target_names=[self.data_processor.fault_types[i] for i in range(self.num_classes)], zero_division=0))if val_accuracy > best_val_accuracy:best_val_accuracy = val_accuracyself.save_models() # 使用默认路径保存最佳模型print(f"新最佳模型已保存,验证准确率: {best_val_accuracy:.4f}")else:print(f"Epoch {epoch+1}/{epochs}, 验证: 未处理数据。")print(f"训练完成。最佳验证准确率: {best_val_accuracy:.4f}")def get_fault_type_int(fault_types_dict, fault_name_str):for i, name in fault_types_dict.items():if name == fault_name_str:return iraise ValueError(f"故障名称 '{fault_name_str}' 不在 fault_types 中。")
DataProcessor.fault_types_str_to_int = get_fault_type_intif __name__ == "__main__":globals()["predicting_now"] = Falsesystem = FaultDiagnosisSystem()print(f"系统已初始化。设备: {system.device}")print(f"故障类别数量: {system.num_classes}")print("\n--- (占位符) 准备数据并拟合缩放器 ---")num_training_samples = 100 # 增加样本量以更好地拟合# --- 准备用于拟合缩放器的数据 ---# 这一步至关重要:使用与训练模型时相同的特征提取方法raw_training_data_for_scalers = []for _ in range(num_training_samples):# 模拟从数据集中加载一个样本的所有原始传感器读数sample_data = {'vibration_raw_unprocessed': np.random.normal(0, 1, VIBRATION_RAW_SEQ_LEN + np.random.randint(-100,100)), # 长度可变'temperature_series_unprocessed': np.random.normal(60, 5, np.random.randint(5,15)),'pressure_series_unprocessed': np.random.normal(100, 10, np.random.randint(50,150)),'blade_angle_series_unprocessed': np.random.normal(45, 2, np.random.randint(10,30)),'oil_particles_val_unprocessed': np.random.uniform(10,100),'oil_viscosity_val_unprocessed': np.random.uniform(60,90),'current_signal_unprocessed': np.random.normal(50, 5, 3), # 假设是3相电流值'target_blade_angle_unprocessed': 45.0,'design_pressure_unprocessed': 100.0}raw_training_data_for_scalers.append(sample_data)# 从原始数据中提取特征以拟合缩放器features_for_scaler_fitting = {'vibration_features': [], 'temperature': [], 'pressure': [],'blade_angle': [], 'oil_particles': [], 'current': []}for sample_raw_data in raw_training_data_for_scalers:features_for_scaler_fitting['vibration_features'].append(system.data_processor.process_vibration(sample_raw_data['vibration_raw_unprocessed'][:VIBRATION_RAW_SEQ_LEN]) # 截取或填充到固定长度)features_for_scaler_fitting['temperature'].append(system.data_processor.process_temperature(sample_raw_data['temperature_series_unprocessed']))features_for_scaler_fitting['pressure'].append(system.data_processor.process_pressure(sample_raw_data['pressure_series_unprocessed'], fs=PRESSURE_SAMPLING_RATE))features_for_scaler_fitting['blade_angle'].append(system.data_processor.process_blade_angle(sample_raw_data['blade_angle_series_unprocessed'], sample_raw_data['target_blade_angle_unprocessed']))features_for_scaler_fitting['oil_particles'].append( # 注意:oil_particles的scaler将基于 [particles, viscosity] 数组system.data_processor.process_oil_analysis(sample_raw_data['oil_particles_val_unprocessed'], sample_raw_data['oil_viscosity_val_unprocessed']))features_for_scaler_fitting['current'].append(system.data_processor.process_current(sample_raw_data['current_signal_unprocessed']))system.data_processor.fit_scalers(features_for_scaler_fitting)system.data_processor.save_scalers()print("\n--- (占位符) 训练模型 ---")# --- DummyDataset 现在需要生成更接近真实场景的数据 ---class AdvancedDummyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_classes, data_processor_ref: DataProcessor, for_scaler_data):self.num_samples = num_samplesself.num_classes = num_classesself.data_processor = data_processor_ref # 引用已拟合缩放器的 DataProcessorself.raw_sensor_data_list = for_scaler_data # 使用之前为scaler准备的原始数据作为基础# 生成标签self.labels = torch.randint(0, num_classes, (num_samples,))def __len__(self): return self.num_samplesdef __getitem__(self, idx):# 从预生成的原始数据中获取一个样本(循环使用如果 num_samples > len(raw_sensor_data_list))raw_sample_data = self.raw_sensor_data_list[idx % len(self.raw_sensor_data_list)]# 1. CNN的原始振动数据 (填充/截断)vib_raw_unpr = raw_sample_data['vibration_raw_unprocessed']if len(vib_raw_unpr) < VIBRATION_RAW_SEQ_LEN:cnn_vib_input_np = np.pad(vib_raw_unpr, (0, VIBRATION_RAW_SEQ_LEN - len(vib_raw_unpr)), 'constant')else:cnn_vib_input_np = vib_raw_unpr[:VIBRATION_RAW_SEQ_LEN]# 2. LSTM的特征序列# 为简单起见,我们这里为每个样本只生成一个时间步的特征,然后复制它形成序列# 在真实场景中,您需要处理真正的时序特征temp_f = self.data_processor.normalize_features(self.data_processor.process_temperature(raw_sample_data['temperature_series_unprocessed']), 'temperature')pres_f = self.data_processor.normalize_features(self.data_processor.process_pressure(raw_sample_data['pressure_series_unprocessed'], fs=PRESSURE_SAMPLING_RATE), 'pressure')curr_f_all = self.data_processor.normalize_features(self.data_processor.process_current(raw_sample_data['current_signal_unprocessed']), 'current')lstm_single_step_features_np = np.concatenate([temp_f, pres_f, curr_f_all[:1]]) # 5+4+1 = 10# 复制单步特征形成序列lstm_feature_sequence_np = np.tile(lstm_single_step_features_np, (LSTM_SEQ_LEN, 1))# 3. Electrical模型的特征electrical_features_np = curr_f_all # 使用完整的电流特征 (假设是2个)# 转换为张量processed_data_dict = {'vibration_raw': torch.tensor(cnn_vib_input_np, dtype=torch.float32),'lstm_input_feature_sequence': torch.tensor(lstm_feature_sequence_np, dtype=torch.float32),'electrical_features': torch.tensor(electrical_features_np, dtype=torch.float32)}return processed_data_dict, self.labels[idx]# 使用 AdvancedDummyDataset# 注意:data_processor 实例现在被传递给 Dataset,因为它包含了已拟合的 scalerstrain_dataset = AdvancedDummyDataset(num_training_samples, system.num_classes, system.data_processor, raw_training_data_for_scalers)# 对于验证集,理想情况下也应该有一组独立的原始数据val_dataset = AdvancedDummyDataset(num_training_samples // 2, system.num_classes, system.data_processor, raw_training_data_for_scalers[:num_training_samples//2]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16 if num_training_samples >=16 else 1, shuffle=True, num_workers=0) # num_workers=0 for simplicityval_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16 if num_training_samples//2 >=16 else 1, num_workers=0)print("\n--- 开始实际训练(使用虚拟数据) ---")system.train(train_loader, val_loader, epochs=3) # 仅训练几个epoch作为演示print("\n--- 加载最佳模型 (如果训练中保存了) 并进行预测 ---")system_for_prediction = FaultDiagnosisSystem()try:system_for_prediction.load_models() # 使用默认路径加载except Exception as e:print(f"无法加载已训练的模型 (可能是因为验证准确率未提高,未保存最佳模型): {e}")print("为了预测结构演示,继续使用新初始化的模型 (需要加载scaler)。")system_for_prediction.data_processor.load_scalers() # 确保scaler被加载print("\n--- 对新样本数据进行预测 ---")new_vibration_raw = np.random.normal(0.5, 0.2, VIBRATION_RAW_SEQ_LEN - 50) # 测试填充new_temperature_series = np.array([70, 72, 130, 73, 74.5])new_pressure_series = np.random.normal(105, 12, 200)new_blade_angle_series = np.array([44, 44.5, 45, 45.1, 44.8])new_oil_particles = 55.0new_oil_viscosity = 78.0new_current_signal = np.array([51.0, 48.5, 50.5])target_angle_setting = 45.0current_design_pressure = 100.0try:prediction_result = system_for_prediction.predict(vibration_raw=new_vibration_raw,temperature_series=new_temperature_series,pressure_series=new_pressure_series,blade_angle_series=new_blade_angle_series,oil_particles_val=new_oil_particles,oil_viscosity_val=new_oil_viscosity,current_signal=new_current_signal,target_blade_angle=target_angle_setting,design_pressure=current_design_pressure)print("\n预测的故障概率:")for fault_name, probability in sorted(prediction_result.items(), key=lambda item: -item[1]):if probability > 0.001:print(f" {fault_name}: {probability:.4f}")predicted_fault_idx = np.argmax(list(prediction_result.values()))print(f"主要预测故障: {system_for_prediction.data_processor.fault_types[predicted_fault_idx]}")except RuntimeError as e:if "的缩放器必须在预测前拟合" in str(e):print(f"预测失败: {e}。如果缩放器未正确拟合/加载,这是预期的。")else:print(f"预测期间发生运行时错误: {e}")except Exception as e:import tracebackprint(f"预测期间发生意外错误:\n{traceback.format_exc()}")