pyspark自定义udf函数
在 PySpark 中,UDF(User-Defined Function,用户自定义函数) 是扩展 Spark 功能的核心工具,用于处理内置函数(如pyspark.sql.functions中的函数)无法覆盖的自定义逻辑(如复杂字符串处理、自定义数值计算、多列联动计算等)。
Spark 是强类型计算引擎,因此自定义 UDF 必须明确声明输入 / 输出数据类型,且需避免行式处理的性能瓶颈(推荐使用基于 Arrow 的 Pandas UDF 优化)。本文将从基础概念、普通 UDF、高性能 Pandas UDF、注意事项四个维度全面讲解 PySpark UDF 的使用。
一、UDF 核心概念
作用:对 DataFrame 的列数据进行自定义转换,补充 Spark SQL 内置函数的不足(如自定义正则匹配、业务逻辑计算等)。
依赖:
导入 UDF 装饰器 / 函数:from pyspark.sql.functions import udf
导入数据类型(声明返回类型):from pyspark.sql.types import StringType, IntegerType, ArrayType, StructType等
执行原理:
普通 Python UDF:基于行式处理,需在 Python 解释器与 Spark JVM 之间进行序列化 / 反序列化(存在 GIL 锁瓶颈,数据量大时效率低)。
Pandas UDF(Vectorized UDF):基于Apache Arrow批量处理数据,直接操作 Pandas Series/DataFrame,性能比普通 UDF 提升 5-10 倍。
二、普通 UDF 的创建与使用
普通 UDF 适用于简单逻辑或小规模数据,创建步骤分为:定义 Python 函数 → 包装为 UDF → 应用到 DataFrame。
- 基础示例:字符串转换(无参数 UDF)
需求:将 DataFrame 的字符串列转为小写,并处理null值。
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType# 1. 初始化SparkSession
spark = SparkSession.builder.appName("BasicUDF").getOrCreate()# 2. 构造测试数据
df = spark.createDataFrame(data=[("Hello PySpark",), ("UDF Example",), (None,)], # 包含null值schema=["original_text"]
)# 3. 定义Python自定义函数(需处理null,否则返回null)
def to_lower_case(s: str) -> str:if s is None: # 显式处理null,避免逻辑错误return "empty"return s.lower()# 4. 包装为PySpark UDF(指定返回类型为StringType)
# 方式1:用udf()函数包装
to_lower_udf = udf(to_lower_case, returnType=StringType())
# 方式2:用@udf装饰器(更简洁)
# @udf(returnType=StringType())
# def to_lower_case(s: str) -> str: ...# 5. 应用UDF到DataFrame(用withColumn/select)
df_result = df.withColumn("lower_text", # 新列名to_lower_udf("original_text") # 对original_text列应用UDF
)# 查看结果
df_result.show(truncate=False)
输出:
+----------------+----------------+
|original_text |lower_text |
+----------------+----------------+
|Hello PySpark |hello pyspark |
|UDF Example |udf example |
|null |empty |
+----------------+----------------+
- 带参数的 UDF
需求:给数值列的每个元素加上一个固定参数值(如给所有数值加 5)。
需用functools.partial传递固定参数,或用 lambda 表达式包装。
from pyspark.sql.functions import col
from pyspark.sql.types import IntegerType
from functools import partial# 1. 构造测试数据
df_num = spark.createDataFrame(data=[(10,), (20,), (30,), (None,)],schema=["num"]
)# 2. 定义带参数的Python函数
def add_fixed_value(x: int, fixed: int) -> int:if x is None:return 0return x + fixed# 3. 传递固定参数(如fixed=5),包装为UDF
# 方式1:用partial固定参数
add_5_udf = udf(partial(add_fixed_value, fixed=5), IntegerType())
# 方式2:用lambda包装(适合简单参数)
# add_5_udf = udf(lambda x: add_fixed_value(x, 5), IntegerType())# 4. 应用UDF
df_num_result = df_num.withColumn("num_add_5",add_5_udf(col("num")) # 仅传入DataFrame列作为第一个参数
)df_num_result.show()
输出:
+----+----------+
| num|num_add_5 |
+----+----------+
| 10| 15|
| 20| 25|
| 30| 35|
|null| 0|
+----+----------+
- 处理复杂类型的 UDF(Array/Map/Struct)
Spark 支持复杂数据类型(如ArrayType数组、MapType字典、StructType结构体),UDF 可直接处理这些类型。
示例:计算数组列的总和
from pyspark.sql.types import ArrayType, IntegerType# 1. 构造含数组列的DataFrame
df_array = spark.createDataFrame(data=[([1,2,3],), ([4,5,6,7],), (None,), ([])], # 空数组、null数组schema=["numbers"]
)# 2. 定义处理数组的函数
def calculate_array_sum(arr: list) -> int:if arr is None or len(arr) == 0: # 处理null和空数组return 0return sum(arr)# 3. 包装UDF(输入为ArrayType,输出为IntegerType)
array_sum_udf = udf(calculate_array_sum, IntegerType())# 4. 应用UDF
df_array_result = df_array.withColumn("array_total",array_sum_udf("numbers")
)df_array_result.show()
输出:
+------------+-----------+
| numbers|array_total|
+------------+-----------+
| [1, 2, 3]| 6|
|[4, 5, 6, 7]| 22|
| null| 0|
| []| 0|
+------------+-----------+
三、高性能 Pandas UDF(Vectorized UDF)
普通 Python UDF 因行式处理效率低,数据量超过 100 万行时强烈推荐使用 Pandas UDF。其基于 Apache Arrow 实现数据零拷贝传输,批量处理 Pandas Series/DataFrame,性能显著提升。
Pandas UDF 分为两类:
Scalar Pandas UDF:输入 / 输出为标量(Pandas Series),对应普通 UDF 的批量版本。
Grouped Map Pandas UDF:分组处理(输入为分组后的 Pandas DataFrame,输出为新 DataFrame)。
- 基础:Scalar Pandas UDF
需求:计算数值列的平方(批量处理)。
需用@pandas_udf装饰器,并指定返回类型,函数参数和返回值均为 Pandas Series。
from pyspark.sql.functions import pandas_udf
import pandas as pd# 1. 开启Arrow加速(Spark 3.0+默认开启,低版本需显式设置)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")# 2. 构造测试数据
df_pd = spark.createDataFrame(data=[(1,), (2,), (3,), (4,), (5,)],schema=["num"]
)# 3. 定义Scalar Pandas UDF(用@pandas_udf装饰,指定返回类型)
@pandas_udf(IntegerType()) # 括号内为返回类型
def square_pd(s: pd.Series) -> pd.Series:# 直接用Pandas Series的矢量化操作,无需循环return s ** 2# 4. 应用UDF(用法与普通UDF一致)
df_pd_result = df_pd.withColumn("num_square",square_pd("num")
)df_pd_result.show()
输出:
+---+----------+
|num|num_square|
+---+----------+
| 1| 1|
| 2| 4|
| 3| 9|
| 4| 16|
| 5| 25|
+---+----------+
- Grouped Map Pandas UDF
需求:按分组计算每个组的平均值,并为每组的每行添加 “组内平均值” 列。
需指定输出 DataFrame 的 Schema,函数输入为分组后的 Pandas DataFrame,输出为处理后的 Pandas DataFrame。
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType# 1. 构造带分组列的DataFrame
df_group = spark.createDataFrame(data=[("A", 10), ("A", 20), ("A", 30), ("B", 40), ("B", 50)],schema=["group_id", "value"]
)# 2. 定义输出Schema(必须与函数返回的DataFrame结构一致)
output_schema = StructType([StructField("group_id", StringType(), nullable=True),StructField("value", IntegerType(), nullable=True),StructField("group_avg", FloatType(), nullable=True) # 新增的组内平均值列
])# 3. 定义Grouped Map Pandas UDF
@pandas_udf(output_schema, functionType="pandas_udf_grouped_map")
def add_group_average(df: pd.DataFrame) -> pd.DataFrame:# df为单个分组的Pandas DataFramegroup_avg = df["value"].mean() # 计算组内平均值df["group_avg"] = group_avg # 为该组所有行添加平均值列return df# 4. 按group_id分组后应用UDF
df_group_result = df_group.groupby("group_id").apply(add_group_average)df_group_result.show()
输出:
+--------+-----+---------+
|group_id|value|group_avg|
+--------+-----+---------+
| A| 10| 20.0|
| A| 20| 20.0|
| A| 30| 20.0|
| B| 40| 45.0|
| B| 50| 45.0|
+--------+-----+---------+
四、UDF 关键注意事项
优先使用内置函数:
Spark 内置函数(如pyspark.sql.functions.lower()、sum())是 C++/Scala 优化后的原生函数,性能远高于自定义 UDF。仅当内置函数无法满足需求时才用 UDF。
明确数据类型:
UDF 的returnType必须与实际返回值类型一致(如返回字符串需指定StringType,返回数组需指定ArrayType(IntegerType())),否则会抛出TypeError。
处理 Null 值:
Python 函数中若未显式处理null(如if x is None),UDF 会默认返回null,可能导致业务逻辑错误(如计算总和时忽略null而非视为 0)。
避免副作用:
UDF 中不可修改外部变量(如全局列表、字典),因 Executor 会并行执行 UDF,修改外部变量会导致数据不一致(如多 Executor 同时写入同一列表)。
调试技巧:
小规模测试:先用df.limit(10)取少量数据测试 UDF,避免全量数据报错。
捕获异常:在 Python 函数中用try-except捕获错误,便于定位问题(如def func(x): try: … except Exception as e: return str(e))。
SQL 中使用 UDF:
通过spark.udf.register()将 UDF 注册为 SQL 函数,可在 Spark SQL 中直接调用:
# 注册UDF为SQL函数(函数名:to_lower_sql)
spark.udf.register("to_lower_sql", to_lower_case, StringType())
# 在SQL中使用
df.createOrReplaceTempView("text_table")
spark.sql("SELECT original_text, to_lower_sql(original_text) AS lower_text FROM text_table").show()
总结
普通 UDF:适用于简单逻辑、小规模数据,实现快但性能低。
Pandas UDF:适用于大规模数据、复杂计算,基于 Arrow 批量处理,性能最优。
核心原则:能不用 UDF 就不用(优先内置函数),必须用则优先选 Pandas UDF。