当前位置: 首页 > news >正文

autogen_core中的DataclassJsonMessageSerializer类

源代码

import json
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable, Union
from pydantic import BaseModel

from types import NoneType, UnionType


def _type_name(cls: type[Any] | Any) -> str:
    if isinstance(cls, type):
        return cls.__name__
    else:
        return cast(str, cls.__class__.__name__)


def is_union(t: object) -> bool:
    origin = get_origin(t)
    return origin is Union or origin is UnionType

T = TypeVar("T")


class MessageSerializer(Protocol[T]):
    @property
    def data_content_type(self) -> str: ...

    @property
    def type_name(self) -> str: ...

    def deserialize(self, payload: bytes) -> T: ...

    def serialize(self, message: T) -> bytes: ...


@runtime_checkable
class IsDataclass(Protocol):
    # as already noted in comments, checking for this attribute is currently
    # the most reliable way to ascertain that something is a dataclass
    __dataclass_fields__: ClassVar[Dict[str, Any]]


def is_dataclass(cls: type[Any]) -> bool:
    return hasattr(cls, "__dataclass_fields__")


def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
    # iterate fields and check if any of them are dataclasses
    return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())


def contains_a_union(cls: type[IsDataclass]) -> bool:
    return any(is_union(f.type) for f in cls.__dataclass_fields__.values())


def has_nested_base_model(cls: type[IsDataclass]) -> bool:
    for f in fields(cls):
        field_type = f.type
        # Resolve forward references and other annotations
        origin = get_origin(field_type)
        args = get_args(field_type)

        # If the field type is directly a subclass of BaseModel
        if isinstance(field_type, type) and issubclass(field_type, BaseModel):
            return True

        # If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.
        if origin is not None and args:
            for arg in args:
                # Recursively check the argument types
                if isinstance(arg, type) and issubclass(arg, BaseModel):
                    return True
                elif get_origin(arg) is not None:
                    # Handle nested generics like List[List[BaseModel]]
                    if has_nested_base_model_in_type(arg):
                        return True
        # Handle Union types
        elif args:
            for arg in args:
                if isinstance(arg, type) and issubclass(arg, BaseModel):
                    return True
                elif get_origin(arg) is not None:
                    if has_nested_base_model_in_type(arg):
                        return True
    return False


def has_nested_base_model_in_type(tp: Any) -> bool:
    """Helper function to check if a type or its arguments is a BaseModel subclass."""
    origin = get_origin(tp)
    args = get_args(tp)

    if isinstance(tp, type) and issubclass(tp, BaseModel):
        return True
    if origin is not None and args:
        for arg in args:
            if has_nested_base_model_in_type(arg):
                return True
    return False


DataclassT = TypeVar("DataclassT", bound=IsDataclass)

JSON_DATA_CONTENT_TYPE = "application/json"
"""JSON data content type"""

# TODO: what's the correct content type? There seems to be some disagreement over what it should be
PROTOBUF_DATA_CONTENT_TYPE = "application/x-protobuf"
"""Protobuf data content type"""


class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
    def __init__(self, cls: type[DataclassT]) -> None:
        if contains_a_union(cls):
            raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")

        if has_nested_dataclass(cls) or has_nested_base_model(cls):
            raise ValueError(
                "Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model"
            )

        self.cls = cls

    @property
    def data_content_type(self) -> str:
        return JSON_DATA_CONTENT_TYPE

    @property
    def type_name(self) -> str:
        return _type_name(self.cls)

    def deserialize(self, payload: bytes) -> DataclassT:
        message_str = payload.decode("utf-8")
        return self.cls(**json.loads(message_str))

    def serialize(self, message: DataclassT) -> bytes:
        return json.dumps(asdict(message)).encode("utf-8")

代码解释

这段代码定义了一个用于序列化和反序列化数据类(dataclass)和 Pydantic 模型的序列化器,特别是针对 JSON 格式。它的目标是提供一种结构化的方式来序列化和反序列化消息,并检查不支持的特性,例如嵌套的数据类、嵌套的 Pydantic 模型以及数据类中联合类型(union type)。

1. 导入

import json
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable, Union
from pydantic import BaseModel

from types import NoneType, UnionType

这部分导入了必要的模块:

  • json: 用于 JSON 序列化/反序列化。
  • dataclasses: 用于处理数据类 (asdict, dataclass, fields)。
  • typing: 用于类型提示和处理泛型 (Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable, Union)。
  • pydantic: 用于处理 Pydantic 模型 (BaseModel)。
  • types: 用于处理类型提示,包括 NoneTypeUnionType

2. 辅助函数

  • _type_name(cls): 返回类或对象的名称。
  • is_union(t): 检查类型 t 是否为 UnionUnionType
  • is_dataclass(cls): 检查类 cls 是否为数据类。
  • has_nested_dataclass(cls): 检查数据类 cls 是否包含任何嵌套的数据类作为字段。
  • contains_a_union(cls): 检查数据类 cls 是否包含任何 Union 类型作为字段类型。
  • has_nested_base_model(cls): 检查数据类 cls 是否包含任何嵌套的 Pydantic BaseModel 子类作为字段,包括处理 List[BaseModel] 等泛型和嵌套泛型。
  • has_nested_base_model_in_type(tp): has_nested_base_model 使用的辅助函数,用于递归检查类型或其参数是否为 BaseModel 子类。

3. 类型变量和常量

  • T: 一个泛型类型变量。
  • DataclassT: 绑定到 IsDataclass 的类型变量。
  • JSON_DATA_CONTENT_TYPE: JSON 内容类型的常量。
  • PROTOBUF_DATA_CONTENT_TYPE: Protobuf 内容类型的常量(虽然在 JSON 序列化器中未使用)。

4. MessageSerializer 协议

class MessageSerializer(Protocol[T]):
    @property
    def data_content_type(self) -> str: ...

    @property
    def type_name(self) -> str: ...

    def deserialize(self, payload: bytes) -> T: ...

    def serialize(self, message: T) -> bytes: ...

定义了消息序列化器的协议(接口)。任何实现此协议的类都必须定义指定的属性和方法:

  • data_content_type: 序列化数据的 MIME 类型。
  • type_name: 被序列化类型的名称。
  • deserialize(payload): 将字节负载反序列化为 T 类型的对象。
  • serialize(message): 将 T 类型的对象序列化为字节负载。

5. IsDataclass 协议

@runtime_checkable
class IsDataclass(Protocol):
    __dataclass_fields__: ClassVar[Dict[str, Any]]

一个运行时可检查的协议,用于确定类是否为数据类。它检查 __dataclass_fields__ 属性是否存在。

6. DataclassJsonMessageSerializer

class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
    # ...

此类实现了 MessageSerializer 协议,用于数据类的 JSON 序列化。

  • __init__(self, cls): 构造函数,它接受数据类类型 cls 作为参数。它执行关键检查:
    • 如果数据类包含 Union 类型,则引发 ValueError(使用 Pydantic 处理联合类型)。
    • 如果数据类包含嵌套的数据类或 Pydantic 模型,则引发 ValueError(使用 Pydantic 处理嵌套类型)。
  • data_content_type 属性:返回 JSON_DATA_CONTENT_TYPE
  • type_name 属性:返回数据类的名称。
  • deserialize(payload): 将 JSON 负载反序列化为数据类的实例。
  • serialize(message): 将数据类实例序列化为 JSON 负载。

总结

这段代码提供了一个专门用于简单数据类(没有联合或嵌套的数据类/模型)的 JSON 序列化器。它强调使用 Pydantic 来处理更复杂的场景,从而实现了清晰的关注点分离,并使序列化过程更具可预测性。这是一种很好的方法,可以避免直接在数据类中泛型序列化嵌套结构和联合类型的复杂性。

几个例子

@dataclass
class Point:
    x: int
    y: int

serializer = DataclassJsonMessageSerializer(Point)

point = Point(x=10, y=20)
serialized_data = serializer.serialize(point)
print(f"Serialized data: {serialized_data}")  # Output: b'{"x": 10, "y": 20}'

deserialized_point = serializer.deserialize(serialized_data)
print(f"Deserialized point: {deserialized_point}")  # Output: Point(x=10, y=20)

print(isinstance(deserialized_point, Point)) # Output: True
Serialized data: b'{"x": 10, "y": 20}'
Deserialized point: Point(x=10, y=20)
True
@dataclass
class Line:
    points: list[Point]

line = Line(points=[Point(x=1, y=2), Point(x=3, y=4)])
serializer_line = DataclassJsonMessageSerializer(Line)

serialized_line = serializer_line.serialize(line)
print(f"Serialized line: {serialized_line}") # Output: b'{"points": [{"x": 1, "y": 2}, {"x": 3, "y": 4}]}'

deserialized_line = serializer_line.deserialize(serialized_line)
print(f"Deserialized line: {deserialized_line}") # Output: Line(points=[Point(x=1, y=2), Point(x=3, y=4)])

print(isinstance(deserialized_line, Line)) # Output: True
Serialized line: b'{"points": [{"x": 1, "y": 2}, {"x": 3, "y": 4}]}'
Deserialized line: Line(points=[{'x': 1, 'y': 2}, {'x': 3, 'y': 4}])
True
@dataclass
class Rectangle:
    top_left: Point
    bottom_right: Point

# 这会抛出 ValueError,因为 Rectangle 包含嵌套的数据类 Point
try:
    serializer_rectangle = DataclassJsonMessageSerializer(Rectangle)
except ValueError as e:
    print(f"Error: {e}")  # Output: Error: Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model
Error: Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model
from typing import Union

@dataclass
class Shape:
    shape_type: Union[str, int]

try:
    serializer_shape = DataclassJsonMessageSerializer(Shape)
except ValueError as e:
    print(f"Error: {e}") # Output: Error: Dataclass has a union type, which is not supported. To use a union, use a Pydantic model
Error: Dataclass has a union type, which is not supported. To use a union, use a Pydantic model
http://www.dtcms.com/a/20216.html

相关文章:

  • Mybatis高级(动态SQL)
  • 基于CanMV IDE 开发软件对K210图像识别模块的开发
  • 2025 (ISC)²CCSP 回忆录
  • 【前端】 react项目使用bootstrap、useRef和useState之间的区别和应用
  • AWS上基于高德地图API验证Amazon Redshift里国内地址数据正确性的设计方案
  • AI法理学与责任归属:技术演进下的法律重构与伦理挑战
  • 【问】强学如何支持 迁移学习呢?
  • 网络安全威胁是什么
  • 【STM32】江科大STM32学习笔记汇总(已完结)
  • Ubuntu 系统迁移
  • C语言(枚举类型)
  • C++ ——this指针
  • Rhel Centos环境开关机自动脚本
  • flutter本地推送 flutter_local_notifications的使用记录
  • Java面试题总结 - Java集合篇(附答案)
  • 一种访问网络中主机图片的方法
  • 深度学习框架PyTorch
  • 4090单卡挑战DeepSeek r1 671b:尝试量化后的心得的分享
  • 鸿蒙Next开发-添加水印以及点击穿透设置
  • c++中什么时候应该使用final关键字?
  • 143,【3】 buuctf web [GYCTF2020]EasyThinking
  • 【ISO 14229-1:2023 UDS诊断(会话控制0x10服务)测试用例CAPL代码全解析③】
  • 强化学习-NPG
  • Zbrush导入笔刷
  • 解锁电商数据宝藏:淘宝商品详情API实战指南
  • 内容中台构建高效数字化内容管理新范式
  • 【ISO 14229-1:2023 UDS诊断全量测试用例清单系列:第十三节】
  • 硬件开发笔记(三十四):AHD转MIPI国产方案详解XS9922B(一):芯片方案介绍
  • kubekey一键部署k8s高可用与kubesphere
  • 图像质量评价指标-UCIQE-UIQM