当前位置: 首页 > news >正文

生产级编排AI工作流套件:Flyte全面使用指南 — Data input/output

生产级编排AI工作流套件:Flyte全面使用指南 — Data input/output

Flyte 是一个开源编排器,用于构建生产级数据和机器学习流水线。它以 Kubernetes 作为底层平台,注重可扩展性和可重复性。借助 Flyte,用户团队可以使用 Python SDK 构建流水线,并将其无缝部署在云端和本地环境中,从而实现分布式处理和高效的资源利用。

文中内容仅限技术学习与代码实践参考,市场存在不确定性,技术分析需谨慎验证,不构成任何投资建议。

Flyte

数据输入/输出

Flyte作为一个数据感知的编排平台,类型在其中起着至关重要的作用。本节将介绍Flyte支持的广泛数据类型。这些类型具有双重作用:不仅用于数据验证,还能实现本地存储与云存储之间的无缝数据传输。它们支持以下功能:

  • 数据溯源
  • 缓存记忆
  • 自动并行化
  • 简化数据访问
  • 自动生成的CLI和启动界面

如需更深入理解Flyte如何管理数据,请参阅理解Flyte数据处理机制

Python到Flyte类型的映射

Flytekit会自动将多数Python类型转换为Flyte类型。以下是详细的映射关系:

Python 类型Flyte 类型转换方式说明
intInteger自动转换使用Python 3类型提示
floatFloat自动转换使用Python 3类型提示
strString自动转换使用Python 3类型提示
boolBoolean自动转换使用Python 3类型提示
bytes/bytearrayBinary不支持可选择使用自定义类型转换器
complex不支持不支持可选择使用自定义类型转换器
datetime.timedeltaDuration自动转换使用Python 3类型提示
datetime.datetimeDatetime自动转换使用Python 3类型提示
datetime.dateDatetime自动转换使用Python 3类型提示
typing.List[T] / list[T]Collection [T]自动转换使用typing.List[T]list[T],其中T可以是表中列出的其他支持类型
typing.Iterator[T]Collection [T]自动转换使用typing.Iterator[T],其中T可以是表中列出的其他支持类型
File / file-like / os.PathLikeFlyteFile自动转换使用fileos.PathLike对象时,Flyte默认使用二进制协议。当使用FlyteFile["protocol"]时,假设文件属于指定协议(如’jpg’,‘png’,'hdf5’等)
DirectoryFlyteDirectory自动转换使用FlyteDirectory["protocol"]时,假设所有文件属于指定协议
typing.Dict[str, V] / dict[str, V]Map[str, V]自动转换使用typing.Dict[str, V]dict[str, V],其中V可以是表中列出的其他支持类型(包括嵌套字典)
dictJSON (struct.pb)自动转换使用未类型化的字典时,假设可转换为JSON格式。但可能无法转换并导致RuntimeError
@dataclassStruct自动转换类必须是用@dataclass装饰器注解的纯值类
np.ndarrayFile自动转换使用np.ndarray作为类型提示
pandas.DataFrameStructured Dataset自动转换使用pandas.DataFrame作为类型提示,不保留列类型信息
polars.DataFrameStructured Dataset自动转换使用polars.DataFrame作为类型提示,不保留列类型信息
polars.LazyFrameStructured Dataset自动转换使用polars.LazyFrame作为类型提示,不保留列类型信息
pyspark.DataFrameStructured Dataset需安装flytekitplugins-spark插件使用pyspark.DataFrame作为类型提示
pydantic.BaseModelMap需安装pydantic模块使用pydantic.BaseModel作为类型提示
torch.Tensor / torch.nn.ModuleFile需安装torch使用torch.Tensortorch.nn.Module及其派生类型作为类型提示
tf.keras.ModelFile需安装tensorflow使用tf.keras.Model及其派生类型
sklearn.base.BaseEstimatorFile需安装scikit-learn使用sklearn.base.BaseEstimator及其派生类型
User defined types任意类型自定义转换器默认使用FlytePickle转换器,也可定义自定义转换器。构建自定义类型转换器请参考此章节

FlyteFile 与 FlyteDirectory

FlyteFile

文件是 Python 用户最常使用的基础实体之一,Flyte 对其提供了完整支持。在 IDL 中,它们被称为 Blob 字面量,由 blob 类型 提供支持。

假设我们的任务很简单:下载几个 CSV 文件链接,使用 Python 内置的 csv.DictReader 函数读取,对预设列进行标准化处理,最后将标准化列输出到新的 CSV 文件。

要克隆并运行本页示例代码,请访问 Flytesnacks 仓库。

首先导入库:

import csv
from collections import defaultdict
from pathlib import Path
from typing import Listimport flytekit as fl

定义接受 FlyteFile 作为输入的任务。以下任务接收 FlyteFile、列名列表和待标准化列列表,最终输出包含标准化列的 CSV 文件。本例使用 z-score 标准化(均值中心化与标准差缩放):

@fl.task
def normalize_columns(csv_url: fl.FlyteFile,column_names: List[str],columns_to_normalize: List[str],output_location: str,
) -> fl.FlyteFile:# 从原始 CSV 文件读取数据parsed_data = defaultdict(list)with open(csv_url, newline="\n") as input_file:reader = csv.DictReader(input_file, fieldnames=column_names)next(reader)  # 跳过表头for row in reader:for column in columns_to_normalize:parsed_data[column].append(float(row[column].strip()))# 标准化数据normalized_data = defaultdict(list)for colname, values in parsed_data.items():mean = sum(values) / len(values)std = (sum([(x - mean) ** 2 for x in values]) / len(values)) ** 0.5normalized_data[colname] = [(x - mean) / std for x in values]# 写入本地路径out_path = str(Path(fl.current_context().working_directory) / f"normalized-{Path(csv_url.path).stem}.csv")with open(out_path, mode="w") as output_file:writer = csv.DictWriter(output_file, fieldnames=columns_to_normalize)writer.writeheader()for row in zip(*normalized_data.values()):writer.writerow({k: row[i] for i, k in enumerate(columns_to_normalize)})if output_location:return fl.FlyteFile(path=str(out_path), remote_path=output_location)else:return fl.FlyteFile(path=str(out_path))

FlyteFile 字面量可使用字符串限定范围(插入到 Blob 类型的格式中)。格式参数完全可选,未指定时默认为 ""。常用文件格式的预定义别名可在此处 查看。

当图像 URL 发送到任务时,系统会将其转换为本地驱动器上的 FlyteFile 对象(但不会立即下载)。调用 download() 方法会触发下载,path 属性可用于 open 文件。

如果指定了 output_location 参数,它将传递给 FlyteFileremote_path 参数,使用该路径作为存储位置而非随机位置(Flyte 对象存储)。

任务完成后,系统返回 FlyteFile 实例,将文件上传到指定位置,并创建指向该位置的 blob 字面量。

最后定义工作流。normalize_csv_files 工作流的 output_location 参数传递给任务的 location 输入。若非空字符串,任务将尝试将文件上传到该位置:

@fl.workflow
def normalize_csv_file(csv_url: fl.FlyteFile,column_names: List[str],columns_to_normalize: List[str],output_location: str = "",
) -> fl.FlyteFile:return normalize_columns(csv_url=csv_url,column_names=column_names,columns_to_normalize=columns_to_normalize,output_location=output_location,)

本地运行工作流:

if __name__ == "__main__":default_files = [("https://raw.githubusercontent.com/flyteorg/flytesnacks/refs/heads/master/examples/data_types_and_io/test_data/biostats.csv",["Name", "Sex", "Age", "Heights (in)", "Weight (lbs)"],["Age"],),("https://raw.githubusercontent.com/flyteorg/flytesnacks/refs/heads/master/examples/data_types_and_io/test_data/faithful.csv",["Index", "Eruption length (mins)", "Eruption wait (mins)"],["Eruption length (mins)"],),]print(f"Running {__file__} main...")for index, (csv_url, column_names, columns_to_normalize) in enumerate(default_files):normalized_columns = normalize_csv_file(csv_url=csv_url,column_names=column_names,columns_to_normalize=columns_to_normalize,)print(f"Running normalize_csv_file workflow on {csv_url}: {normalized_columns}")

安装 python-magic 包后可启用类型验证:

Mac OS

brew install libmagic

Linux

sudo apt-get install libmagic1

当前类型验证仅支持 Mac OSLinux 平台。

流式支持

Flyte 1.5 通过 fsspec 库引入对 FlyteFile 类型的流式支持。该集成实现了对远程文件的高效按需访问,无需完全下载到本地存储。

此功能标记为实验性。我们期待您对 API 的反馈!(@Peeter 此处应提供反馈链接?)

以下示例演示从 CSV 文件中删除部分行并写入新文件:

@fl.task()
def remove_some_rows(ff: fl.FlyteFile) -> fl.FlyteFile:"""删除 city 列为 'Seattle' 的行这是流式支持的示例"""new_file = fl.FlyteFile.new_remote_file("data_without_seattle.csv")with ff.open("r") as r:with new_file.open("w") as w:df = pd.read_csv(r)df = df[df["City"] != "Seattle"]df.to_csv(w, index=False)

FlyteDirectory

除文件外,文件夹是另一个基础操作系统原语。Flyte 以 多部分 blob 形式支持文件夹。

要克隆并运行本页示例代码,请访问 Flytesnacks 仓库。

首先导入库:

import csv
import urllib.request
from collections import defaultdict
from pathlib import Path
from typing import Listimport flytekit as fl

基于 FlyteFile 章节 的示例,继续考虑 CSV 文件列标准化。

以下任务下载 CSV 文件 URL 列表,并以 FlyteDirectory 对象返回文件夹路径:

@fl.task
def download_files(csv_urls: List[str]) -> union.FlyteDirectory:working_dir = fl.current_context().working_directorylocal_dir = Path(working_dir) / "csv_files"local_dir.mkdir(exist_ok=True)# 计算保留文件顺序所需的填充位数zfill_len = len(str(len(csv_urls)))for idx, remote_location in enumerate(csv_urls):# 在文件名前添加原列表中的索引位置local_image = Path(local_dir) / f"{str(idx).zfill(zfill_len)}_{Path(remote_location).name}"urllib.request.urlretrieve(remote_location, local_image)return fl.FlyteDirectory(path=str(local_dir))

当需要分批下载/上传目录内容时,可使用 FlyteDirectory 注解:

@fl.task
def t1(directory: Annotated[fl.FlyteDirectory, BatchSize(10)]) -> Annotated[fl.FlyteDirectory, BatchSize(100)]:...return fl.FlyteDirectory(...)# Flytekit 将以 10 文件块高效下载输入目录文件
# 加载到内存后写入本地磁盘
# 输出目录以 100 文件块上传

定义原地标准化列的辅助函数:

def normalize_columns(local_csv_file: str,column_names: List[str],columns_to_normalize: List[str],
):# 从原始 CSV 文件读取数据parsed_data = defaultdict(list)with open(local_csv_file, newline="\n") as input_file:reader = csv.DictReader(input_file, fieldnames=column_names)for row in (x for i, x in enumerate(reader) if i > 0):for column in columns_to_normalize:parsed_data[column].append(float(row[column].strip()))# 标准化数据normalized_data = defaultdict(list)for colname, values in parsed_data.items():mean = sum(values) / len(values)std = (sum([(x - mean) ** 2 for x in values]) / len(values)) ** 0.5normalized_data[colname] = [(x - mean) / std for x in values]# 用标准化列覆盖原 CSV 文件with open(local_csv_file, mode="w") as output_file:writer = csv.DictRow(output_file, fieldnames=columns_to_normalize)writer.writeheader()for row in zip(*normalized_data.values()):writer.writerow({k: row[i] for i, k in enumerate(columns_to_normalize)})

定义任务接收下载目录及元数据:

@fl.task
def normalize_all_files(csv_files_dir: fl.FlyteDirectory,columns_metadata: List[List[str]],columns_to_normalize_metadata: List[List[str]],
) -> union.FlyteDirectory:for local_csv_file, column_names, columns_to_normalize in zip(# 排序目录文件以保持原始 URL 顺序list(sorted(Path(csv_files_dir).iterdir())),columns_metadata,columns_to_normalize_metadata,):normalize_columns(local_csv_file, column_names, columns_to_normalize)return fl.FlyteDirectory(path=csv_files_dir.path)

组合工作流:

@fl.workflow
def download_and_normalize_csv_files(csv_urls: List[str],columns_metadata: List[List[str]],columns_to_normalize_metadata: List[List[str]],
) -> fl.FlyteDirectory:directory = download_files(csv_urls=csv_urls)return normalize_all_files(csv_files_dir=directory,columns_metadata=columns_metadata,columns_to_normalize_metadata=columns_to_normalize_metadata,)

本地运行工作流:

if __name__ == "__main__":csv_urls = ["https://raw.githubusercontent.com/flyteorg/flytesnacks/refs/heads/master/examples/data_types_and_io/test_data/biostats.csv","https://raw.githubusercontent.com/flyteorg/flytesnacks/refs/heads/master/examples/data_types_and_io/test_data/faithful.csv",]columns_metadata = [["Name", "Sex", "Age", "Heights (in)", "Weight (lbs)"],["Index", "Eruption length (mins)", "Eruption wait (mins)"],]columns_to_normalize_metadata = [["Age"],["Eruption length (mins)"],]print(f"Running {__file__} main...")directory = download_and_normalize_csv_files(csv_urls=csv_urls,columns_metadata=columns_metadata,columns_to_normalize_metadata=columns_to_normalize_metadata,)print(f"Running download_and_normalize_csv_files on {csv_urls}: {directory}")

修改数据上传位置

上传位置

使用 Flyte Serverless 时,FlyteFileFlyteDirectory 的容器本地文件始终上传到 Flyte 内部对象存储的随机生成(全局唯一)位置,不可更改。

使用 Flyte BYOC 时,上传位置可配置。

默认情况下,Flyte 将本地文件/目录上传到默认的 原始数据存储(Flyte 专用内部对象存储)。可通过设置原始数据前缀或指定 remote_path 修改上传位置。

设置自有对象存储桶

配置自有对象存储桶请参考云服务商指南:

  • 启用 AWS S3
  • 启用 Google Cloud Storage
  • 启用 Azure Blob Storage

修改原始数据前缀

若要将文件/目录上传到自有存储桶,可在注册时设置工作流级别的原始数据前缀参数,或在命令行/UI 中按执行指定。Flyte 将在您的存储桶中创建具有唯一随机名称的目录,确保数据不会被覆盖。

指定 remote_path

初始化 FlyteFileFlyteDirectory 时若指定 remote_path,数据将直接写入该路径,不进行随机化。

使用 remote_path 会覆盖数据

若将 remote_path 设为静态字符串,相同任务的后续运行将覆盖文件。如需使用动态生成路径,需自行生成。

远程示例

远程文件示例

上述示例从本地文件开始。要在任务边界保留文件,Flyte 会先将其上传到对象存储。也可直接使用远程文件:

@fl.task
def task_1() -> fl.FlyteFile:remote_path = "https://people.sc.fsu.edu/~jburkardt/data/csv/biostats.csv"return fl.FlyteFile(path=remote_path)

此时无需上传操作,源文件已在远程位置。传递对象时转换为带远程 URI 的 Blob。后续任务中可照常调用 FlyteFile.open()

若不需要传递文件,只需在任务内打开远程文件内容,可使用 from_source

@fl.task
def load_json():uri = "gs://my-bucket/my-directory/example.json"my_json = FlyteFile.from_source(uri)# 加载 JSON 文件并打印with open(my_json, "r") as json_file:data = json.load(json_file)print(data)

初始化 FlyteFile 时支持 fsspec 的所有 URI 方案,包括 httphttps(网页)、gs(Google 云存储)、s3(AWS S3)、abfsabfss(Azure Blob 文件系统)。

远程目录示例

@fl.task
def task1() -> fl.FlyteDirectory:p = "https://people.sc.fsu.edu/~jburkardt/data/csv/"return fl.FlyteDirectory(p)@fl.task
def task2(fd: fl.FlyteDirectory):# 获取目录内容列表并显示第一个 CSVfiles = fl.FlyteDirectory.listdir(fd)with open(files[0], mode="r") as f:d = f.read()print(f"第一个 CSV 内容:\n{d}")@fl.workflow
def workflow():fd = task1()task2(fd=fd)

流式处理

上述示例通过 FlyteFile.open() 访问文件内容,返回流对象。对于大文件可迭代处理:

@fl.task
def task_1() -> fl.FlyteFile:remote_path = "https://sample-videos.com/csv/Sample-Spreadsheet-100000-rows.csv"return fl.FlyteFile(path=remote_path)@fl.task
def task_2(ff: fl.FlyteFile):with ff.open(mode="r") as f:for row in f:do_something(row)

下载

可通过 隐式显式 两种方式下载文件。

隐式下载

当调用需要文件路径的函数时,FlyteFile 会自动下载到容器本地文件系统。FlyteFile 实现 os.PathLike 接口,其 __fspath__() 方法执行下载操作并返回本地路径。

典型示例是使用 Python 内置 open()

@fl.task
def task_2(ff: fl.FlyteFile):with open(ff, mode="r") as f:file_contents = f.read()

open() vs ff.open()

注意区别:

  • ff.open(mode="r"):返回迭代器,不下载文件
  • open(ff, mode="r"):调用内置函数,下载文件并返回文件句柄

更多信息参考 使用 FlyteFile 和 FlyteDirectory 下载。

显式下载

调用 FlyteFile.download() 显式下载:

@fl.task
def task_2(ff: fl.FlyteFile):local_path = ff.download()

此方法适用于需要下载文件但不立即读取的场景。

类型别名

Flytekit SDK 定义了一些特定类型的 FlyteFile 别名:

  • HDF5EncodedFile
  • HTMLPage
  • JoblibSerializedFile
  • JPEGImageFile
  • PDFFile
  • PNGImageFile
  • PythonPickledFile
  • PythonNotebook
  • SVGImageFile

类似地,FlyteDirectory 有以下别名:

  • TensorboardLogs
  • TFRecordsDirectory

这些别名可作为类型标记用于任务函数签名,但不会实际验证文件内容。更多信息参考 FlyteFile 别名 和 FlyteDirectory 别名。

使用 FlyteFile 和 FlyteDirectory 进行下载

FlyteFileFlyteDirectory 的核心概念是它们代表远程存储中的文件和目录。在任务中操作这些对象时,实际上是在操作远程文件和目录的引用。

当需要访问这些文件和目录的实际内容时,必须将其下载到任务容器的本地文件系统中。FlyteFileFlyteDirectory 的实际内容会通过两种方式下载到任务容器的本地文件系统:

  • 显式下载:通过调用 download 方法
  • 隐式下载:通过自动下载机制。当对 FlyteFileFlyteDirectory 调用外部函数(其内部会调用 __fspath__ 方法)时触发

理解下载发生的具体时机对于编写高效的任务和工作流代码至关重要。以下示例将展示 FlyteFileFlyteDirectory 对象内容下载到任务容器本地文件系统的具体场景。

FlyteFile

通过调用 download 方法显式下载

@fl.task
def my_task(ff: FlyteFile):print(os.path.isfile(ff.path))  # 将输出False,因为尚未下载任何内容ff.download()print(os.path.isfile(ff.path))  # 将输出True,因为FlyteFile已下载

注意我们使用 ff.path(类型为 typing.Union[str, os.PathLike])而非直接使用 ff。下例将展示直接使用 os.path.isfile(ff) 会触发 __fspath__ 方法导致隐式下载。

通过 __fspath__ 隐式下载

为支持类似 os.path.isfile 等常规文件路径操作,FlyteFile 实现了 __fspath__ 方法,该方法会将远程内容下载到容器本地的 path 路径。

@fl.task
def my_task(ff: FlyteFile):print(os.path.isfile(ff.path))  # 输出False(未下载)print(os.path.isfile(ff))  # 输出True(通过__fspath__触发下载)print(os.path.isfile(ff.path))  # 再次输出True(文件已下载)

需特别注意可能触发 __fspath__ 导致下载的操作,例如:直接在 FlyteFile 上调用 open(ff, mode="r") 获取路径内容(而非使用 path 属性),或直接对 FlyteFile 使用 shutil.copypathlib.Path 等操作。

FlyteDirectory

通过调用 download 方法显式下载

@fl.task
def my_task(fd: FlyteDirectory):print(os.listdir(fd.path))  # 输出空列表(目录未下载)fd.download()print(os.listdir(fd.path))  # 输出目录中的文件列表(已下载)

FlyteFile 类似,此处使用 fd.path(类型为 typing.Union[str, os.PathLike])而非直接使用 fd。直接调用 os.listdir(fd) 会触发 __fspath__ 方法。

通过 __fspath__ 隐式下载

为支持类似 os.listdir 等目录操作,FlyteDirectory 实现了 __fspath__ 方法,该方法会将远程内容下载到容器本地的 path 路径。

@fl.task
def my_task(fd: FlyteDirectory):print(os.listdir(fd.path))  # 输出空列表(未下载)print(os.listdir(fd))  # 输出文件列表(通过__fspath__触发下载)print(os.listdir(fd.path))  # 再次输出文件列表(已下载)

需特别注意可能触发 __fspath__ 导致下载的操作,例如:直接在 FlyteDirectory 上调用 os.stat 获取路径状态(而非使用 path 属性),或使用 os.path.isdir 检查目录是否存在。

使用 crawl 方法免下载检查目录内容

如前所述,在 FlyteDirectory 上使用 os.listdir 查看远程Blob存储内容会触发下载。若需避免下载,可使用 crawl 方法检查目录内容(不调用 __fspath__)。

@fl.task
def task1() -> FlyteDirectory:p = os.path.join(current_context().working_directory, "my_new_directory")os.makedirs(p)# 创建并写入两个文件with open(os.path.join(p, "file_1.txt"), 'w') as file1:file1.write("This is file 1.")with open(os.path.join(p, "file_2.txt"), 'w') as file2:file2.write("This is file 2.")return FlyteDirectory(p)@fl.task
def task2(fd: FlyteDirectory):print(os.listdir(fd.path))  # 输出空列表(未下载)print(list(fd.crawl()))  # 输出远程Blob存储中的文件列表# 示例输出:[('s3://union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80', 'file_1.txt'), ...]print(list(fd.crawl(detail=True)))  # 输出带详细信息的文件列表# 示例输出包含类型、创建时间等信息print(os.listdir(fd.path))  # 仍输出空列表(未触发下载)

任务输入与输出

Flyte 工作流引擎自动管理任务间数据传递及工作流输出。

该机制通过强制要求任务函数参数和返回值的强类型实现。这使得工作流引擎能够高效地在任务容器之间编组(marshall)和解编组(unmarshall)数据值。

实际数据会临时存储在您数据平面内的 Flyte 内部对象存储中(具体取决于云服务提供商,可能是 AWS S3、Google Cloud Storage 或 Azure Blob Storage)。

元数据与原始数据

Flyte 区分元数据和原始数据:

  • 原始值(intstr 等)直接存储在元数据存储中
  • 复杂数据对象(pandas.DataFrameFlyteFile 等)通过引用存储,其中引用指针在元数据存储中,实际数据在原始数据存储中

元数据存储

元数据存储位于您数据平面中的专用 Flyte 对象存储中。根据云服务提供商的不同,可能是 AWS S3、Google Cloud Storage 或 Azure Blob Storage 存储桶。

控制平面可访问此数据。它用于运行和管理工作流,并会在 UI 界面中展示。

原始数据存储

默认情况下,原始数据存储也位于您数据平面中的专用 Flyte 对象存储中。

但可以通过原始数据前缀(raw data prefix)参数,按工作流或按执行覆盖此位置。

控制平面无法访问原始数据存储中的数据,除非您的代码显式展示(例如在 Deck 中)。

更多细节请参阅 理解 Flyte 如何处理数据。

更改原始数据存储位置

有多种方式可以更改原始数据存储位置:

  • 注册工作流时
    • 使用 uctl register 时,添加 --files.outputLocationPrefix 标志
    • 使用 pyflyte register 时,添加 --raw-data-prefix 标志
  • 执行层面
    • 在 UI 界面中,通过执行对话框设置 Raw output data config 参数

这些选项会更改所有大型类型FlyteFileFlyteDirectoryDataFrame 及其他大型数据对象)的原始数据存储位置。

若只需控制 FlyteFileFlyteDirectory 使用的原始数据存储位置,可以在任务代码中初始化这些类型的对象时设置 remote_path 参数。

设置自有对象存储

默认情况下,当 Flyte 在任务间编组值时,会将元数据和原始数据都存储在其专用的对象存储桶中。虽然该存储桶位于您的 Flyte BYOC 数据平面中并由您控制,但它属于 Flyte 实现的一部分,不应被您的任务代码直接访问或修改。

因此,当更改默认原始数据位置时,目标位置应该是您自行设置的、独立于 Flyte 实现存储桶的其他存储桶。

有关如何设置自有存储桶并启用访问权限的信息,请根据云服务提供商参考:

  • 启用 AWS S3
  • 启用 Google Cloud Storage
  • 启用 Azure Blob Storage

访问属性

在 Flyte 中,您可以直接访问列表、字典、数据类及其组合类型输出承诺的属性。需要注意的是,虽然这些功能可能看起来是 Python 的正常行为,但 @workflow 函数中的代码实际上并非 Python,而是一个由 Flyte 编译的类 Python DSL。因此,这种属性访问方式实际上是 Flyte 专门实现的功能。该功能支持在工作流中直接传递输出属性,极大地方便了复杂数据结构的操作。

Flytekit 版本 >= v1.14.0 支持 Pydantic BaseModel V2,您也可以在 Pydantic BaseModel V2 上进行属性访问。

要克隆并运行本页示例代码,请参考 Flytesnacks 代码库。

首先导入所需依赖并定义后续使用的通用任务:

from dataclasses import dataclass
import flytekit as fl@fl.task
def print_message(message: str):print(message)return

列表

可以通过索引访问输出列表。

Flyte 当前不支持通过列表切片访问输出承诺:

@fl.task
def list_task() -> list[str]:return ["apple", "banana"]@fl.workflow
def list_wf():items = list_task()first_item = items[0]print_message(message=first_item)

字典

通过指定键访问输出字典:

@fl.task
def dict_task() -> dict[str, str]:return {"fruit": "banana"}@fl.workflow
def dict_wf():fruit_dict = dict_task()print_message(message=fruit_dict["fruit"])

数据类

直接访问数据类的属性:

@dataclass
class Fruit:name: str@fl.task
def dataclass_task() -> Fruit:return Fruit(name="banana")@fl.workflow
def dataclass_wf():fruit_instance = dataclass_task()print_message(message=fruit_instance.name)

复合类型

列表、字典和数据类的组合类型也能有效工作:

@fl.task
def advance_task() -> (dict[str, list[str]], list[dict[str, str]], dict[str, Fruit]):return {"fruits": ["banana"]}, [{"fruit": "banana"}], {"fruit": Fruit(name="banana")}@fl.task
def print_list(fruits: list[str]):print(fruits)@fl.task
def print_dict(fruit_dict: dict[str, str]):print(fruit_dict)@fl.workflow
def advanced_workflow():dictionary_list, list_dict, dict_dataclass = advance_task()print_message(message=dictionary_list["fruits"][0])print_message(message=list_dict[0]["fruit"])print_message(message=dict_dataclass["fruit"].name)print_list(fruits=dictionary_list["fruits"])print_dict(fruit_dict=list_dict[0])

可以通过以下方式在本地运行所有工作流:

if __name__ == "__main__":list_wf()dict_wf()dataclass_wf()advanced_workflow()

失败场景

以下工作流会因尝试访问越界索引和键而失败:

from flytekit import WorkflowFailurePolicy@fl.task
def failed_task() -> (list[str], dict[str, str], Fruit):return ["apple", "banana"], {"fruit": "banana"}, Fruit(name="banana")@fl.workflow(# 只要其他可执行节点仍可用,当某个节点出错时工作流仍能继续执行failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE
)
def failed_workflow():fruits_list, fruit_dict, fruit_instance = failed_task()print_message(message=fruits_list[100])  # 访问不存在的索引print_message(message=fruit_dict["fruits"])  # 访问不存在的键print_message(message=fruit_instance.fruit)  # 访问不存在的参数

数据类

当需要在 Flyte 实体间传递多个值时,可以使用 dataclass

Flytekit 使用 Mashumaro 库 来实现数据类的序列化和反序列化。

在 1.14 版本中,flytekit 采用 MessagePack 作为数据类的序列化格式,解决了早期版本将数据序列化为 Protobuf struct 中 JSON 字符串的主要限制。

早期版本中,Protobuf 的 struct 会将整型转换为浮点类型,需要用户编写样板代码来规避此问题。

若使用 Flytekit 版本 < v1.11.1,需要导入 from dataclasses_json import dataclass_json 并使用 @dataclass_json 装饰数据类。

Flytekit 版本 < v1.14.0 会为数据类生成 protobuf struct 字面量。

Flytekit 版本 >= v1.14.0 会为数据类生成 msgpack 字节字面量。

若使用 Flytekit 版本 >= v1.14.0 但希望为数据类生成 protobuf struct 字面量,可设置环境变量 FLYTE_USE_OLD_DC_FORMATtrue

更多细节请参考 MSGPACK IDL RFC:https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md

要克隆并运行本页示例代码,请访问 Flytesnacks 仓库。

首先导入必要依赖:

import os
import tempfile
from dataclasses import dataclassimport pandas as pd
import flytekit as fl
from flytekit.types.structured import StructuredDataset

使用 ImageSpec 构建自定义镜像:

image_spec = union.ImageSpec(registry="ghcr.io/flyteorg",packages=["pandas", "pyarrow"],
)

Python 类型

定义包含 intstrdict 作为数据类型的 dataclass

@dataclass
class Datum:x: inty: strz: dict[int, str]

可以在不同语言编写的任务间传递 dataclass,并通过 Flyte UI 以原始 JSON 格式输入。

数据类中的所有变量必须标注类型,否则会导致错误。

声明后,数据类可作为输出返回或作为输入接受:

@fl.task(container_image=image_spec)
def stringify(s: int) -> Datum:"""数据类返回值将被视为单个复杂 JSON 返回"""return Datum(x=s, y=str(s), z={s: str(s)})@fl.task(container_image=image_spec)
def add(x: Datum, y: Datum) -> Datum:x.z.update(y.z)return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)

Flyte 类型

定义接受 StructuredDatasetFlyteFileFlyteDirectory 的数据类:

@dataclass
class FlyteTypes:dataframe: StructuredDatasetfile: union.FlyteFiledirectory: union.FlyteDirectory@fl.task(container_image=image_spec)
def upload_data() -> FlyteTypes:df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})temp_dir = tempfile.mkdtemp(prefix="flyte-")df.to_parquet(temp_dir + "/df.parquet")file_path = tempfile.NamedTemporaryFile(delete=False)file_path.write(b"Hello, World!")fs = FlyteTypes(dataframe=StructuredDataset(dataframe=df),file=union.FlyteFile(file_path.name),directory=union.FlyteDirectory(temp_dir),)return fs@fl.task(container_image=image_spec)
def download_data(res: FlyteTypes):assert pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}).equals(res.dataframe.open(pd.DataFrame).all())f = open(res.file, "r")assert f.read() == "Hello, World!"assert os.listdir(res.directory) == ["df.parquet"]

数据类支持使用 Python 类型、其他数据类、FlyteFile、FlyteDirectory 和结构化数据集等关联数据。

定义调用上述任务的工作流:

@fl.workflow
def dataclass_wf(x: int, y: int) -> (Datum, FlyteTypes):o1 = add(x=stringify(s=x), y=stringify(s=y))o2 = upload_data()download_data(res=o2)return o1, o2

要通过 pyflyte run 触发接受数据类作为输入的任务,可提供 JSON 文件作为输入:

pyflyte run dataclass.py add --x dataclass_input.json --y dataclass_input.json

另一个通过 pyflyte run 触发接受数据类作为输入任务的示例:

$ pyflyte run \https://raw.githubusercontent.com/flyteorg/flytesnacks/69dbe4840031a85d79d9ded25f80397c6834752d/examples/data_types_and_io/data_types_and_io/dataclass.py \add --x dataclass_input.json --y dataclass_input.json

Enum 类型

有时您可能需要将输入或输出的有效值限制为预定义的集合。这种常见需求通常通过编程语言中的 Enum 类型来实现。

您可以创建 Python Enum 类型并将其用作任务的输入或输出。Flytekit 会自动进行类型转换,并将输入输出限制在预定义的值集合中。

当前仅支持字符串值作为有效的 Enum 值。Flyte 假定列表中的第一个值为默认值,且 Enum 类型不能为可选类型。因此在定义 Enum 时,务必将第一个值设计为有效的默认值。

我们定义了一个 Enum 和简单的咖啡机工作流,该工作流接受订单并据此冲泡咖啡 ☕️。假设咖啡机只能识别 Enum 类型的输入:

# coffee_maker.pyfrom enum import Enumimport flytekit as flclass Coffee(Enum):ESPRESSO = "espresso"AMERICANO = "americano"LATTE = "latte"CAPPUCCINO = "cappucccino"@fl.task
def take_order(coffee: str) -> Coffee:return Coffee(coffee)@fl.task
def prep_order(coffee_enum: Coffee) -> str:return f"Preparing {coffee_enum.value} ..."@fl.workflow
def coffee_maker(coffee: str) -> str:coffee_enum = take_order(coffee=coffee)return prep_order(coffee_enum=coffee_enum)# 工作流也可以接受枚举值
@fl.workflow
def coffee_maker_enum(coffee_enum: Coffee) -> str:return prep_order(coffee_enum=coffee_enum)

您可以在运行时为 coffee_enum 参数指定值:

pyflyte run coffee_maker.py coffee_maker_enum --coffee_enum="latte"

Pickle 类型

Flyte 通过利用类型信息来编译任务和工作流,从而强制执行类型安全,这使得静态分析和条件分支等多种功能成为可能。

然而,我们也致力于为终端用户提供灵活性,使他们无需在体验 Flyte 价值之前就投入大量精力理解数据结构。

Flyte 支持 FlytePickle 转换器,该转换器会将所有无法识别的类型提示转换为 FlytePickle,从而实现 Python 值到 pickle 文件的序列化/反序列化。

Pickle 只能在完全相同的 Python 版本之间传输对象。为了获得最佳性能,建议使用 Flyte 支持的 Python 类型或注册自定义转换器,因为使用 pickle 类型可能导致性能下降。

本示例演示了如何在不注册转换器的情况下使用自定义对象。

要克隆并运行本页示例代码,请访问 Flytesnacks 仓库。

import flytekit as fl

Superhero 表示用户定义的复杂类型,Flytekit 可将其序列化为 pickle 文件,并作为输入/输出数据在任务间传递。

您也可以将此对象转换为 dataclass 以获得更好的性能。此处使用简单对象仅用于演示目的。

class Superhero:def __init__(self, name, power):self.name = nameself.power = power@fl.task
def welcome_superhero(name: str, power: str) -> Superhero:return Superhero(name, power)@fl.task
def greet_superhero(superhero: Superhero) -> str:return f"👋 Hello {superhero.name}! Your superpower is {superhero.power}."@fl.workflow
def superhero_wf(name: str = "Thor", power: str = "Flight") -> str:superhero = welcome_superhero(name=name, power=power)return greet_superhero(superhero=superhero)

Pydantic BaseModel

flytekit版本 >=1.14 原生支持 Pydantic BaseModel生成的JSON格式,增强了 Pydantic BaseModel 与 Flyte 类型系统的互操作性。

注意:Pydantic BaseModel V2 仅在使用 flytekit 版本 >= v1.14.0 时可用。

自 1.14 版本起,flytekit采用MessagePack作为 Pydantic BaseModel的序列化格式,克服了旧版本将数据序列化为 Protobuf struct数据类型中的 JSON 字符串的主要限制:

Protobuf 的struct会将int类型转换为float,迫使开发者需要编写样板代码来规避此问题。

默认情况下,flytekit >= 1.14在序列化时会生成msgpack字节字面量,保留BaseModel类中定义的类型。如果您使用 flytekit 版本 >= v1.14.0 序列化BaseModel但希望生成 Protobuf struct字面量,可将环境变量FLYTE_USE_OLD_DC_FORMAT设为true

更多技术细节请参考 MESSAGEPACK IDL RFC:https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md

要克隆并运行本页示例代码,请访问 Flytesnacks 代码库

您可以在 pydantic BaseModel 中使用数据类和 Flyte 类型(FlyteFile、FlyteDirectory、FlyteSchema 和 StructuredDataset)。

首先导入必要的依赖:

import os
import tempfile
import pandas as pd
from flytekit
from flytekit.types.structured import StructuredDataset
from pydantic import BaseModel

使用 ImageSpec 构建自定义镜像:

image_spec = union.ImageSpec(registry="ghcr.io/flyteorg",packages=["pandas", "pyarrow", "pydantic"],
)

Python 类型

定义包含intstrdict数据类型的pydantic basemodel

class Datum(BaseModel):x: inty: strz: dict[int, str]

您可以在不同语言编写的任务间传递pydantic basemodel,并通过 Flyte 控制台以原始 JSON 格式输入。

数据类中的所有变量必须使用类型注解,否则会导致错误。

声明后,数据类可作为输出返回或作为输入接收:

@fl.task(container_image=image_spec)
def stringify(s: int) -> Datum:"""Pydantic 模型的返回将被视为单个复杂 JSON 返回值"""return Datum(x=s, y=str(s), z={s: str(s)})@fl.task(container_image=image_spec)
def add(x: Datum, y: Datum) -> Datum:x.z.update(y.z)return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)

Flyte 类型

定义接收StructuredDatasetFlyteFileFlyteDirectory的数据类:

class FlyteTypes(BaseModel):dataframe: StructuredDatasetfile: union.FlyteFiledirectory: union.FlyteDirectory@fl.task(container_image=image_spec)
def upload_data() -> FlyteTypes:df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})temp_dir = tempfile.mkdtemp(prefix="flyte-")df.to_parquet(os.path.join(temp_dir, "df.parquet"))file_path = tempfile.NamedTemporaryFile(delete=False)file_path.write(b"Hello, World!")file_path.close()fs = FlyteTypes(dataframe=StructuredDataset(dataframe=df),file=fl.FlyteFile(file_path.name),directory=fl.FlyteDirectory(temp_dir),)return fs@fl.task(container_image=image_spec)
def download_data(res: FlyteTypes):expected_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})actual_df = res.dataframe.open(pd.DataFrame).all()assert expected_df.equals(actual_df), "DataFrames 不匹配!"with open(res.file, "r") as f:assert f.read() == "Hello, World!", "文件内容不匹配!"assert os.listdir(res.directory) == ["df.parquet"], "目录内容不匹配!"

数据类支持使用 Python 类型、数据类、FlyteFile、FlyteDirectory 和 StructuredDataset 相关数据。

定义调用上述任务的工作流:

@fl.workflow
def basemodel_wf(x: int, y: int) -> (Datum, FlyteTypes):o1 = add(x=stringify(s=x), y=stringify(s=y))o2 = upload_data()download_data(res=o2)return o1, o2

使用pyflyte run触发接受数据类作为输入的任务时,可提供 JSON 文件作为输入:

pyflyte run dataclass.py basemodel_wf --x 1 --y 2

触发接受数据类作为输入的任务:

$ pyflyte run \https://raw.githubusercontent.com/flyteorg/flytesnacks/b71e01d45037cea883883f33d8d93f258b9a5023/examples/data_types_and_io/data_types_and_io/pydantic_basemodel.py \basemodel_wf --x 1 --y 2

PyTorch 类型

Flyte 提倡使用强类型数据来简化健壮且可测试的管道的开发。除了在数据工程中的应用,Flyte 主要应用于机器学习领域。为了优化 Flyte 任务间的通信(特别是在处理张量和模型时),我们引入了对 PyTorch 类型的支持。

张量与模块

有时您可能需要在工作流中传递张量和模块(模型)。在缺乏原生 PyTorch 张量和模块支持的情况下,Flytekit 依赖 pickle 进行这些实体及任何未知类型的序列化和反序列化。但这种方式并非最高效的解决方案。因此,我们将 PyTorch 的序列化与反序列化支持集成到了 Flyte 类型系统中。

要克隆并运行本页示例代码,请参考 Flytesnacks 仓库。

@fl.task
def generate_tensor_2d() -> torch.Tensor:return torch.tensor([[1.0, -1.0, 2], [1.0, -1.0, 9], [0, 7.0, 3]])@fl.task
def reshape_tensor(tensor: torch.Tensor) -> torch.Tensor:# 将2D转换为3Dtensor.unsqueeze_(-1)return tensor.expand(3, 3, 2)@fl.task
def generate_module() -> torch.nn.Module:bn = torch.nn.BatchNorm1d(3, track_running_stats=True)return bn@fl.task
def get_model_weight(model: torch.nn.Module) -> torch.Tensor:return model.weightclass MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.l0 = torch.nn.Linear(4, 2)self.l1 = torch.nn.Linear(2, 1)def forward(self, input):out0 = self.l0(input)out0_relu = torch.nn.functional.relu(out0)return self.l1(out0_relu)@fl.task
def get_l1() -> torch.nn.Module:model = MyModel()return model.l1@fl.workflow
def pytorch_native_wf():reshape_tensor(tensor=generate_tensor_2d())get_model_weight(model=generate_module())get_l1()

现在传递张量和模块不再繁琐!

检查点

PyTorchCheckpoint 是专门用于序列化和反序列化 PyTorch 模型的检查点。它检查点保存 torch.nn.Module 的状态、超参数和优化器状态。

此模块检查点与标准检查点的区别在于它专门捕获模块的 state_dict。因此在恢复模块时,必须将模块的 state_dict 与实际模块结合使用。根据 PyTorch 文档 建议,虽然序列化在两种情况下都有效,但推荐存储模块的 state_dict 而非模块本身。

from dataclasses import dataclassimport torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses_json import dataclass_json
from flytekit.extras.pytorch import PyTorchCheckpoint@dataclass_json
@dataclass
class Hyperparameters:epochs: intloss: floatclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x@fl.task
def generate_model(hyperparameters: Hyperparameters) -> PyTorchCheckpoint:bn = Net()optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9)return PyTorchCheckpoint(module=bn, hyperparameters=hyperparameters, optimizer=optimizer)@fl.task
def load(checkpoint: PyTorchCheckpoint):new_bn = Net()new_bn.load_state_dict(checkpoint["module_state_dict"])optimizer = optim.SGD(new_bn.parameters(), lr=0.001, momentum=0.9)optimizer.load_state_dict(checkpoint["optimizer_state_dict"])@fl.workflow
def pytorch_checkpoint_wf():checkpoint = generate_model(hyperparameters=Hyperparameters(epochs=10, loss=0.1))load(checkpoint=checkpoint)

PyTorchCheckpoint 支持序列化 dictNamedTupledataclass 类型的超参数。

自动 GPU-CPU 转换

并非所有 PyTorch 计算都需要 GPU。在某些情况下(特别是在 GPU 上训练模型后),将计算转移到 CPU 可能更有利。要利用 GPU 的强大能力,典型的结构是使用:to(torch.device("cuda"))

当在 CPU 上处理 GPU 变量时,需要使用 to(torch.device("cpu")) 结构将变量传输到 CPU。但 PyTorch 推荐的这种手动转换方式可能不够友好。为此,我们增加了对 PyTorch 类型自动 GPU-CPU 转换(反之亦然)的支持。

import flytekit as fl
from typing import Tuple@fl.task(requests=union.Resources(gpu="1"))
def train() -> Tuple[PyTorchCheckpoint, torch.Tensor, torch.Tensor, torch.Tensor]:...device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = Model(X_train.shape[1])model.to(device)...X_train, X_test = X_train.to(device), X_test.to(device)y_train, y_test = y_train.to(device), y_test.to(device)...return PyTorchCheckpoint(module=model), X_train, X_test, y_test@fl.task
def predict(checkpoint: PyTorchCheckpoint,X_train: torch.Tensor,X_test: torch.Tensor,y_test: torch.Tensor,
):new_bn = Model(X_train.shape[1])new_bn.load_state_dict(checkpoint["module_state_dict"])accuracy_list = np.zeros((5,))with torch.no_grad():y_pred = new_bn(X_test)correct = (torch.argmax(y_pred, dim=1) == y_test).type(torch.FloatTensor)accuracy_list = correct.mean()

predict 任务将在 CPU 上运行,GPU 到 CPU 的设备转换将由 Flytekit 自动处理。

StructuredDataset

与大多数类型系统类似,Python 拥有基本类型、容器类型(如映射和元组),并支持用户自定义结构。然而,尽管存在丰富的 DataFrame 类(Pandas、Spark、Pandas 等),但 Python 本身并没有原生的抽象 DataFrame 类型。这正是 StructuredDataset 类型所要填补的空白。它提供以下优势:

  • 消除将文件对象序列化/反序列化为 DataFrame 实例所需的样板代码
  • 消除用于传递文件中表格数据格式的额外输入/输出
  • 增加 DataFrame 文件加载方式的灵活性
  • 提供一系列 DataFrame 专属功能 - 强制不同模式的兼容性(不仅限于编译时,运行时也有效,因为类型信息会随字面量传递),存储第三方模式定义,未来还可能实现样本数据渲染、统计摘要等功能

使用方式

要使用 StructuredDataset 类型,需导入 pandas 并定义返回 Pandas DataFrame 的任务。Flytekit 将检测 Pandas DataFrame 返回签名,并将任务接口转换为 StructuredDataset 类型。

示例

本示例演示如何使用 Flyte 实体处理结构化数据集。

使用 StructuredDataset 类型仅需导入 pandas。以下其他导入仅为本示例所需。

要克隆并运行本页示例代码,请参考 Flytesnacks 代码库。

首先导入示例依赖:

import typing
from dataclasses import dataclass
from pathlib import Pathimport numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import flytekit as fl
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.types.structured.structured_dataset import (PARQUET,StructuredDataset,StructuredDatasetDecoder,StructuredDatasetEncoder,StructuredDatasetTransformerEngine,
)
from typing_extensions import Annotated

定义返回 Pandas DataFrame 的任务:

@fl.task(container_image=image_spec)
def generate_pandas_df(a: int) -> pd.DataFrame:return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [a, 22], "Height": [160, 178]})

但使用此简单形式时,用户无法设置上文提到的附加 DataFrame 信息:

  • 列类型信息
  • 序列化字节格式
  • 存储驱动和位置
  • 附加第三方模式信息

这是有意设计,因为我们希望默认情况能满足大多数用例,且尽可能减少对现有代码的改动。通过 Python 变量注解可以轻松指定这些信息,该功能专为使用任意元数据补充类型而设计。

列类型信息

若需提取 DataFrame 的实际列子集并指定其类型进行验证,只需在结构化数据集类型注解中指定列名及其类型。

首先初始化需要从 StructuredDataset 中提取的列类型:

all_cols = fl.kwtypes(Name=str, Age=int, Height=int)
col = fl.kwtypes(Age=int)

定义通过调用 all() 打开结构化数据集的任务。当使用 pandas.DataFrame 调用 all() 时,Flyte 引擎会下载 S3 上的 Parquet 文件并反序列化为 pandas.DataFrame。注意可以使用结构化数据集支持的任何 DataFrame 类型调用 open(),例如使用 pa.Table 将 Pandas DataFrame 转换为 PyArrow 表:

@fl.task(container_image=image_spec)
def get_subset_pandas_df(df: Annotated[StructuredDataset, all_cols]) -> Annotated[StructuredDataset, col]:df = df.open(pd.DataFrame).all()df = pd.concat([df, pd.DataFrame([[30]], columns=["Age"])])return StructuredDataset(dataframe=df)@fl.workflow
def simple_sd_wf(a: int = 19) -> Annotated[StructuredDataset, col]:pandas_df = generate_pandas_df(a=a)return get_subset_pandas_df(df=pandas_df)

若列不匹配,代码可能在运行时失败。输入 df 包含 NameAgeHeight 列,而输出结构化数据集将仅保留 Age 列。

序列化字节格式

可使用自定义序列化格式序列化 DataFrame。以下是注册已有的 Pandas 到 CSV 处理器,并通过注解结构化数据集启用 CSV 序列化的方法:

from flytekit.types.structured import register_csv_handlers
from flytekit.types.structured.structured_dataset import CSVregister_csv_handlers()@fl.task(container_image=image_spec)
def pandas_to_csv(df: pd.DataFrame) -> Annotated[StructuredDataset, CSV]:return StructuredDataset(dataframe=df)@fl.workflow
def pandas_to_csv_wf() -> Annotated[StructuredDataset, CSV]:pandas_df = generate_pandas_df(a=19)return pandas_to_csv(df=pandas_df)

存储驱动和位置

默认情况下,数据将写入与其他指针类型(FlyteFile、FlyteDirectory 等)相同的位置。这由 Flyte 的输出数据前缀选项控制,该选项支持多级配置。

也就是说,在简单默认情况下,Flytekit 会:

  • 查找默认格式(例如 Pandas DataFrame)
  • 根据原始输出前缀设置查找默认存储位置
  • 使用这两个设置选择编码器并调用

那么什么是编码器?让我们通过结构化数据集插件的工作原理来理解。

结构化数据集插件的内部机制

与 Flyte 交互时,任何 DataFrame 实例都需要进行两个操作:

  • 将 Python 实例序列化/反序列化为字节(按上述指定格式)
  • 将这些字节传输到/从某处检索

每个结构化数据集插件(称为编码器或解码器)都需要执行这两个步骤。Flytekit 根据三个属性决定调用哪个已加载的插件:

  • 字节格式
  • 存储位置
  • 任务或工作流签名中的 Python 类型

这三个键唯一标识编码器(用于将内存中的 Python DataFrame 转换为 Flyte 值,例如任务完成返回 DataFrame 时)或解码器(用于从 Flyte 值水合生成内存中的 DataFrame,例如任务启动并接收 DataFrame 输入时)。

但要求用户在每个签名上使用 typing.Annotated 会比较笨拙。因此 Flytekit 为每个注册的 Python DataFrame 类型提供默认字节格式。

uri 参数

BigQuery 的 uri 参数允许使用 uri 从云端加载和检索数据。uri 由存储桶名称和前缀为 gs:// 的文件名组成。若为结构化数据集指定 BigQuery uri,BigQuery 将在 uri 指定位置创建表。结构化数据集中的 uri 可读写 S3、GCP、BigQuery 等存储。

在将 DataFrame 写入 BigQuery 表前需:

  1. 创建 GCP 账户 并创建服务账号
  2. 创建项目并在 .bashrc 文件中添加 GOOGLE_APPLICATION_CREDENTIALS 环境变量
  3. 在项目中创建数据集

以下是定义将 Pandas DataFrame 转换为 BigQuery 表的任务:

@fl.task
def pandas_to_bq() -> StructuredDataset:df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})return StructuredDataset(dataframe=df, uri="gs://<BUCKET_NAME>/<FILE_NAME>")

BUCKET_NAME 替换为 GCS 存储桶名称,FILE_NAME 替换为目标文件名。

注意:未在结构化数据集构造函数或签名中指定格式。BigQuery 编码器是如何被调用的?

这是因为内置的 BigQuery 编码器以空格式加载到 Flytekit 中。Flytekit 的 StructuredDatasetTransformerEngine 将其解释为通用编码器(或解码器),若未找到更具体的格式,则可以跨格式工作。

以下是定义将 BigQuery 表转换为 Pandas DataFrame 的任务:

@fl.task
def bq_to_pandas(sd: StructuredDataset) -> pd.DataFrame:return sd.open(pd.DataFrame).all()

执行 BigQuery 查询时,Flyte 会在项目的数据集中创建表。

如何从任务返回多个 DataFrame?

例如,任务如何返回两个 DataFrame:

  • 第一个 DataFrame 需写入 BigQuery 并使用其库序列化
  • 第二个需序列化为 CSV 并写入与通用指针数据存储桶不同的 GCS 指定位置

若需要默认行为(其本身可根据加载的插件进行配置),直接使用当前原始 DataFrame 类即可:

@fl.task
def t1() -> typing.Tuple[StructuredDataset, StructuredDataset]:...return StructuredDataset(df1, uri="bq://project:flyte.table"), \StructuredDataset(df2, uri="gs://auxiliary-bucket/data")

若要自定义 Flyte 交互行为,需将 DataFrame 包装在 StructuredDataset 包装对象中。

如何定义自定义结构化数据集插件?

StructuredDataset 自带编码器和解码器,分别处理 Python 值到 Flyte 字面量的转换。以下演示如何构建 NumPy 编码器/解码器,使 2D NumPy 数组成为结构化数据集的有效类型。

NumPy 编码器

继承 StructuredDatasetEncoder 并实现 encode 函数。该函数将 NumPy 数组转换为中间格式(本例为 Parquet 文件格式):

class NumpyEncodingHandler(StructuredDatasetEncoder):def encode(self,ctx: fl.FlyteContext,structured_dataset: StructuredDataset,structured_dataset_type: union.StructuredDatasetType,) -> literals.StructuredDataset:df = typing.cast(np.ndarray, structured_dataset.dataframe)name = ["col" + str(i) for i in range(len(df))]table = pa.Table.from_arrays(df, name)path = ctx.file_access.get_random_remote_directory()local_dir = ctx.file_access.get_random_local_directory()local_path = Path(local_dir) / f"{0:05}"pq.write_table(table, str(local_path))ctx.file_access.upload_directory(local_dir, path)return literals.StructuredDataset(uri=path,metadata=StructuredDatasetMetadata(structured_dataset_type=union.StructuredDatasetType(format=PARQUET)),)

NumPy 解码器

继承 StructuredDatasetDecoder 并实现 StructuredDatasetDecoder.decode 函数。该函数将 Parquet 文件转换为 numpy.ndarray

class NumpyDecodingHandler(StructuredDatasetDecoder):def decode(self,ctx: fl.FlyteContext,flyte_value: literals.StructuredDataset,current_task_metadata: StructuredDatasetMetadata,) -> np.ndarray:local_dir = ctx.file_access.get_random_local_directory()ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True)table = pq.read_table(local_dir)return table.to_pandas().to_numpy()

NumPy 渲染器

创建 NumPy 数组的默认渲染器,Flytekit 将使用此渲染器在 Deck 中显示 NumPy 数组模式:

class NumpyRenderer:def to_html(self, df: np.ndarray) -> str:assert isinstance(df, np.ndarray)name = ["col" + str(i) for i in range(len(df))]table = pa.Table.from_arrays(df, name)return pd.DataFrame(table.schema).to_html(index=False)

最后,将编码器、解码器和渲染器注册到 StructuredDatasetTransformerEngine。指定要注册的 Python 类型(np.ndarray)、存储引擎(若未指定则假定适用于所有存储后端)和字节格式(本例为 PARQUET):

StructuredDatasetTransformerEngine.register(NumpyEncodingHandler(np.ndarray, None, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandler(np.ndarray, None, PARQUET))
StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer())

现在可以使用 numpy.ndarray 将 Parquet 文件反序列化为 NumPy,并将任务输出(NumPy 数组)序列化为 Parquet 文件:

@fl.task(container_image=image_spec)
def generate_pd_df_with_str() -> pd.DataFrame:return pd.DataFrame({"Name": ["Tom", "Joseph"]})@fl.task(container_image=image_spec)
def to_numpy(sd: StructuredDataset) -> Annotated[StructuredDataset, None, PARQUET]:numpy_array = sd.open(np.ndarray).all()return StructuredDataset(dataframe=numpy_array)@fl.workflow
def numpy_wf() -> Annotated[StructuredDataset, None, PARQUET]:return to_numpy(sd=generate_pd_df_with_str())

当 DataFrame 包含整数时,pyarrow 会抛出 Expected bytes, got a 'int' object 错误。

可本地运行代码如下:

if __name__ == "__main__":sd = simple_sd_wf()print(f"简单 Pandas DataFrame 工作流: {sd.open(pd.DataFrame).all()}")print(f"使用 CSV 作为序列化器: {pandas_to_csv_wf().open(pd.DataFrame).all()}")print(f"NumPy 编码器和解码器: {numpy_wf().open(np.ndarray).all()}")

嵌套类型列

与大多数存储格式(如 Avro、Parquet 和 BigQuery)类似,StructuredDataset 支持嵌套字段结构。

Flytekit 版本 > 1.11.0 时可运行嵌套字段 StructuredDataset:

data = [{"company": "XYZ pvt ltd","location": "London","info": {"president": "Rakesh Kapoor", "contacts": {"email": "contact@xyz.com", "tel": "9876543210"}},},{"company": "ABC pvt ltd","location": "USA","info": {"president": "Kapoor Rakesh", "contacts": {"email": "contact@abc.com", "tel": "0123456789"}},},
]@dataclass
class ContactsField:email: strtel: str@dataclass
class InfoField:president: strcontacts: ContactsField@dataclass
class CompanyField:location: strinfo: InfoFieldcompany: strMyArgDataset = Annotated[StructuredDataset, union.kwtypes(company=str)]
MyTopDataClassDataset = Annotated[StructuredDataset, CompanyField]
MyTopDictDataset = Annotated[StructuredDataset, {"company": str, "location": str}]MyDictDataset = Annotated[StructuredDataset, union.kwtypes(info={"contacts": {"tel": str}})]
MyDictListDataset = Annotated[StructuredDataset, union.kwtypes(info={"contacts": {"tel": str, "email": str}})]
MySecondDataClassDataset = Annotated[StructuredDataset, union.kwtypes(info=InfoField)]
MyNestedDataClassDataset = Annotated[StructuredDataset, union.kwtypes(info=union.kwtypes(contacts=ContactsField))]image = union.ImageSpec(packages=["pandas", "pyarrow", "pandas", "tabulate"], registry="ghcr.io/flyteorg")@fl.task(container_image=image)
def create_parquet_file() -> StructuredDataset:from tabulate import tabulatedf = pd.json_normalize(data, max_level=0)print("原始 DataFrame: \n", tabulate(df, headers="keys", tablefmt="psql"))return StructuredDataset(dataframe=df)@fl.task(container_image=image)
def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MyArgDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.task(container_image=image)
def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MyDictDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.task(container_image=image)
def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MyDictListDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.task(container_image=image)
def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MyTopDataClassDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.task(container_image=image)
def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MyTopDictDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.task(container_image=image)
def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MySecondDataClassDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.task(container_image=image)
def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame:from tabulate import tabulatet = sd.open(pd.DataFrame).all()print("MyNestedDataClassDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))return t@fl.workflow
def contacts_wf():sd = create_parquet_file()print_table_by_arg(sd=sd)print_table_by_dict(sd=sd)print_table_by_list_dict(sd=sd)print_table_by_top_dataclass(sd=sd)print_table_by_top_dict(sd=sd)print_table_by_second_dataclass(sd=sd)print_table_by_nested_dataclass(sd=sd)

TensorFlow 类型

本文概述了 Flyte 中可用的 TensorFlow 类型,这些类型有助于在 Flyte 工作流中集成 TensorFlow 模型和数据集。

导入必要的库和模块

import fl
from flytekit.types.directory import TFRecordsDirectory
from flytekit.types.file import TFRecordFilecustom_image = fl.ImageSpec(packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],registry="ghcr.io/flyteorg",
)import tensorflow as tf

TensorFlow 模型

Flyte 支持使用 TensorFlow SavedModel 格式对 tf.keras.Model 实例进行序列化和反序列化。TensorFlowModelTransformer 负责处理这些转换。

转换器

  • 名称: TensorFlow 模型
  • : TensorFlowModelTransformer
  • Python 类型: tf.keras.Model
  • Blob 格式: TensorFlowModel
  • 维度: MULTIPART

使用方式

TensorFlowModelTransformer 允许您将 TensorFlow 模型保存到远程存储位置,并在后续 Flyte 工作流中检索。

要克隆并运行本页示例代码,请参阅 Flytesnacks 仓库

@fl.task
def train_model() -> tf.keras.Model:model = tf.keras.Sequential([tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")])model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])return model@fl.task
def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float:loss, accuracy = model.evaluate(x, y)return accuracy@fl.workflow
def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float:model = train_model()return evaluate_model(model=model, x=x, y=y)

TFRecord 文件

Flyte 通过 TFRecordFile 类型支持 TFRecord 文件,该类型可以处理序列化的 TensorFlow 记录。TensorFlowRecordFileTransformer 负责管理 TFRecord 文件与 Flyte 字面量之间的转换。

转换器

  • 名称: TensorFlow 记录文件
  • : TensorFlowRecordFileTransformer
  • Blob 格式: TensorFlowRecord
  • 维度: SINGLE

使用方式

TensorFlowRecordFileTransformer 使您能够处理单个 TFRecord 文件,方便读写 TensorFlow 的 TFRecord 格式数据。

@fl.task
def process_tfrecord(file: TFRecordFile) -> int:count = 0for record in tf.data.TFRecordDataset(file):count += 1return count@fl.workflow
def tfrecord_workflow(file: TFRecordFile) -> int:return process_tfrecord(file=file)

TFRecord 目录

Flyte 通过 TFRecordsDirectory 类型支持包含多个 TFRecord 文件的目录。TensorFlowRecordsDirTransformer 负责管理 TFRecord 目录与 Flyte 字面量之间的转换。

转换器

  • 名称: TensorFlow 记录目录
  • : TensorFlowRecordsDirTransformer
  • Python 类型: TFRecordsDirectory
  • Blob 格式: TensorFlowRecord
  • 维度: MULTIPART

使用方式

TensorFlowRecordsDirTransformer 使您能够处理包含多个 TFRecord 文件的目录,这对处理跨多个文件分割的大型数据集非常有用。

示例
@fl.task
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:count = 0for record in tf.data.TFRecordDataset(dir.path):count += 1return count@fl.workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:return process_tfrecords_dir(dir=dir)

配置类:TFRecordDatasetConfig

TFRecordDatasetConfig 类是用于配置创建 tf.data.TFRecordDataset 参数的数据结构,该数据集支持高效读取 TFRecord 文件。此类使用 DataClassJsonMixin 实现便捷的 JSON 序列化。

属性

  • compression_type:(可选)指定 TFRecord 文件使用的压缩方法。可选值包括空字符串(无压缩)、“ZLIB” 或 “GZIP”
  • buffer_size:(可选)定义读取缓冲区的字节大小。如果未设置,将根据本地或远程文件系统使用默认值
  • num_parallel_reads:(可选)确定并行读取的文件数量。大于 1 的值将输出交错排序的记录
  • name:(可选)为操作分配名称以便在流水线中更易识别

该配置对于优化 TFRecord 数据集的读取过程至关重要,特别是在处理大型数据集或需要特定性能调优时。

风险提示与免责声明
本文内容基于公开信息研究整理,不构成任何形式的投资建议。历史表现不应作为未来收益保证,市场存在不可预见的波动风险。投资者需结合自身财务状况及风险承受能力独立决策,并自行承担交易结果。作者及发布方不对任何依据本文操作导致的损失承担法律责任。市场有风险,投资须谨慎。

相关文章:

  • 【常用算法:查找篇】9.AVL树深度解析:动态平衡二叉树的原理、实现与应用
  • USB传输速率 和 RS-232/RS-485串口协议速率 的倍数关系
  • 备忘录模式
  • 类的加载过程详解
  • LINQ:统一查询语法的强大工具
  • 服务端HttpServletRequest、HttpServletResponse、HttpSession
  • 前端动画库 Anime.js 的V4 版本,兼容 Vue、React
  • 初始C++:类和对象(中)
  • 游戏引擎学习第293天:移动Familiars
  • 线程池核心线程永续机制:从源码到实战的深度解析
  • 继MCP、A2A之上的“AG-UI”协议横空出世,人机交互迈入新纪元
  • 学习黑客Active Directory 入门指南(五)
  • 32LED心形灯程序源代码
  • Java大师成长计划之第26天:Spring生态与微服务架构之消息驱动的微服务
  • 4:OpenCV—保存图像
  • Spring AI Alibaba集成阿里云百炼大模型应用
  • 05 部署Nginx反向代理
  • 【Linux高级全栈开发】2.1.2 事件驱动reactor的原理与实现
  • 【运营商查询】批量手机号码归属地和手机运营商高速查询分类,按省份城市,按运营商移动联通电信快速分类导出Excel表格,基于WPF的实现方案
  • ChatGPT:OpenAI Codex—一款基于云的软件工程 AI 代理,赋能 ChatGPT,革新软件开发模式
  • 上海徐汇 “家 + 书屋”,创新服务广大家庭
  • 内蒙古赤峰市城建集团董事长孙广通拟任旗县区党委书记
  • 缅甸发生5.0级地震
  • 澎湃与七猫联合启动百万奖金征文,赋能非虚构与现实题材创作
  • 钕铁硼永磁材料龙头瞄准人形机器人,正海磁材:已向下游客户完成小批量供货
  • 上海黄浦江挡潮闸工程建设指挥部成立,组成人员名单公布