当前位置: 首页 > news >正文

pyspark自定义udf函数

在 PySpark 中,UDF(User-Defined Function,用户自定义函数) 是扩展 Spark 功能的核心工具,用于处理内置函数(如pyspark.sql.functions中的函数)无法覆盖的自定义逻辑(如复杂字符串处理、自定义数值计算、多列联动计算等)。

Spark 是强类型计算引擎,因此自定义 UDF 必须明确声明输入 / 输出数据类型,且需避免行式处理的性能瓶颈(推荐使用基于 Arrow 的 Pandas UDF 优化)。本文将从基础概念、普通 UDF、高性能 Pandas UDF、注意事项四个维度全面讲解 PySpark UDF 的使用。
一、UDF 核心概念
作用:对 DataFrame 的列数据进行自定义转换,补充 Spark SQL 内置函数的不足(如自定义正则匹配、业务逻辑计算等)。
依赖:
导入 UDF 装饰器 / 函数:from pyspark.sql.functions import udf
导入数据类型(声明返回类型):from pyspark.sql.types import StringType, IntegerType, ArrayType, StructType等
执行原理:
普通 Python UDF:基于行式处理,需在 Python 解释器与 Spark JVM 之间进行序列化 / 反序列化(存在 GIL 锁瓶颈,数据量大时效率低)。
Pandas UDF(Vectorized UDF):基于Apache Arrow批量处理数据,直接操作 Pandas Series/DataFrame,性能比普通 UDF 提升 5-10 倍。
二、普通 UDF 的创建与使用
普通 UDF 适用于简单逻辑或小规模数据,创建步骤分为:定义 Python 函数 → 包装为 UDF → 应用到 DataFrame。

  1. 基础示例:字符串转换(无参数 UDF)
    需求:将 DataFrame 的字符串列转为小写,并处理null值。
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType# 1. 初始化SparkSession
spark = SparkSession.builder.appName("BasicUDF").getOrCreate()# 2. 构造测试数据
df = spark.createDataFrame(data=[("Hello PySpark",), ("UDF Example",), (None,)],  # 包含null值schema=["original_text"]
)# 3. 定义Python自定义函数(需处理null,否则返回null)
def to_lower_case(s: str) -> str:if s is None:  # 显式处理null,避免逻辑错误return "empty"return s.lower()# 4. 包装为PySpark UDF(指定返回类型为StringType)
# 方式1:用udf()函数包装
to_lower_udf = udf(to_lower_case, returnType=StringType())
# 方式2:用@udf装饰器(更简洁)
# @udf(returnType=StringType())
# def to_lower_case(s: str) -> str: ...# 5. 应用UDF到DataFrame(用withColumn/select)
df_result = df.withColumn("lower_text",  # 新列名to_lower_udf("original_text")  # 对original_text列应用UDF
)# 查看结果
df_result.show(truncate=False)

输出:

+----------------+----------------+
|original_text   |lower_text      |
+----------------+----------------+
|Hello PySpark   |hello pyspark   |
|UDF Example     |udf example     |
|null            |empty           |
+----------------+----------------+
  1. 带参数的 UDF
    需求:给数值列的每个元素加上一个固定参数值(如给所有数值加 5)。
    需用functools.partial传递固定参数,或用 lambda 表达式包装。
from pyspark.sql.functions import col
from pyspark.sql.types import IntegerType
from functools import partial# 1. 构造测试数据
df_num = spark.createDataFrame(data=[(10,), (20,), (30,), (None,)],schema=["num"]
)# 2. 定义带参数的Python函数
def add_fixed_value(x: int, fixed: int) -> int:if x is None:return 0return x + fixed# 3. 传递固定参数(如fixed=5),包装为UDF
# 方式1:用partial固定参数
add_5_udf = udf(partial(add_fixed_value, fixed=5), IntegerType())
# 方式2:用lambda包装(适合简单参数)
# add_5_udf = udf(lambda x: add_fixed_value(x, 5), IntegerType())# 4. 应用UDF
df_num_result = df_num.withColumn("num_add_5",add_5_udf(col("num"))  # 仅传入DataFrame列作为第一个参数
)df_num_result.show()

输出:

+----+----------+
| num|num_add_5 |
+----+----------+
|  10|        15|
|  20|        25|
|  30|        35|
|null|         0|
+----+----------+
  1. 处理复杂类型的 UDF(Array/Map/Struct)
    Spark 支持复杂数据类型(如ArrayType数组、MapType字典、StructType结构体),UDF 可直接处理这些类型。
    示例:计算数组列的总和
from pyspark.sql.types import ArrayType, IntegerType# 1. 构造含数组列的DataFrame
df_array = spark.createDataFrame(data=[([1,2,3],), ([4,5,6,7],), (None,), ([])],  # 空数组、null数组schema=["numbers"]
)# 2. 定义处理数组的函数
def calculate_array_sum(arr: list) -> int:if arr is None or len(arr) == 0:  # 处理null和空数组return 0return sum(arr)# 3. 包装UDF(输入为ArrayType,输出为IntegerType)
array_sum_udf = udf(calculate_array_sum, IntegerType())# 4. 应用UDF
df_array_result = df_array.withColumn("array_total",array_sum_udf("numbers")
)df_array_result.show()

输出:

+------------+-----------+
|     numbers|array_total|
+------------+-----------+
|   [1, 2, 3]|          6|
|[4, 5, 6, 7]|         22|
|        null|          0|
|          []|          0|
+------------+-----------+

三、高性能 Pandas UDF(Vectorized UDF)
普通 Python UDF 因行式处理效率低,数据量超过 100 万行时强烈推荐使用 Pandas UDF。其基于 Apache Arrow 实现数据零拷贝传输,批量处理 Pandas Series/DataFrame,性能显著提升。

Pandas UDF 分为两类:

Scalar Pandas UDF:输入 / 输出为标量(Pandas Series),对应普通 UDF 的批量版本。
Grouped Map Pandas UDF:分组处理(输入为分组后的 Pandas DataFrame,输出为新 DataFrame)。

  1. 基础:Scalar Pandas UDF
    需求:计算数值列的平方(批量处理)。
    需用@pandas_udf装饰器,并指定返回类型,函数参数和返回值均为 Pandas Series。
from pyspark.sql.functions import pandas_udf
import pandas as pd# 1. 开启Arrow加速(Spark 3.0+默认开启,低版本需显式设置)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")# 2. 构造测试数据
df_pd = spark.createDataFrame(data=[(1,), (2,), (3,), (4,), (5,)],schema=["num"]
)# 3. 定义Scalar Pandas UDF(用@pandas_udf装饰,指定返回类型)
@pandas_udf(IntegerType())  # 括号内为返回类型
def square_pd(s: pd.Series) -> pd.Series:# 直接用Pandas Series的矢量化操作,无需循环return s ** 2# 4. 应用UDF(用法与普通UDF一致)
df_pd_result = df_pd.withColumn("num_square",square_pd("num")
)df_pd_result.show()

输出:

+---+----------+
|num|num_square|
+---+----------+
|  1|         1|
|  2|         4|
|  3|         9|
|  4|        16|
|  5|        25|
+---+----------+
  1. Grouped Map Pandas UDF
    需求:按分组计算每个组的平均值,并为每组的每行添加 “组内平均值” 列。
    需指定输出 DataFrame 的 Schema,函数输入为分组后的 Pandas DataFrame,输出为处理后的 Pandas DataFrame。
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType# 1. 构造带分组列的DataFrame
df_group = spark.createDataFrame(data=[("A", 10), ("A", 20), ("A", 30), ("B", 40), ("B", 50)],schema=["group_id", "value"]
)# 2. 定义输出Schema(必须与函数返回的DataFrame结构一致)
output_schema = StructType([StructField("group_id", StringType(), nullable=True),StructField("value", IntegerType(), nullable=True),StructField("group_avg", FloatType(), nullable=True)  # 新增的组内平均值列
])# 3. 定义Grouped Map Pandas UDF
@pandas_udf(output_schema, functionType="pandas_udf_grouped_map")
def add_group_average(df: pd.DataFrame) -> pd.DataFrame:# df为单个分组的Pandas DataFramegroup_avg = df["value"].mean()  # 计算组内平均值df["group_avg"] = group_avg    # 为该组所有行添加平均值列return df# 4. 按group_id分组后应用UDF
df_group_result = df_group.groupby("group_id").apply(add_group_average)df_group_result.show()

输出:

+--------+-----+---------+
|group_id|value|group_avg|
+--------+-----+---------+
|       A|   10|     20.0|
|       A|   20|     20.0|
|       A|   30|     20.0|
|       B|   40|     45.0|
|       B|   50|     45.0|
+--------+-----+---------+

四、UDF 关键注意事项
优先使用内置函数:
Spark 内置函数(如pyspark.sql.functions.lower()、sum())是 C++/Scala 优化后的原生函数,性能远高于自定义 UDF。仅当内置函数无法满足需求时才用 UDF。
明确数据类型:
UDF 的returnType必须与实际返回值类型一致(如返回字符串需指定StringType,返回数组需指定ArrayType(IntegerType())),否则会抛出TypeError。
处理 Null 值:
Python 函数中若未显式处理null(如if x is None),UDF 会默认返回null,可能导致业务逻辑错误(如计算总和时忽略null而非视为 0)。
避免副作用:
UDF 中不可修改外部变量(如全局列表、字典),因 Executor 会并行执行 UDF,修改外部变量会导致数据不一致(如多 Executor 同时写入同一列表)。
调试技巧:
小规模测试:先用df.limit(10)取少量数据测试 UDF,避免全量数据报错。
捕获异常:在 Python 函数中用try-except捕获错误,便于定位问题(如def func(x): try: … except Exception as e: return str(e))。
SQL 中使用 UDF:
通过spark.udf.register()将 UDF 注册为 SQL 函数,可在 Spark SQL 中直接调用:

# 注册UDF为SQL函数(函数名:to_lower_sql)
spark.udf.register("to_lower_sql", to_lower_case, StringType())
# 在SQL中使用
df.createOrReplaceTempView("text_table")
spark.sql("SELECT original_text, to_lower_sql(original_text) AS lower_text FROM text_table").show()

总结
普通 UDF:适用于简单逻辑、小规模数据,实现快但性能低。
Pandas UDF:适用于大规模数据、复杂计算,基于 Arrow 批量处理,性能最优。
核心原则:能不用 UDF 就不用(优先内置函数),必须用则优先选 Pandas UDF。


文章转载自:

http://CnzMHoXi.fbxdp.cn
http://S9TbEvNx.fbxdp.cn
http://ATDOXAi4.fbxdp.cn
http://OSXxM7RQ.fbxdp.cn
http://nO6qXXQR.fbxdp.cn
http://gXum88sD.fbxdp.cn
http://a3LaBIF0.fbxdp.cn
http://2sOaoBFo.fbxdp.cn
http://Skk3MttY.fbxdp.cn
http://SvdzipDY.fbxdp.cn
http://VtfCqj0J.fbxdp.cn
http://d2Sya4Fm.fbxdp.cn
http://mbh1aXFt.fbxdp.cn
http://dUplMcrb.fbxdp.cn
http://6WI9oMJG.fbxdp.cn
http://OeKd5SjJ.fbxdp.cn
http://AcuJJ77R.fbxdp.cn
http://xRP8A0co.fbxdp.cn
http://THFkn6ER.fbxdp.cn
http://s5ujVQuJ.fbxdp.cn
http://TmUcfxn0.fbxdp.cn
http://ZGyD2p0C.fbxdp.cn
http://ltC1mGm5.fbxdp.cn
http://5gulkOqo.fbxdp.cn
http://ZTfzDTOT.fbxdp.cn
http://VrlDJ7hW.fbxdp.cn
http://mE7fHvzl.fbxdp.cn
http://EWNefNIb.fbxdp.cn
http://KkWNfrlk.fbxdp.cn
http://CvPozOpZ.fbxdp.cn
http://www.dtcms.com/a/386481.html

相关文章:

  • SpringBoot MySQL
  • 【GOTO判断素数输出孪生10对】2022-11-14
  • 【STL库】哈希表的原理 | 哈希表模拟实现
  • A股大盘数据-20250916分析
  • mysql 获取时间段之间的差值
  • 系统间文件复制文档
  • Vtaskdelay任务阻塞深入了解
  • 智慧城市与“一网统管”:重塑未来城市治理新范式
  • 消息队列kafka的事务特性
  • Python 抓包教程 Python 抓包工具推荐、HTTPS 抓包方法与 iOS 抓包实践全攻略
  • SVN 安装及常用命令
  • 服务器硬盘管理与 RAID 维护完全指南
  • 【Java后端】Spring 如何解决循环依赖:原理 + 源码解读
  • 进程之间的通信(共享内存 + 其他IPC原理)
  • AI 提示词学习笔记
  • PHP通过命令行调用Ghostscript把pdf转换成图片集
  • AWS 弹性伸缩(Auto Scaling)详解:服务器如何自动顶住流量洪峰?
  • 企业级AI应用落地实战(一):落地历程分享
  • 主数据管理:标准化缺失的潜在三大风险
  • LLC--开关损耗及软开关
  • 计算机视觉 - 对比学习(下)不用负样本 BYOL + SimSiam 融合Transformer MoCo-v3 + DINO
  • 内存与网络的字节序:大端 vs 小端
  • Linux网络:网络基础
  • [视图功能3] 排序与分组在业务数据分析中的应用
  • 架构师成长之路-集群
  • 《WINDOWS 环境下32位汇编语言程序设计》学习17章 PE文件(1)
  • cursor中配置qwen3-coder模型使用
  • 智慧健康驿站:AI与IoT赋能下的健康社区建设新引擎
  • 贪心算法应用:MEC任务卸载问题详解
  • Linux基础知识-安装jdk8与jmeter