sparkml 多列共享labelEncoder pipeline方案
背景描述
比如两列 from城市 to城市
我们的需求是两侧同一个城市必须labelEncoder后编码相同.
代码
"""
需求说明
- 两列城市字段(origin_city、dest_city)表达同一语义,需要共享一套 Label 编码映射。
- 使用 PySpark 框架实现,且编码器可复用,并可整合进 Spark ML Pipeline。
"""from __future__ import annotationsfrom typing import Dict, Iterable, List, Optional, Tuplefrom pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritableclass SharedStringIndexer(Estimator, DefaultParamsReadable, DefaultParamsWritable):"""一个 Estimator:基于多列拟合一次 StringIndexer 的 labels,并产出可对多列统一编码的 Model。Params- inputCols: List[str] 要共享映射的输入列- outputCols: List[str] 对应输出列名(与 inputCols 等长)- handleInvalid: keep/skip/error(含义同 StringIndexer)"""inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)def __init__(self, inputCols: List[str], outputCols: List[str], handleInvalid: str = "keep"):super().__init__()if len(inputCols) != len(outputCols):raise ValueError("inputCols 与 outputCols 长度需一致")self._set(inputCols=inputCols, outputCols=outputCols, handleInvalid=handleInvalid)def _fit(self, dataset):# 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labelsstacked = Nonefor c in self.getOrDefault(self.inputCols):col_df = dataset.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})stacked = col_df if stacked is None else stacked.unionByName(col_df)indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")model = indexer.fit(stacked)labels = list(model.labels)return SharedStringIndexerModel().setParams(inputCols=self.getOrDefault(self.inputCols),outputCols=self.getOrDefault(self.outputCols),handleInvalid=self.getOrDefault(self.handleInvalid),labels=labels,)class SharedStringIndexerModel(Model, DefaultParamsReadable, DefaultParamsWritable):"""Transformer:将拟合得到的 labels 作为共享映射,对多列输出统一索引。为了能够被 PipelineModel.save/load 序列化,labels 作为一个 Param 保存。"""inputCols = Param(Params._dummy(), "inputCols", "input columns", typeConverter=TypeConverters.toListString)outputCols = Param(Params._dummy(), "outputCols", "output columns", typeConverter=TypeConverters.toListString)handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid labels", typeConverter=TypeConverters.toString)labels = Param(Params._dummy(), "labels", "shared label list", typeConverter=TypeConverters.toListString)def __init__(self):# 必须无参构造以支持反序列化super().__init__()def setParams(self, **kwargs):self._set(**kwargs)return selfdef _transform(self, dataset):label_list = self.getOrDefault(self.labels) or []mapping = {v: i for i, v in enumerate(label_list)}unknown_index = len(label_list)handle_invalid = self.getOrDefault(self.handleInvalid)bmap = dataset.sparkSession.sparkContext.broadcast(mapping)def map_value(v: Optional[str]) -> Optional[int]:if v is None:return None if handle_invalid == "skip" else unknown_index if handle_invalid == "keep" else Noneidx = bmap.value.get(v)if idx is not None:return idxif handle_invalid == "keep":return unknown_indexif handle_invalid == "skip":return Noneraise ValueError(f"未知标签: {v}")enc = F.udf(map_value, T.IntegerType())out = datasetfor src, dst in zip(self.getOrDefault(self.inputCols), self.getOrDefault(self.outputCols)):out = out.withColumn(dst, enc(F.col(src).cast(T.StringType())))return outdef main():spark = SparkSession.builder.appName("shared_string_indexer_pipeline").getOrCreate()spark.sparkContext.setLogLevel("ERROR")data = [(1, "北京", "上海", 1),(2, "上海", "北京", 0),(3, "广州", "深圳", 1),(4, "深圳", "广州", 0),(5, "北京", "广州", 1),(6, "上海", "深圳", 0),]columns = ["id", "origin_city", "dest_city", "label"]df = spark.createDataFrame(data, schema=columns)df_test = spark.createDataFrame([(1, "北京111", "上海", 1)], schema=columns)shared_indexer = SharedStringIndexer(inputCols=["origin_city", "dest_city"],outputCols=["origin_city_idx", "dest_city_idx"],handleInvalid="keep",)assembler = VectorAssembler(inputCols=["origin_city_idx", "dest_city_idx"],outputCol="features",)pipeline = Pipeline(stages=[shared_indexer, assembler])model = pipeline.fit(df)out = model.transform(df)out.select("id", "origin_city", "dest_city", "origin_city_idx", "dest_city_idx", "features").show(truncate=False)# 复用:保存/加载整个 PipelineModel(包含共享映射)# 也可以仅保存 shared_indexer 的 model(通过 pipeline.stages[0] 的写接口)model.write().overwrite().save("./shared_indexer_pipeline_model")print('新数据转换:') # handleInvalid="keep" 所以这里新枚举值不报错 model.transform(df_test).show(truncate=False)print("加载导出后的模型 新数据转换")loaded_model = PipelineModel.load("./shared_indexer_pipeline_model")loaded_model.transform(df_test).show(truncate=False)# spark.stop()main()
输出
+---+-----------+---------+---------------+-------------+---------+ |id |origin_city|dest_city|origin_city_idx|dest_city_idx|features | +---+-----------+---------+---------------+-------------+---------+ |1 |北京 |上海 |1 |0 |[1.0,0.0]| |2 |上海 |北京 |0 |1 |[0.0,1.0]| |3 |广州 |深圳 |2 |3 |[2.0,3.0]| |4 |深圳 |广州 |3 |2 |[3.0,2.0]| |5 |北京 |广州 |1 |2 |[1.0,2.0]| |6 |上海 |深圳 |0 |3 |[0.0,3.0]| +---+-----------+---------+---------------+-------------+---------+新数据转换: +---+-----------+---------+-----+---------------+-------------+---------+ |id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features | +---+-----------+---------+-----+---------------+-------------+---------+ |1 |北京111 |上海 |1 |4 |0 |[4.0,0.0]| +---+-----------+---------+-----+---------------+-------------+---------+加载导出后的模型 新数据转换 +---+-----------+---------+-----+---------------+-------------+---------+ |id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|features | +---+-----------+---------+-----+---------------+-------------+---------+ |1 |北京111 |上海 |1 |4 |0 |[4.0,0.0]| +---+-----------+---------+-----+---------------+-------------+---------+