当前位置: 首页 > news >正文

【论文阅读—深度学习处理表格数据】ResNet-like FT Transformer

[2106.11959] Revisiting Deep Learning Models for Tabular Data(论文链接)

https://github.com/yandex-research/tabular-dl-revisiting-models(代码链接)

这篇论文主要讲的是:

  • 现在很多研究人员提出了各种新的深度学习方法来处理表格数据(像 Excel 里的那种数据)。

  • 这些方法都说自己效果不错,但问题是:大家用的实验数据集和测试方法都不一样,所以根本没法直接比较,谁更好也不清楚

  • 此外,目前还缺少一些“简单又好用”的基准模型(baseline),让别人能方便地拿来对比。

于是,作者做了两件事:

  1. 总结梳理了现有的主要深度学习方法,看看它们的思路和优缺点。

  2. 提出了两个简单但效果很强的模型

    • 第一个是类似 ResNet 的结构(在图像识别里很常见),结果发现它在表格数据上也能作为一个强有力的基线模型。

    • 第二个是稍微改造过的 Transformer,用来处理表格数据,效果比很多其他模型都好。

他们用统一的实验流程,把这些模型在很多任务上做了公平比较。最后还把深度学习模型和“梯度提升树”(一种传统但很强的机器学习方法,比如 XGBoost)做对比,结论是:
👉 目前还没有一个方法能在所有情况下一直是最优的。

Introduction

研究背景

  • 深度学习在图像、语音、文本等领域非常成功,所以大家很想把它用到表格数据上(比如工业生产里的传感器数据、金融数据、医疗数据、比赛里的数据表)。

  • 在表格数据任务里,深度学习的主要对手是 梯度提升决策树(GBDT,比如 XGBoost、LightGBM、CatBoost),它们已经很强大了。

  • 如果深度学习能在表格数据上表现好,还有一个额外的好处:它可以和图像、语音等数据结合,做成“多模态”的端到端模型,用梯度优化一起训练,非常方便。

目前的问题

  • 已经有很多研究提出了新的深度学习模型来处理表格数据,但大家用的数据集和实验流程都不一样。

  • 这就导致:

    • 论文之间没法公平比较,不知道到底哪个模型更好。

    • 也不清楚深度学习是否真的比 GBDT 更强。

  • 此外,虽然模型很多,但还缺少一些 既简单又稳定的“基准模型”,能让大家更容易对比和复现。现在大家常用的 MLP(多层感知机)太弱,根本挑战不到其他模型。

作者的思路

  • 他们怀疑:其实一些在图像和 NLP 里已经验证过的经典深度学习模块,在表格任务上可能被忽视了。

  • 所以,他们尝试把这些“成熟的架构”改造一下用在表格数据上,结果提出了两个简单的模型:

    • ResNet-like(残差网络的变体)

    • FT-Transformer(他们改造过的 Transformer)

论文主要贡献

  • 系统地评估了各种表格数据的深度学习模型。

  • 发现 ResNet-like 是一个很强的基线,推荐以后大家用它做比较。

  • 提出了 FT-Transformer,一个简单但强大的新方案。

  • 最后确认,目前没有任何模型能在所有任务上完全超越 GBDT。

Related work

传统方法:决策树集成(GBDT)

目前在表格数据任务里,GBDT(梯度提升树) 是最常用、最强的“浅层”方法。有三个非常流行的库:XGBoost、LightGBM、CatBoost,它们在细节上有区别,但在大多数任务上性能差不多。所以,GBDT 一直是机器学习比赛和工业界的首选。

深度学习方法:主要有三大类

可微分的树

传统的树模型不能通过梯度下降训练(因为分裂是“硬”的,没法求导)。有些研究者就把树的分裂函数“平滑化”,让树也能用梯度优化,方便和深度学习结合。这些方法在部分任务上能超过 GBDT,但作者的实验发现,它们并不比一个调好参数的 ResNet 稳定

基于注意力的模型

因为注意力机制(Transformer)在图像、文本等领域大获成功,有人也尝试把它用在表格数据上。结果表明:虽然一些注意力模型表现还行,但经过精调的 ResNet 依然更强。不过,作者提出了一种更合适的 Transformer 改造版本(FT-Transformer),在大多数任务上超过了 ResNet。

显示建模特征交叉

在推荐系统和点击率预测中,很多人批评 MLP(多层感知机),因为它不擅长表达特征之间的“乘积关系”(比如“年龄 × 收入”这种交互)。有些工作尝试在 MLP 里加上“特征乘积”的模块。但作者的实验发现,这些方法并不比调好的基线模型更好。

其他杂项模型

还有一些方法不能明确归类到上面三类,比如一些比较特殊的架构设计,但效果也没有形成统一结论。

Models for tabular data

研究目标

找出简单但好用的深度学习基线模型,而不是复杂又难调的架构。尽量使用大家熟悉的深度学习“积木”(比如 MLP、ResNet、Transformer)来搭建模型。这样模型更容易理解、实现,也更容易调出好结果。

统一任务设定

数据集包含数值型特征(numerical)和类别型特征(categorical)。数据被分成 训练集(train)、验证集(val)、测试集(test)。任务分为三类:二分类、多分类和回归。

三种主要模型

MLP(多层感知机)
  • 最基础的深度学习模型。

  • 结构:线性层 → ReLU → Dropout → 线性层 → … → 输出

  • 简单易用,但在表格数据任务里,往往不是最优解。

ResNet(残差网络改造版)
  • 原本 ResNet 在图像识别很成功(解决了“深层网络难训练”的问题)。

  • 作者尝试在表格数据上用一个简化版 ResNet

  • 核心思想:在每一层里,加一个“捷径连接”(skip connection),让输入可以直接绕过部分计算流向输出。

  • 这样有助于训练,能学到更深层次的特征。

  • 作者认为它会比 MLP 更适合一些需要“深层表达”的任务,其实就是在原来MLP的基础上加上残差链接

FT-Transformer(特征标记器 + Transformer)
  • 这是作者的创新点。

  • 步骤:

    1. 把每个特征(数值/类别)先转换成embedding(向量表示)

    2. Transformer 模型处理这些特征向量(像处理句子里的词一样)。

    3. 最后用 [CLS] token 的输出作为整体表示,再做预测。

  • 好处:Transformer 的自注意力机制可以自动学到特征之间的关系,特别适合表格这种“混合信息”的数据。

  • 作者发现,这个方法在多数任务上能超过 ResNet。

Experiments

作者想回答两个核心问题:

  1. 不同的深度学习(DL)架构在表格数据上的真实表现差别有多大?

  2. 和传统的梯度提升树(GBDT,例如 XGBoost、CatBoost)相比,DL 模型到底有没有优势?

数据集和实验

我们使用了11 个多样化的公开数据集(详细描述见附录)。对于每个数据集,都只有一个固定的训练集–验证集–测试集划分,因此所有算法都使用相同的划分。

这些数据集包括:

  • California Housing (CA):房价预测(房地产数据,Kelley Pace 和 Barry, 1997)

  • Adult (AD):收入预测(Kohavi, 1996)

  • Helena (HE):匿名数据集(Guyon 等, 2019)

  • Jannis (JA):匿名数据集(Guyon 等, 2019)

  • Higgs (HI):模拟的物理粒子数据(Baldi 等, 2014;使用 OpenML 上的 98K 样本版本,Vanschoren 等, 2014)

  • ALOI (AL):图像数据(Geusebroek 等, 2005)

  • Epsilon (EP):模拟物理实验数据

  • Year (YE):音频特征数据(Bertin-Mahieux 等, 2011)

  • Covertype (CO):森林特征数据(Blackard 和 Dean, 2000)

  • Yahoo (YA):搜索查询数据(Chapelle 和 Chang, 2011)

  • Microsoft (MI):搜索查询数据(Qin 和 Liu, 2013)

在学习排序(learning-to-rank)任务上,我们采用逐点(pointwise)方法,并把 MicrosoftYahoo 的排序问题视为回归问题来处理。

实验结果(简单总结)

  1. MLP 依然能打:普通的多层感知机(MLP)经过认真调参,表现不差,是很好的基线。

  2. ResNet-like 很强:在 MLP 上加残差连接,效果显著提升,变成一个强力的 baseline。

  3. FT-Transformer 最稳:改造过的 Transformer(FT-T)在大多数任务上表现最好,而且在各种数据集上都很稳定。

  4. NODE(树状神经网络)也不错,但结构复杂,而且在很多数据集上还是比不过 ResNet。

  5. 集成(Ensemble)有效果:把多个模型预测结果平均,性能更强;ResNet 和 FT-Transformer 的提升尤其明显。

Analysis

FT-Transformer 什么时候比 ResNet 更好?

作者做了一个合成实验

  • 构造一些虚拟任务,把目标值设计成 一部分像 GBDT 容易拟合的函数(fGBDT),另一部分像 深度学习模型容易拟合的函数(fDL)

  • 用一个参数 α 来控制这两部分的比例。

结果:

  • 在更偏向 ResNet 友好的任务上(fDL 占主导),ResNet 和 FT-Transformer 表现差不多,都比 CatBoost 强。

  • 但当任务更接近 GBDT 风格时,ResNet 的性能掉下去了,而 FT-Transformer 依然稳健。

结论:FT-Transformer 更“万能”,既能处理 DL 友好的任务,也能处理 GBDT 擅长的任务,而 ResNet 更挑剔。

Ablation Study(消融实验)

作者测试了 FT-Transformer 的一些设计细节到底有没有用:

和 AutoInt 的比较

  • AutoInt 也是基于注意力的模型,但细节设计和 Transformer 不一样。

  • 实验表明:FT-Transformer 的主干(标准 Transformer 结构)更强,整体优于 AutoInt。

Feature biases(特征偏置项)有没有必要?

  • FT-Transformer 在 embedding 层里为每个特征加了一个偏置(bias)。

  • 去掉偏置再跑实验 → 性能下降。

  • 说明偏置设计是有帮助的。

从注意力图获取特征重要性

作者提出用 Transformer 的注意力图(attention maps) 来解释模型:

  • 计算 [CLS] token 的注意力分布 → 得到每个特征的重要性分数。

  • 优点:很高效,只要一次前向传播就能得到。

对比方法:

  • Integrated Gradients (IG):通用但很慢。

  • Permutation Test (PT):可靠但要做很多次前向传播(特征数 + 1)。

实验结果:

  • 注意力图方法(AM)的特征重要性和 PT 的排名高度相关,效果和 IG 差不多。

  • 但效率远高于 IG 和 PT。

结论:用注意力图来解释 FT-Transformer 的特征重要性,是一种 便宜又有效的可解释性方法

加载FT-Transformer处理表格数据

import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from torch.utils.data import Dataset, DataLoader
from rtdl_revisiting_models import FTTransformer  # 注意要安装 rtdl-revisiting# ========= 1. 加载数据 =========
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
columns = ["age", "workclass", "fnlwgt", "education", "education-num","marital-status", "occupation", "relationship", "race", "sex","capital-gain", "capital-loss", "hours-per-week", "native-country", "income"
]
df = pd.read_csv(url, names=columns, sep=",", skipinitialspace=True)# ========= 2. 预处理 =========
# 标签:>50K / <=50K
df["income"] = LabelEncoder().fit_transform(df["income"])  # <=50K=0, >50K=1# 数值特征 / 类别特征
num_features = ["age", "fnlwgt", "education-num", "capital-gain", "capital-loss", "hours-per-week"]
cat_features = [col for col in df.columns if col not in num_features + ["income"]]# 数值标准化
scaler = StandardScaler()
df[num_features] = scaler.fit_transform(df[num_features])# 类别编码
for col in cat_features:df[col] = LabelEncoder().fit_transform(df[col])X = df[num_features + cat_features].values
y = df["income"].valuesX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# ========= 3. Dataset & Dataloader =========
class TabDataset(Dataset):def __init__(self, X, y):self.X = torch.tensor(X, dtype=torch.float32)self.y = torch.tensor(y, dtype=torch.long)def __len__(self): return len(self.X)def __getitem__(self, idx): return self.X[idx], self.y[idx]train_loader = DataLoader(TabDataset(X_train, y_train), batch_size=256, shuffle=True)
test_loader = DataLoader(TabDataset(X_test, y_test), batch_size=256)# ========= 4. 构建 FT-Transformer =========
device = "cuda" if torch.cuda.is_available() else "cpu"model = FTTransformer(d_numerical=len(num_features),categories=[int(df[col].nunique()) for col in cat_features],  # 每个类别特征的取值个数d_token=32,n_layers=3,n_heads=4,d_ffn_factor=2.0,attention_dropout=0.1,ffn_dropout=0.1,residual_dropout=0.1,d_out=2  # 二分类
).to(device)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)# ========= 5. 训练 =========
for epoch in range(5):  # 演示只训练 5 轮model.train()for xb, yb in train_loader:xb, yb = xb.to(device), yb.to(device)out = model(xb)loss = criterion(out, yb)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch+1} | Loss {loss.item():.4f}")# ========= 6. 测试 =========
model.eval()
correct, total = 0, 0
with torch.no_grad():for xb, yb in test_loader:xb, yb = xb.to(device), yb.to(device)out = model(xb)preds = out.argmax(1)correct += (preds == yb).sum().item()total += yb.size(0)print(f"Test Accuracy: {correct / total:.4f}")

http://www.dtcms.com/a/366848.html

相关文章:

  • 当电力设计遇上AI:良策金宝AI如何重构行业效率边界?
  • 学习嵌入式的第三十三天——网络编程
  • HTB Sau
  • 服务器异常磁盘写排查手册 · 已删除文件句柄篇
  • 稠密矩阵和稀疏矩阵的对比
  • C++面试突击(1)
  • 【面试】MySQL 面试常见优化问题
  • 面试官:如何确保动态线程池任务都执行完?
  • 计算机网络模型入门指南:分层原理与各层作用
  • EasyExcel:阿里开源的高效 Excel 处理工具,轻松解决 POI 内存溢出问题
  • SolidWorks对电脑的硬件配置要求具体有哪些
  • [Sublime Text]-显示菜单栏
  • 《云原生深坑实录:让团队卡壳的不是配置,是底层逻辑盲区》
  • 【Dify】使用工具节点实现 API 接口调用与 JSON 处理
  • 25高教社杯数模国赛【B题超高质量思路+问题分析】
  • 具身智能多模态感知与场景理解:视觉探索
  • 第二阶段WinForm-13:图表控件,N层架构,Dapper
  • 数据结构与排序算法:从理论到场景,解锁高效数据处理的核心逻辑
  • 【项目思路】基于STM32+ZigBee的智能家居--浴室场景设计
  • 服务器异常负载排查手册 · 隐蔽进程篇
  • QT面经(含相关知识)
  • elasticsearch学习(五)文档CRUD
  • 前端跨域终极指南:3 种优雅解决方案 + 可运行 Demo
  • App UI 自动化环境搭建指南
  • Java Stream 流式操作举例
  • QT Creator 使用
  • 【一文了解】C#泛型
  • 数据库集成:使用 SQLite 与 Electron
  • 新电脑硬盘如何分区?3个必知技巧避免“空间浪费症”!
  • [技术革命]Harmonizer:仅20MB模型如何实现8K图像_视频的完美和谐化?