使用bert或roberta模型做分类训练时,分类数据不平衡时,可以采取哪些优化的措施
📝 学习笔记:BERT/RoBERTa 文本分类中的类别不平衡问题优化策略
一、问题背景
在使用 BERT、RoBERTa 等预训练语言模型进行文本分类时,若训练数据中各类别样本数量严重不均衡(如 95% 正例、5% 负例),模型容易偏向多数类,导致对少数类识别能力差。需从数据、模型、训练、评估四个维度综合优化。
二、优化策略汇总
✅ 1. 数据层面处理
方法 | 说明 | 注意事项 |
---|---|---|
过采样(Oversampling) | 复制或增强少数类样本 | 避免简单复制导致过拟合 |
欠采样(Undersampling) | 随机删除多数类样本 | 可能丢失重要信息,慎用 |
数据增强 | 生成语义相近的新样本: • 回译(Back-Translation) • 同义词替换 • EDA(随机插入/删除/交换) • 使用 T5/BART 生成 | 文本数据不适合 SMOTE,应注重语义一致性 |
✅ 2. 模型与损失函数优化
方法 | 原理 | 实现方式 |
---|---|---|
类别加权交叉熵 | 为少数类分配更高损失权重 | sklearn.utils.class_weight.compute_class_weight('balanced', ...) nn.CrossEntropyLoss(weight=class_weights) |
Focal Loss | 降低易分样本权重,聚焦难样本 | 公式:( FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) ) 适合极端不平衡场景 |
Label Smoothing | 软化标签,防止模型对多数类过度自信 | 间接提升泛化能力 |
✅ 3. 训练策略改进
策略 | 作用 | 工具示例 |
---|---|---|
分层采样(Stratified Sampling) | 每个 batch 中保持类别比例 | torch.utils.data.WeightedRandomSampler |
两阶段训练 | 先在平衡数据上微调,再在原始数据上精调 | 第一阶段:过采样数据;第二阶段:原始数据 + 小学习率 |
集成学习 | 多模型融合提升鲁棒性 | 不同采样策略训练多个 BERT 模型,投票或加权平均 |
✅ 4. 评估与后处理
方法 | 说明 |
---|---|
使用合理指标 | 避免 Accuracy!推荐: • Macro-F1(各类别平等对待) • PR-AUC(对少数类敏感) • MCC(Matthews 相关系数) |
调整分类阈值 | 默认 0.5 不适用 → 在验证集上搜索最优阈值(如最大化 F1) → 多分类可对每类单独调阈值 |
✅ 5. 高级技巧(进阶)
- Prompt-based Learning / Few-shot Learning:利用 BERT 的强泛化能力,在极少数样本下学习。
- 课程学习(Curriculum Learning):先学简单样本,再逐步引入难样本和少数类。
- 使用 Hugging Face Trainer 自定义损失:重写
compute_loss
方法集成 Focal Loss 或加权损失。
三、实践建议(优先级排序)
-
首选组合:
类别加权损失 + 分层采样(WeightedRandomSampler)
→ 简单、高效、无需修改数据。 -
若少数类样本极少(<100):
结合 数据增强(如回译) + Focal Loss。 -
评估必须看 Macro-F1 / PR-AUC,不能只看 Accuracy!
-
调阈值:训练完成后,在验证集上优化分类阈值。
四、代码片段速查
# 1. 计算类别权重
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)# 2. PyTorch 加权损失
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float))# 3. 分层采样
from torch.utils.data import WeightedRandomSampler
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = weights[dataset.targets]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
dataloader = DataLoader(dataset, sampler=sampler, batch_size=16)
五、总结
类别不平衡 ≠ 模型不行,而是训练策略需调整!
通过“加权损失 + 智能采样 + 合理评估 + 阈值调优”四步法,可显著提升 BERT/RoBERTa 在不平衡分类任务中的表现。