autogen_core中的DataclassJsonMessageSerializer类
源代码
import json
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable, Union
from pydantic import BaseModel
from types import NoneType, UnionType
def _type_name(cls: type[Any] | Any) -> str:
if isinstance(cls, type):
return cls.__name__
else:
return cast(str, cls.__class__.__name__)
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
T = TypeVar("T")
class MessageSerializer(Protocol[T]):
@property
def data_content_type(self) -> str: ...
@property
def type_name(self) -> str: ...
def deserialize(self, payload: bytes) -> T: ...
def serialize(self, message: T) -> bytes: ...
@runtime_checkable
class IsDataclass(Protocol):
# as already noted in comments, checking for this attribute is currently
# the most reliable way to ascertain that something is a dataclass
__dataclass_fields__: ClassVar[Dict[str, Any]]
def is_dataclass(cls: type[Any]) -> bool:
return hasattr(cls, "__dataclass_fields__")
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
# iterate fields and check if any of them are dataclasses
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
def contains_a_union(cls: type[IsDataclass]) -> bool:
return any(is_union(f.type) for f in cls.__dataclass_fields__.values())
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
for f in fields(cls):
field_type = f.type
# Resolve forward references and other annotations
origin = get_origin(field_type)
args = get_args(field_type)
# If the field type is directly a subclass of BaseModel
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return True
# If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.
if origin is not None and args:
for arg in args:
# Recursively check the argument types
if isinstance(arg, type) and issubclass(arg, BaseModel):
return True
elif get_origin(arg) is not None:
# Handle nested generics like List[List[BaseModel]]
if has_nested_base_model_in_type(arg):
return True
# Handle Union types
elif args:
for arg in args:
if isinstance(arg, type) and issubclass(arg, BaseModel):
return True
elif get_origin(arg) is not None:
if has_nested_base_model_in_type(arg):
return True
return False
def has_nested_base_model_in_type(tp: Any) -> bool:
"""Helper function to check if a type or its arguments is a BaseModel subclass."""
origin = get_origin(tp)
args = get_args(tp)
if isinstance(tp, type) and issubclass(tp, BaseModel):
return True
if origin is not None and args:
for arg in args:
if has_nested_base_model_in_type(arg):
return True
return False
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
JSON_DATA_CONTENT_TYPE = "application/json"
"""JSON data content type"""
# TODO: what's the correct content type? There seems to be some disagreement over what it should be
PROTOBUF_DATA_CONTENT_TYPE = "application/x-protobuf"
"""Protobuf data content type"""
class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
def __init__(self, cls: type[DataclassT]) -> None:
if contains_a_union(cls):
raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")
if has_nested_dataclass(cls) or has_nested_base_model(cls):
raise ValueError(
"Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model"
)
self.cls = cls
@property
def data_content_type(self) -> str:
return JSON_DATA_CONTENT_TYPE
@property
def type_name(self) -> str:
return _type_name(self.cls)
def deserialize(self, payload: bytes) -> DataclassT:
message_str = payload.decode("utf-8")
return self.cls(**json.loads(message_str))
def serialize(self, message: DataclassT) -> bytes:
return json.dumps(asdict(message)).encode("utf-8")
代码解释
这段代码定义了一个用于序列化和反序列化数据类(dataclass)和 Pydantic 模型的序列化器,特别是针对 JSON 格式。它的目标是提供一种结构化的方式来序列化和反序列化消息,并检查不支持的特性,例如嵌套的数据类、嵌套的 Pydantic 模型以及数据类中联合类型(union type)。
1. 导入
import json
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable, Union
from pydantic import BaseModel
from types import NoneType, UnionType
这部分导入了必要的模块:
json
: 用于 JSON 序列化/反序列化。dataclasses
: 用于处理数据类 (asdict
,dataclass
,fields
)。typing
: 用于类型提示和处理泛型 (Any
,ClassVar
,Dict
,List
,Protocol
,Sequence
,TypeVar
,cast
,get_args
,get_origin
,runtime_checkable
,Union
)。pydantic
: 用于处理 Pydantic 模型 (BaseModel
)。types
: 用于处理类型提示,包括NoneType
和UnionType
。
2. 辅助函数
_type_name(cls)
: 返回类或对象的名称。is_union(t)
: 检查类型t
是否为Union
或UnionType
。is_dataclass(cls)
: 检查类cls
是否为数据类。has_nested_dataclass(cls)
: 检查数据类cls
是否包含任何嵌套的数据类作为字段。contains_a_union(cls)
: 检查数据类cls
是否包含任何Union
类型作为字段类型。has_nested_base_model(cls)
: 检查数据类cls
是否包含任何嵌套的 PydanticBaseModel
子类作为字段,包括处理List[BaseModel]
等泛型和嵌套泛型。has_nested_base_model_in_type(tp)
:has_nested_base_model
使用的辅助函数,用于递归检查类型或其参数是否为BaseModel
子类。
3. 类型变量和常量
T
: 一个泛型类型变量。DataclassT
: 绑定到IsDataclass
的类型变量。JSON_DATA_CONTENT_TYPE
: JSON 内容类型的常量。PROTOBUF_DATA_CONTENT_TYPE
: Protobuf 内容类型的常量(虽然在 JSON 序列化器中未使用)。
4. MessageSerializer
协议
class MessageSerializer(Protocol[T]):
@property
def data_content_type(self) -> str: ...
@property
def type_name(self) -> str: ...
def deserialize(self, payload: bytes) -> T: ...
def serialize(self, message: T) -> bytes: ...
定义了消息序列化器的协议(接口)。任何实现此协议的类都必须定义指定的属性和方法:
data_content_type
: 序列化数据的 MIME 类型。type_name
: 被序列化类型的名称。deserialize(payload)
: 将字节负载反序列化为T
类型的对象。serialize(message)
: 将T
类型的对象序列化为字节负载。
5. IsDataclass
协议
@runtime_checkable
class IsDataclass(Protocol):
__dataclass_fields__: ClassVar[Dict[str, Any]]
一个运行时可检查的协议,用于确定类是否为数据类。它检查 __dataclass_fields__
属性是否存在。
6. DataclassJsonMessageSerializer
类
class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
# ...
此类实现了 MessageSerializer
协议,用于数据类的 JSON 序列化。
__init__(self, cls)
: 构造函数,它接受数据类类型cls
作为参数。它执行关键检查:- 如果数据类包含
Union
类型,则引发ValueError
(使用 Pydantic 处理联合类型)。 - 如果数据类包含嵌套的数据类或 Pydantic 模型,则引发
ValueError
(使用 Pydantic 处理嵌套类型)。
- 如果数据类包含
data_content_type
属性:返回JSON_DATA_CONTENT_TYPE
。type_name
属性:返回数据类的名称。deserialize(payload)
: 将 JSON 负载反序列化为数据类的实例。serialize(message)
: 将数据类实例序列化为 JSON 负载。
总结
这段代码提供了一个专门用于简单数据类(没有联合或嵌套的数据类/模型)的 JSON 序列化器。它强调使用 Pydantic 来处理更复杂的场景,从而实现了清晰的关注点分离,并使序列化过程更具可预测性。这是一种很好的方法,可以避免直接在数据类中泛型序列化嵌套结构和联合类型的复杂性。
几个例子
@dataclass
class Point:
x: int
y: int
serializer = DataclassJsonMessageSerializer(Point)
point = Point(x=10, y=20)
serialized_data = serializer.serialize(point)
print(f"Serialized data: {serialized_data}") # Output: b'{"x": 10, "y": 20}'
deserialized_point = serializer.deserialize(serialized_data)
print(f"Deserialized point: {deserialized_point}") # Output: Point(x=10, y=20)
print(isinstance(deserialized_point, Point)) # Output: True
Serialized data: b'{"x": 10, "y": 20}'
Deserialized point: Point(x=10, y=20)
True
@dataclass
class Line:
points: list[Point]
line = Line(points=[Point(x=1, y=2), Point(x=3, y=4)])
serializer_line = DataclassJsonMessageSerializer(Line)
serialized_line = serializer_line.serialize(line)
print(f"Serialized line: {serialized_line}") # Output: b'{"points": [{"x": 1, "y": 2}, {"x": 3, "y": 4}]}'
deserialized_line = serializer_line.deserialize(serialized_line)
print(f"Deserialized line: {deserialized_line}") # Output: Line(points=[Point(x=1, y=2), Point(x=3, y=4)])
print(isinstance(deserialized_line, Line)) # Output: True
Serialized line: b'{"points": [{"x": 1, "y": 2}, {"x": 3, "y": 4}]}'
Deserialized line: Line(points=[{'x': 1, 'y': 2}, {'x': 3, 'y': 4}])
True
@dataclass
class Rectangle:
top_left: Point
bottom_right: Point
# 这会抛出 ValueError,因为 Rectangle 包含嵌套的数据类 Point
try:
serializer_rectangle = DataclassJsonMessageSerializer(Rectangle)
except ValueError as e:
print(f"Error: {e}") # Output: Error: Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model
Error: Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model
from typing import Union
@dataclass
class Shape:
shape_type: Union[str, int]
try:
serializer_shape = DataclassJsonMessageSerializer(Shape)
except ValueError as e:
print(f"Error: {e}") # Output: Error: Dataclass has a union type, which is not supported. To use a union, use a Pydantic model
Error: Dataclass has a union type, which is not supported. To use a union, use a Pydantic model