当前位置: 首页 > news >正文

sparkml 多列共享labelEncoder pipeline方案

背景描述

比如两列 from城市 to城市

我们的需求是两侧同一个城市必须labelEncoder后编码相同.

代码

"""
需求说明
- 两列城市字段(origin_city、dest_city)表达同一语义,需要共享一套 Label 编码映射。
- 使用 PySpark 框架实现,且编码器可复用,并可整合进 Spark ML Pipeline。
"""from __future__ import annotationsfrom typing import Dict, Iterable, List, Optional, Tuplefrom pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritableclass SharedStringIndexer(Estimator, DefaultParamsReadable, DefaultParamsWritable):"""一个 Estimator:基于多列拟合一次 StringIndexer 的 labels,并产出可对多列统一编码的 Model。Params- inputCols: List[str] 要共享映射的输入列- outputCols: List[str] 对应输出列名(与 inputCols 等长)- handleInvalid: keep/skip/error(含义同 StringIndexer)"""inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)def __init__(self, inputCols: List[str], outputCols: List[str], handleInvalid: str = "keep"):super().__init__()if len(inputCols) != len(outputCols):raise ValueError("inputCols 与 outputCols 长度需一致")self._set(inputCols=inputCols, outputCols=outputCols, handleInvalid=handleInvalid)def _fit(self, dataset):# 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labelsstacked = Nonefor c in self.getOrDefault(self.inputCols):col_df = dataset.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})stacked = col_df if stacked is None else stacked.unionByName(col_df)indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")model = indexer.fit(stacked)labels = list(model.labels)return SharedStringIndexerModel().setParams(inputCols=self.getOrDefault(self.inputCols),outputCols=self.getOrDefault(self.outputCols),handleInvalid=self.getOrDefault(self.handleInvalid),labels=labels,)class SharedStringIndexerModel(Model, DefaultParamsReadable, DefaultParamsWritable):"""Transformer:将拟合得到的 labels 作为共享映射,对多列输出统一索引。为了能够被 PipelineModel.save/load 序列化,labels 作为一个 Param 保存。"""inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)labels = Param(Params._dummy(), "labels", "shared label list", typeConverter=TypeConverters.toListString)def __init__(self):# 必须无参构造以支持反序列化super().__init__()def setParams(self, **kwargs):self._set(**kwargs)return selfdef _transform(self, dataset):label_list = self.getOrDefault(self.labels) or []mapping = {v: i for i, v in enumerate(label_list)}unknown_index = len(label_list)handle_invalid = self.getOrDefault(self.handleInvalid)bmap = dataset.sparkSession.sparkContext.broadcast(mapping)def map_value(v: Optional[str]) -> Optional[int]:if v is None:return None if handle_invalid == "skip" else unknown_index if handle_invalid == "keep" else Noneidx = bmap.value.get(v)if idx is not None:return idxif handle_invalid == "keep":return unknown_indexif handle_invalid == "skip":return Noneraise ValueError(f"未知标签: {v}")enc = F.udf(map_value, T.IntegerType())out = datasetfor src, dst in zip(self.getOrDefault(self.inputCols), self.getOrDefault(self.outputCols)):out = out.withColumn(dst, enc(F.col(src).cast(T.StringType())))return outdef main():spark = SparkSession.builder.appName("shared_string_indexer_pipeline").getOrCreate()spark.sparkContext.setLogLevel("ERROR")data = [(1, "北京", "上海", 1),(2, "上海", "北京", 0),(3, "广州", "深圳", 1),(4, "深圳", "广州", 0),(5, "北京", "广州", 1),(6, "上海", "深圳", 0),]columns = ["id", "origin_city", "dest_city", "label"]df = spark.createDataFrame(data, schema=columns)df_test = spark.createDataFrame([(1, "北京111", "上海", 1)], schema=columns)shared_indexer = SharedStringIndexer(inputCols=["origin_city", "dest_city"],outputCols=["origin_city_idx", "dest_city_idx"],handleInvalid="keep",)assembler = VectorAssembler(inputCols=["origin_city_idx", "dest_city_idx"],outputCol="features",)pipeline = Pipeline(stages=[shared_indexer, assembler])model = pipeline.fit(df)out = model.transform(df)out.select("id", "origin_city", "dest_city", "origin_city_idx", "dest_city_idx", "features").show(truncate=False)# 复用:保存/加载整个 PipelineModel(包含共享映射)# 也可以仅保存 shared_indexer 的 model(通过 pipeline.stages[0] 的写接口)model.write().overwrite().save("./shared_indexer_pipeline_model")print('新数据转换:')   # handleInvalid="keep" 所以这里新枚举值不报错 model.transform(df_test).show(truncate=False)print("加载导出后的模型 新数据转换")loaded_model = PipelineModel.load("./shared_indexer_pipeline_model")loaded_model.transform(df_test).show(truncate=False)# spark.stop()main()

输出

+---+-----------+---------+---------------+-------------+---------+
|id |origin_city|dest_city|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+---------------+-------------+---------+
|1  |北京       |上海     |1              |0            |[1.0,0.0]|
|2  |上海       |北京     |0              |1            |[0.0,1.0]|
|3  |广州       |深圳     |2              |3            |[2.0,3.0]|
|4  |深圳       |广州     |3              |2            |[3.0,2.0]|
|5  |北京       |广州     |1              |2            |[1.0,2.0]|
|6  |上海       |深圳     |0              |3            |[0.0,3.0]|
+---+-----------+---------+---------------+-------------+---------+新数据转换:
+---+-----------+---------+-----+---------------+-------------+---------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+-----+---------------+-------------+---------+
|1  |北京111    |上海     |1    |4              |0            |[4.0,0.0]|
+---+-----------+---------+-----+---------------+-------------+---------+加载导出后的模型 新数据转换
+---+-----------+---------+-----+---------------+-------------+---------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features |
+---+-----------+---------+-----+---------------+-------------+---------+
|1  |北京111    |上海     |1    |4              |0            |[4.0,0.0]|
+---+-----------+---------+-----+---------------+-------------+---------+

http://www.dtcms.com/a/399753.html

相关文章:

  • Python临时文件与目录完全指南:从基础到高级实战
  • 哪个网站收录排名好东台网络推广
  • 免费个人logo设计网站网页设计好的网站
  • 网站COM和CN有啥区别
  • 手机网站如何做才能兼容性各种手机个人网站备案需要什么
  • Least squares prediction and Indicator Variables
  • wordpress站群是什么网站官网建设企业
  • Qt(常用的对话框)
  • 网站被墙怎么做跳转360浏览器打开是2345网址导航
  • Qt QPainter 绘图系统精通指南
  • 宣城网站开发专业制西安巨久科技网站建设
  • LVGL详解
  • 饰品销售网站功能建设seo思维
  • 什么是UT测试
  • 制作网站需要的技术wordpress的xmlrpc
  • Playwright 高级用法全解析:从自动化到测试工程化的进阶指南
  • 视觉SLAM第14讲:现在与未来
  • 系统基模的思想
  • 专业的网站建设企业网站专做脚本的网站
  • 郑州市建设信息网站wordpress整合ucenter
  • 安徽网站开发项目wordpress 后台 重定向循环
  • XSD 文件(XML Schema Definition)简介
  • 什么网站可以做美食怎么做学校网站和微信公众号
  • 寒武纪MLU环境搭建并部署DeepSeek【MLU370-S4】
  • 永康物流网站泉州网站制作推广
  • Hackademic: RTB2靶场渗透
  • 第九届电气、机械与计算机工程国际学术会议(ICEMCE 2025)
  • SimForge™ 功能介绍|「组织管理」赋能仿真研发场景——权限可控、资源可调、成本可溯
  • 【读书笔记】《创始人》
  • 组件化思维(上):视图与基础内容组件的深度探索