基于HetEmotionNet框架的多模态情绪识别系统
设计一个基于HetEmotionNet框架的多模态情绪识别系统。以下是分步实现方案:
一、系统架构设计
class TwoStreamHetGNN(nn.Module):
def __init__(self, modalities=['eeg', 'video']):
super().__init__()
self.modalities = modalities
# 模态特异性编码器
self.encoders = nn.ModuleDict({
'eeg': EEGEncoder(),
'video': VideoEncoder()
})
# 异构图构建模块
self.graph_builder = HetGraphConstructor()
# 多模态融合模块
self.fusion = MultiModalFusion()
# 情绪分类器
self.classifier = EmotionClassifier()
def forward(self, inputs):
# 模态特征提取
features = {mod: encoder(inputs[mod])
for mod, encoder in self.encoders.items()}
# 构建异构图
graph = self.graph_builder(features)
# 图递归处理
graph = self.graph_recurrent(graph)
# 多模态融合
fused = self.fusion(graph)
# 情绪分类
return self.classifier(fused)
二、关键组件实现
- 时空动态卷积网络(STDCN)
class STDCN(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv3d = nn.Conv3d(in_channels, out_channels,
kernel_size=(3,3,3),
padding=(1,1,1))
self.lstm = nn.LSTM(out_channels, 256, batch_first=True)
def forward(self, x):
# 输入形状: [B, C, T, H, W]
x = self.conv3d(x) # [B, C', T, H, W]
x = x.permute(0, 2, 1, 3, 4).contiguous() # [B, T, C', H, W]
B, T, C, H, W = x.shape
x = x.view(B, T, C*H*W) # 时空特征展平
x, _ = self.lstm(x)
return x.mean(dim=1) # 时间维度池化
- 图Transformer块(GTblock)
class GTblock(nn.Module):
def __init__(self, in_dim, n_heads):
super().__init__()
self.attn = GAT(in_dim, n_heads)
self.mlp = nn.Sequential(
nn.Linear(in_dim, 2*in_dim),
nn.ReLU(),
nn.Linear(2*in_dim, in_dim)
)
self.norm1 = nn.LayerNorm(in_dim)
self.norm2 = nn.LayerNorm(in_dim)
def forward(self, x, adj):
x = self.norm1(x + self.attn(x, adj))
x = self.norm2(x + self.mlp(x))
return x
- 异构图构建模块
class HetGraphConstructor(nn.Module):
def __init__(self):
super().__init__()
self.modal_proj = nn.ModuleDict({
'eeg': nn.Linear(128, 64),
'video': nn.Linear(256, 64)
})
def forward(self, features):
nodes = {}
edges = defaultdict(list)
# 模态内节点
for mod, feat in features.items():
nodes[f'{mod}_nodes'] = self.modal_proj[mod](feat)
# 模态间边
for i, mod1 in enumerate(self.modalities):
for j, mod2 in enumerate(self.modalities):
if i != j:
edges[(mod1, 'inter', mod2)] = torch.randn(
len(nodes[f'{mod1}_nodes']),
len(nodes[f'{mod2}_nodes'])
)
return HeteroData(nodes=nodes, edge_index=edges)
三、训练策略
def train():
# 数据加载
train_loader = DataLoader(DEAPDataset(), batch_size=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
for batch in train_loader:
# 多模态数据预处理
inputs = {
'eeg': batch['eeg'].permute(0, 2, 1), # [B, T, C] -> [B, C, T]
'video': batch['video'].permute(0, 4, 1, 2, 3) # [B, T, H, W, C] -> [B, C, T, H, W]
}
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, batch['label'])
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 监控指标
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Loss: {loss.item()}')
四、数据集适配
- DEAP数据集处理
class DEAPDataset(Dataset):
def __init__(self):
self.eeg_data = np.load('deap_eeg.npy') # [N, 40, 8064]
self.video_data = np.load('deap_video.npy') # [N, 300, 224, 224, 3]
self.labels = np.load('deap_labels.npy') # [N, 4]
def __getitem__(self, idx):
return {
'eeg': self.eeg_data[idx], # [40, 8064]
'video': self.video_data[idx], # [300, 224, 224, 3]
'label': self.labels[idx]
}
- MAHNOB-HCI测试流程
def test():
test_loader = DataLoader(MAHNOBDataset(), batch_size=16)
model.eval()
correct = 0
with torch.no_grad():
for batch in test_loader:
outputs = model(batch)
preds = torch.argmax(outputs, dim=1)
correct += (preds == batch['label']).sum().item()
accuracy = correct / len(test_loader.dataset)
print(f'Test Accuracy: {accuracy:.4f}')
五、优化建议
- 模态对齐策略
- 时间对齐:使用动态时间规整(DTW)处理不同步的模态数据
- 特征对齐:通过共享子空间学习(如CCA)对齐模态特征
- 正则化方法
- 图 dropout:在GAT层应用节点/边dropout
- 对比学习:引入模态间对比损失
- 多任务学习:联合预测valence/arousal维度
- 可视化分析
from torch_geometric.utils import to_networkx
def visualize_graph(graph):
nx_graph = to_networkx(graph, to_undirected=True)
pos = nx.spring_layout(nx_graph)
nx.draw(nx_graph, pos, with_labels=True, node_color=graph.y)
plt.show()
这个方案整合了时空特征提取、异构图建模和多模态融合技术,能够有效处理DEAP和MAHNOB-HCI数据集的多模态特性。建议使用PyTorch Geometric进行图结构处理,利用预训练的I3D模型初始化视频流,EEG流使用BCI竞赛预训练权重。