Skip to content

Commit 51c6b0f

Browse files
committed
feat: add support computed fields
1 parent b2f3fda commit 51c6b0f

File tree

8 files changed

+173
-12
lines changed

8 files changed

+173
-12
lines changed

beanie/odm/fields.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,61 @@ def __deepcopy__(self, memo):
276276
return self
277277

278278

279+
class ExpressionFieldProperty(property):
280+
def __init__(
281+
self, original_property: property, expression_field: ExpressionField
282+
):
283+
self._original = original_property
284+
self._expression_field = expression_field
285+
super().__init__(
286+
original_property.fget,
287+
original_property.fset,
288+
original_property.fdel,
289+
original_property.__doc__,
290+
)
291+
292+
def __getitem__(self, item):
293+
return ExpressionField(f"{self._expression_field}.{item}")
294+
295+
def __getattr__(self, item):
296+
return ExpressionField(f"{self._expression_field}.{item}")
297+
298+
def __hash__(self):
299+
return hash(str(self._expression_field))
300+
301+
def __eq__(self, other):
302+
if isinstance(other, ExpressionField):
303+
return super(ExpressionField, self._expression_field).__eq__(other)
304+
return Eq(field=self, other=other)
305+
306+
def __gt__(self, other):
307+
return GT(field=self._expression_field, other=other)
308+
309+
def __ge__(self, other):
310+
return GTE(field=self._expression_field, other=other)
311+
312+
def __lt__(self, other):
313+
return LT(field=self._expression_field, other=other)
314+
315+
def __le__(self, other):
316+
return LTE(field=self._expression_field, other=other)
317+
318+
def __ne__(self, other):
319+
return NE(field=self._expression_field, other=other)
320+
321+
def __pos__(self):
322+
return self._expression_field, SortDirection.ASCENDING
323+
324+
def __neg__(self):
325+
return self._expression_field, SortDirection.DESCENDING
326+
327+
def __copy__(self):
328+
return self._expression_field
329+
330+
def __deepcopy__(self, memo):
331+
return self._expression_field
332+
333+
279334
class DeleteRules(str, Enum):
280335
DO_NOTHING = "DO_NOTHING"
281336
DELETE_LINKS = "DELETE_LINKS"

beanie/odm/utils/encoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from beanie.odm.utils.pydantic import (
2828
IS_PYDANTIC_V2,
2929
IS_PYDANTIC_V2_10,
30+
get_model_all_items,
3031
get_model_fields,
3132
)
3233

@@ -157,7 +158,7 @@ def _iter_model_items(
157158
) -> Iterable[Tuple[str, Any]]:
158159
keep_nulls = self.keep_nulls
159160
get_model_field = get_model_fields(obj).get
160-
for key, value in obj.__iter__():
161+
for key, value in get_model_all_items(obj).items():
161162
field_info = get_model_field(key)
162163
if field_info is not None:
163164
key = field_info.alias or key
@@ -167,7 +168,9 @@ def _iter_model_items(
167168
yield key, value
168169

169170
def _should_exclude_field(
170-
self, key: str, field_info: Optional[pydantic.fields.FieldInfo]
171+
self,
172+
key: str,
173+
field_info: Any,
171174
):
172175
exclude, include = (
173176
self.exclude,
@@ -180,7 +183,7 @@ def _should_exclude_field(
180183
is_pydantic_excluded_field = (
181184
field_info is not None
182185
and (
183-
field_info.exclude
186+
getattr(field_info, "exclude", None)
184187
if IS_PYDANTIC_V2
185188
else getattr(field_info.field_info, "exclude")
186189
)

beanie/odm/utils/init.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from beanie.odm.utils.pydantic import (
77
IS_PYDANTIC_V2,
88
get_extra_field_info,
9+
get_field_type,
910
get_model_fields,
1011
parse_model,
1112
)
@@ -37,6 +38,7 @@
3738
from beanie.odm.fields import (
3839
BackLink,
3940
ExpressionField,
41+
ExpressionFieldProperty,
4042
Link,
4143
LinkInfo,
4244
LinkTypes,
@@ -215,9 +217,9 @@ def detect_link(
215217
:param field: ModelField
216218
:return: Optional[LinkInfo]
217219
"""
218-
219-
origin = get_origin(field.annotation)
220-
args = get_args(field.annotation)
220+
annotation = get_field_type(field)
221+
origin = get_origin(annotation)
222+
args = get_args(annotation)
221223
classes = [
222224
Link,
223225
BackLink,
@@ -226,8 +228,8 @@ def detect_link(
226228
for cls in classes:
227229
# Check if annotation is one of the custom classes
228230
if (
229-
isinstance(field.annotation, _GenericAlias)
230-
and field.annotation.__origin__ is cls
231+
isinstance(annotation, _GenericAlias)
232+
and annotation.__origin__ is cls
231233
):
232234
if cls is Link:
233235
return LinkInfo(
@@ -398,7 +400,14 @@ def init_document_fields(self, cls) -> None:
398400
cls._link_fields = {}
399401
for k, v in get_model_fields(cls).items():
400402
path = v.alias or k
401-
setattr(cls, k, ExpressionField(path))
403+
attr = getattr(cls, k, None)
404+
expression_field = ExpressionField(path)
405+
if isinstance(attr, property):
406+
setattr(
407+
cls, k, ExpressionFieldProperty(attr, expression_field)
408+
)
409+
else:
410+
setattr(cls, k, expression_field)
402411

403412
link_info = self.detect_link(v, k)
404413
depth_level = cls.get_settings().max_nesting_depths_per_field.get(

beanie/odm/utils/pydantic.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
if IS_PYDANTIC_V2:
1212
from pydantic import TypeAdapter
13+
from pydantic.fields import ComputedFieldInfo
1314
else:
1415
from pydantic import parse_obj_as
1516

@@ -23,18 +24,36 @@ def parse_object_as(object_type: Type, data: Any):
2324

2425
def get_field_type(field):
2526
if IS_PYDANTIC_V2:
26-
return field.annotation
27+
if isinstance(field, ComputedFieldInfo):
28+
return field.return_type
29+
else:
30+
return field.annotation
2731
else:
2832
return field.outer_type_
2933

3034

3135
def get_model_fields(model):
3236
if IS_PYDANTIC_V2:
33-
return model.model_fields
37+
return {**model.model_fields, **model.model_computed_fields}
3438
else:
3539
return model.__fields__
3640

3741

42+
def get_model_all_items(model):
43+
if IS_PYDANTIC_V2:
44+
return {
45+
**dict(model.__iter__()),
46+
**{
47+
key: getattr(model, key)
48+
for key in {
49+
**model.model_computed_fields,
50+
}.keys()
51+
},
52+
}
53+
else:
54+
return dict(model.__iter__())
55+
56+
3857
def parse_model(model_type: Type[BaseModel], data: Any):
3958
if IS_PYDANTIC_V2:
4059
return model_type.model_validate(data)

tests/odm/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DocumentWithBsonBinaryField,
4242
DocumentWithBsonEncodersFiledsTypes,
4343
DocumentWithComplexDictKey,
44+
DocumentWithComputedField,
4445
DocumentWithCustomFiledsTypes,
4546
DocumentWithCustomIdInt,
4647
DocumentWithCustomIdUUID,
@@ -211,6 +212,7 @@
211212
BsonRegexDoc,
212213
NativeRegexDoc,
213214
DocumentWithExcludedField,
215+
DocumentWithComputedField,
214216
]
215217

216218

tests/odm/models.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
6262

6363
if IS_PYDANTIC_V2:
64-
from pydantic import RootModel, validate_call
64+
from pydantic import RootModel, computed_field, validate_call
6565

6666
if sys.version_info >= (3, 10):
6767

@@ -912,6 +912,30 @@ class DocumentWithExcludedField(Document):
912912
excluded_field: Optional[int] = Field(default=None, exclude=True)
913913

914914

915+
class DocumentWithComputedField(Document):
916+
num: int
917+
918+
_cached_uuid: Optional[str] = None
919+
920+
if IS_PYDANTIC_V2:
921+
922+
@computed_field
923+
@property
924+
def doubled(self) -> int:
925+
return self.num * 2
926+
927+
@computed_field
928+
@property
929+
def cacheable_uuid(self) -> str:
930+
if self._cached_uuid is None:
931+
self._cached_uuid = str(uuid4())
932+
return self._cached_uuid
933+
934+
@cacheable_uuid.setter
935+
def cacheable_uuid(self, new: str) -> None:
936+
self._cached_uuid = new
937+
938+
915939
class ReleaseElemMatch(BaseModel):
916940
major_ver: int
917941
minor_ver: int

tests/odm/test_encoder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DocumentForEncodingTest,
1717
DocumentForEncodingTestDate,
1818
DocumentWithComplexDictKey,
19+
DocumentWithComputedField,
1920
DocumentWithDecimalField,
2021
DocumentWithEnumKeysDict,
2122
DocumentWithExcludedField,
@@ -146,6 +147,15 @@ async def test_excluded():
146147
assert "excluded_field" not in encoded_doc
147148

148149

150+
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Test only for Pydantic v2")
151+
def test_computed_field():
152+
doc = DocumentWithComputedField(num=1)
153+
encoded_doc = Encoder().encode(doc)
154+
print(doc)
155+
print(encoded_doc)
156+
assert encoded_doc["doubled"] == 2
157+
158+
149159
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Test only for Pydantic v2")
150160
def test_should_encode_pydantic_v2_url_correctly():
151161
url = AnyUrl("https://example.com")

tests/odm/test_fields.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DocumentTestModel,
1818
DocumentTestModelIndexFlagsAnnotated,
1919
DocumentWithBsonEncodersFiledsTypes,
20+
DocumentWithComputedField,
2021
DocumentWithCustomFiledsTypes,
2122
DocumentWithDeprecatedHiddenField,
2223
DocumentWithExcludedField,
@@ -124,6 +125,44 @@ async def test_excluded(document):
124125
assert "excluded_field" not in stored_doc.dict()
125126

126127

128+
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Test only for Pydantic v2")
129+
async def test_computed_field():
130+
doc = DocumentWithComputedField(num=1)
131+
assert doc.doubled == 2
132+
133+
await doc.insert()
134+
stored_doc = await DocumentWithComputedField.get(doc.id)
135+
assert stored_doc and stored_doc.doubled == 2
136+
137+
stored_doc.num = 2
138+
assert stored_doc.doubled == 4
139+
140+
await stored_doc.replace()
141+
replaced_doc = await DocumentWithComputedField.get(doc.id)
142+
assert replaced_doc and replaced_doc.doubled == 4
143+
144+
145+
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Test only for Pydantic v2")
146+
async def test_computed_field_setter():
147+
doc = DocumentWithComputedField(num=1)
148+
await doc.insert()
149+
cached_uui = doc.cacheable_uuid
150+
db_raw_data = (
151+
await DocumentWithComputedField.get_motor_collection().find_one(
152+
{"_id": doc.id}
153+
)
154+
)
155+
assert db_raw_data == {
156+
"_id": doc.id,
157+
"num": 1,
158+
"doubled": 2,
159+
"cacheable_uuid": cached_uui,
160+
}
161+
162+
fetched_doc = await DocumentWithComputedField.get(doc.id)
163+
assert fetched_doc and fetched_doc.cacheable_uuid != cached_uui
164+
165+
127166
async def test_hidden(deprecated_init_beanie):
128167
document = DocumentWithDeprecatedHiddenField(test_hidden=["abc", "def"])
129168
await document.insert()

0 commit comments

Comments
 (0)