Skip to content

Commit 556bf56

Browse files
giacbrdgiacbrdDouweM
authored
Use model class names as tags in format_as_xml and add option to include field titles and descriptions as attributes (#2313)
Co-authored-by: giacbrd <[email protected]> Co-authored-by: Douwe Maan <[email protected]>
1 parent 961a666 commit 556bf56

File tree

2 files changed

+504
-20
lines changed

2 files changed

+504
-20
lines changed

pydantic_ai_slim/pydantic_ai/format_prompt.py

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import Iterable, Iterator, Mapping
4-
from dataclasses import asdict, dataclass, is_dataclass
4+
from dataclasses import asdict, dataclass, field, fields, is_dataclass
55
from datetime import date
6-
from typing import Any
6+
from typing import Any, Literal
77
from xml.etree import ElementTree
88

99
from pydantic import BaseModel
1010

1111
__all__ = ('format_as_xml',)
1212

13+
from pydantic.fields import ComputedFieldInfo, FieldInfo
14+
1315

1416
def format_as_xml(
1517
obj: Any,
1618
root_tag: str | None = None,
1719
item_tag: str = 'item',
1820
none_str: str = 'null',
1921
indent: str | None = ' ',
22+
include_field_info: Literal['once'] | bool = False,
2023
) -> str:
2124
"""Format a Python object as XML.
2225
@@ -33,6 +36,10 @@ def format_as_xml(
3336
for dataclasses and Pydantic models.
3437
none_str: String to use for `None` values.
3538
indent: Indentation string to use for pretty printing.
39+
include_field_info: Whether to include attributes like Pydantic `Field` attributes and dataclasses `field()`
40+
`metadata` as XML attributes. In both cases the allowed `Field` attributes and `field()` metadata keys are
41+
`title` and `description`. If a field is repeated in the data (e.g. in a list) by setting `once`
42+
the attributes are included only in the first occurrence of an XML element relative to the same field.
3643
3744
Returns:
3845
XML representation of the object.
@@ -51,7 +58,12 @@ def format_as_xml(
5158
'''
5259
```
5360
"""
54-
el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
61+
el = _ToXml(
62+
data=obj,
63+
item_tag=item_tag,
64+
none_str=none_str,
65+
include_field_info=include_field_info,
66+
).to_xml(root_tag)
5567
if root_tag is None and el.text is None:
5668
join = '' if indent is None else '\n'
5769
return join.join(_rootless_xml_elements(el, indent))
@@ -63,11 +75,26 @@ def format_as_xml(
6375

6476
@dataclass
6577
class _ToXml:
78+
data: Any
6679
item_tag: str
6780
none_str: str
68-
69-
def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
70-
element = ElementTree.Element(self.item_tag if tag is None else tag)
81+
include_field_info: Literal['once'] | bool
82+
# a map of Pydantic and dataclasses Field paths to their metadata:
83+
# a field unique string representation and its class
84+
_fields_info: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] = field(default_factory=dict)
85+
# keep track of fields we have extracted attributes from
86+
_included_fields: set[str] = field(default_factory=set)
87+
# keep track of class names for dataclasses and Pydantic models, that occur in lists
88+
_element_names: dict[str, str] = field(default_factory=dict)
89+
# flag for parsing dataclasses and Pydantic models once
90+
_is_info_extracted: bool = False
91+
_FIELD_ATTRIBUTES = ('title', 'description')
92+
93+
def to_xml(self, tag: str | None = None) -> ElementTree.Element:
94+
return self._to_xml(value=self.data, path='', tag=tag)
95+
96+
def _to_xml(self, value: Any, path: str, tag: str | None = None) -> ElementTree.Element:
97+
element = self._create_element(self.item_tag if tag is None else tag, path)
7198
if value is None:
7299
element.text = self.none_str
73100
elif isinstance(value, str):
@@ -79,31 +106,96 @@ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
79106
elif isinstance(value, date):
80107
element.text = value.isoformat()
81108
elif isinstance(value, Mapping):
82-
self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
109+
if tag is None and path in self._element_names:
110+
element.tag = self._element_names[path]
111+
self._mapping_to_xml(element, value, path) # pyright: ignore[reportUnknownArgumentType]
83112
elif is_dataclass(value) and not isinstance(value, type):
113+
self._init_structure_info()
84114
if tag is None:
85-
element = ElementTree.Element(value.__class__.__name__)
86-
dc_dict = asdict(value)
87-
self._mapping_to_xml(element, dc_dict)
115+
element.tag = value.__class__.__name__
116+
self._mapping_to_xml(element, asdict(value), path)
88117
elif isinstance(value, BaseModel):
118+
self._init_structure_info()
89119
if tag is None:
90-
element = ElementTree.Element(value.__class__.__name__)
91-
self._mapping_to_xml(element, value.model_dump(mode='python'))
120+
element.tag = value.__class__.__name__
121+
# by dumping the model we loose all metadata in nested data structures,
122+
# but we have collected it when called _init_structure_info
123+
self._mapping_to_xml(element, value.model_dump(), path)
92124
elif isinstance(value, Iterable):
93-
for item in value: # pyright: ignore[reportUnknownVariableType]
94-
item_el = self.to_xml(item, None)
95-
element.append(item_el)
125+
for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
126+
element.append(self._to_xml(value=item, path=f'{path}.[{n}]' if path else f'[{n}]'))
96127
else:
97128
raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
98129
return element
99130

100-
def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
131+
def _create_element(self, tag: str, path: str) -> ElementTree.Element:
132+
element = ElementTree.Element(tag)
133+
if path in self._fields_info:
134+
field_repr, field_info = self._fields_info[path]
135+
if self.include_field_info and self.include_field_info != 'once' or field_repr not in self._included_fields:
136+
field_attributes = self._extract_attributes(field_info)
137+
for k, v in field_attributes.items():
138+
element.set(k, v)
139+
self._included_fields.add(field_repr)
140+
return element
141+
142+
def _init_structure_info(self):
143+
"""Create maps with all data information (fields info and class names), if not already created."""
144+
if not self._is_info_extracted:
145+
self._parse_data_structures(self.data)
146+
self._is_info_extracted = True
147+
148+
def _mapping_to_xml(
149+
self,
150+
element: ElementTree.Element,
151+
mapping: Mapping[Any, Any],
152+
path: str = '',
153+
) -> None:
101154
for key, value in mapping.items():
102155
if isinstance(key, int):
103156
key = str(key)
104157
elif not isinstance(key, str):
105158
raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
106-
element.append(self.to_xml(value, key))
159+
element.append(self._to_xml(value=value, path=f'{path}.{key}' if path else key, tag=key))
160+
161+
def _parse_data_structures(
162+
self,
163+
value: Any,
164+
path: str = '',
165+
):
166+
"""Parse data structures as dataclasses or Pydantic models to extract element names and attributes."""
167+
if value is None or isinstance(value, (str | int | float | date | bytearray | bytes | bool)):
168+
return
169+
elif isinstance(value, Mapping):
170+
for k, v in value.items(): # pyright: ignore[reportUnknownVariableType]
171+
self._parse_data_structures(v, f'{path}.{k}' if path else f'{k}')
172+
elif is_dataclass(value) and not isinstance(value, type):
173+
self._element_names[path] = value.__class__.__name__
174+
for field in fields(value):
175+
new_path = f'{path}.{field.name}' if path else field.name
176+
if self.include_field_info and field.metadata:
177+
attributes = {k: v for k, v in field.metadata.items() if k in self._FIELD_ATTRIBUTES}
178+
if attributes:
179+
field_repr = f'{value.__class__.__name__}.{field.name}'
180+
self._fields_info[new_path] = (field_repr, FieldInfo(**attributes))
181+
self._parse_data_structures(getattr(value, field.name), new_path)
182+
elif isinstance(value, BaseModel):
183+
self._element_names[path] = value.__class__.__name__
184+
for model_fields in (value.__class__.model_fields, value.__class__.model_computed_fields):
185+
for field, info in model_fields.items():
186+
new_path = f'{path}.{field}' if path else field
187+
if self.include_field_info and (isinstance(info, ComputedFieldInfo) or not info.exclude):
188+
field_repr = f'{value.__class__.__name__}.{field}'
189+
self._fields_info[new_path] = (field_repr, info)
190+
self._parse_data_structures(getattr(value, field), new_path)
191+
elif isinstance(value, Iterable):
192+
for n, item in enumerate(value): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
193+
new_path = f'{path}.[{n}]' if path else f'[{n}]'
194+
self._parse_data_structures(item, new_path)
195+
196+
@classmethod
197+
def _extract_attributes(cls, info: FieldInfo | ComputedFieldInfo) -> dict[str, str]:
198+
return {attr: str(value) for attr in cls._FIELD_ATTRIBUTES if (value := getattr(info, attr, None)) is not None}
107199

108200

109201
def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:

0 commit comments

Comments
 (0)