当前位置: 首页 > news >正文

【动手学深度学习】4.10 实战Kaggle比赛:预测房价


目录

    • 4.10 实战Kaggle比赛:预测房价
      • 1)数据预处理
      • 2)模型定义与训练
      • 3)模型评估与预测
      • 4)模型训练与预测提交
      • 5)示例超参数(可调)


4.10 实战Kaggle比赛:预测房价

数据来源:Kaggle房价预测比赛

.

1)数据预处理

读取数据

import pandas as pdtrain_data = pd.read_csv('../data/kaggle_house_pred_train.csv')
test_data = pd.read_csv('../data/kaggle_house_pred_test.csv')all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))

处理数值和类别特征

# 数值特征标准化
numeric_feats = all_features.dtypes[all_features.dtypes != 'object'].index
all_features[numeric_feats] = all_features[numeric_feats].apply(lambda x: (x - x.mean()) / x.std()
)
all_features[numeric_feats] = all_features[numeric_feats].fillna(0)# 类别特征独热编码
all_features = pd.get_dummies(all_features, dummy_na=True)

转换为张量

import torchn_train = train_data.shape[0]
X = torch.tensor(all_features[:n_train].values, dtype=torch.float32)
X_test = torch.tensor(all_features[n_train:].values, dtype=torch.float32)
y = torch.tensor(train_data.SalePrice.values, dtype=torch.float32).reshape(-1, 1)

.

2)模型定义与训练

模型定义

from torch import nndef get_net():net = nn.Sequential(nn.Linear(X.shape[1], 1))return net

损失函数:对数均方根误差(log RMSE)

import torch.nn.functional as Fdef log_rmse(net, features, labels):clipped_preds = torch.clamp(net(features), 1, float('inf'))rmse = torch.sqrt(F.mse_loss(torch.log(clipped_preds), torch.log(labels)))return rmse.item()

训练函数

def train(net, train_features, train_labels, test_features, test_labels,num_epochs, learning_rate, weight_decay, batch_size):train_ls, test_ls = [], []dataset = torch.utils.data.TensorDataset(train_features, train_labels)train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)for epoch in range(num_epochs):for X_batch, y_batch in train_iter:optimizer.zero_grad()loss = F.mse_loss(net(X_batch), y_batch)loss.backward()optimizer.step()train_ls.append(log_rmse(net, train_features, train_labels))if test_labels is not None:test_ls.append(log_rmse(net, test_features, test_labels))return train_ls, test_ls

.

3)模型评估与预测

K折交叉验证

def get_k_fold_data(k, i, X, y):fold_size = X.shape[0] // kX_train, y_train = None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)X_part, y_part = X[idx], y[idx]if j == i:X_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = torch.cat((X_train, X_part), 0)y_train = torch.cat((y_train, y_part), 0)return X_train, y_train, X_valid, y_validdef k_fold(k, X_train, y_train, num_epochs, lr, weight_decay, batch_size):train_l_sum, valid_l_sum = 0, 0for i in range(k):data = get_k_fold_data(k, i, X_train, y_train)net = get_net()train_ls, valid_ls = train(net, *data, num_epochs, lr, weight_decay, batch_size)train_l_sum += train_ls[-1]valid_l_sum += valid_ls[-1]print(f'Fold {i+1}, Train log rmse: {train_ls[-1]:.4f}, Valid log rmse: {valid_ls[-1]:.4f}')return train_l_sum / k, valid_l_sum / k

.

4)模型训练与预测提交

使用全部数据训练并预测

def train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size):net = get_net()train(net, train_features, train_labels, None, None, num_epochs, lr, weight_decay, batch_size)preds = net(test_features).detach().numpy()submission = pd.DataFrame({"Id": test_data.Id,"SalePrice": preds.flatten()})submission.to_csv('submission.csv', index=False)

.

5)示例超参数(可调)

k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, X, y, num_epochs, lr, weight_decay, batch_size)
print(f'{k}-fold validation: Avg train log rmse: {train_l:.4f}, Avg valid log rmse: {valid_l:.4f}')
# 最终提交
train_and_pred(X, X_test, y, test_data, num_epochs, lr, weight_decay, batch_size)

.

总结:

  • 核心流程:数据预处理 → 建模 → K折验证 → 全数据训练 → 生成提交文件。

  • 模型简单但有效:线性回归 + 标准化 + One-Hot。

  • log_rmse 是比赛评分标准的重要转化。

.


声明:资源可能存在第三方来源,若有侵权请联系删除!

http://www.dtcms.com/a/271344.html

相关文章:

  • Android API Level 到底是什么?和安卓什么关系?应用发布如何知道自己的版本?优雅草卓伊凡
  • 深度学习预备知识
  • MyBatisPlus-03-扩展功能
  • 基于Matlab多特征融合的可视化指纹识别系统
  • 常见 HTTP 方法的成功状态码200,204,202,201
  • whitt算法之特征向量的尺度
  • 利用编码ai工具cursor写单元测试
  • springMVC06-注解+配置类实现springMVC
  • Java位运算
  • Electron的setContentProtection()会被哪个层级的API捕获?
  • 【TCP/IP】3. IP 地址
  • 储能系统防孤岛保护测试:电网安全的“守门人”
  • C#字符串相关库函数运用梳理总结 + 正则表达式详解
  • 基于YOLOv11的CF-YOLO,如何突破无人机小目标检测?
  • 光伏无人机3D建模:毫秒级精度设计
  • HarmonyOS从入门到精通:自定义组件开发指南(六):组件生命周期详解
  • vue3.2 前端动态分页算法
  • [Python] 区分方法 函数
  • 企业级智能体平台怎么选?字节、腾讯、360、FastGPT选哪个?
  • 【牛客刷题】小欧的选数乘积
  • K8S使用命令多集群管理配置
  • EUDR法案的核心内容,EUDR未来展望,EUDR对全球供应链的影响
  • Excel 常用高级用法
  • [特殊字符] Python 批量生成词云:读取词频 Excel + 自定义背景 + Excel to.png 流程解析
  • 【踩坑】python写超长字符到excel中被截断
  • TDengine 集群部署及启动、扩容、缩容常见问题与解决方案
  • 自建ELK vs 云商日志服务:成本对比分析
  • Apache Tomcat SessionExample 漏洞分析与防范
  • AMIS全栈低代码开发
  • 【NVIDIA-H100】基于 nvidia-smi 数据H100 GPU 功耗异常深度分析与解决方案