Skip to content

Commit b2f3fda

Browse files
authored
fix: Pydantic "exclude" option is not working #756 (#1154)
* fix: respect exclude option in Pydantic fields
1 parent 68ca070 commit b2f3fda

File tree

6 files changed

+63
-12
lines changed

6 files changed

+63
-12
lines changed

beanie/odm/utils/dump.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ def get_dict(
1616
exclude = set()
1717
if document.id is None:
1818
exclude.add("_id")
19-
if not document.get_settings().use_revision:
20-
exclude.add("revision_id")
21-
encoder = Encoder(exclude=exclude, to_db=to_db, keep_nulls=keep_nulls)
19+
include = set()
20+
if document.get_settings().use_revision:
21+
include.add("revision_id")
22+
encoder = Encoder(
23+
exclude=exclude, include=include, to_db=to_db, keep_nulls=keep_nulls
24+
)
2225
return encoder.encode(document)
2326

2427

beanie/odm/utils/encoder.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class Encoder:
8383
"""
8484

8585
exclude: Container[str] = frozenset()
86+
include: Container[str] = frozenset()
8687
custom_encoders: Mapping[type, SingleArgCallable] = dc.field(
8788
default_factory=dict
8889
)
@@ -154,15 +155,39 @@ def encode(self, obj: Any) -> Any:
154155
def _iter_model_items(
155156
self, obj: pydantic.BaseModel
156157
) -> Iterable[Tuple[str, Any]]:
157-
exclude, keep_nulls = self.exclude, self.keep_nulls
158+
keep_nulls = self.keep_nulls
158159
get_model_field = get_model_fields(obj).get
159160
for key, value in obj.__iter__():
160161
field_info = get_model_field(key)
161162
if field_info is not None:
162163
key = field_info.alias or key
163-
if key not in exclude and (value is not None or keep_nulls):
164+
if not self._should_exclude_field(key, field_info) and (
165+
value is not None or keep_nulls
166+
):
164167
yield key, value
165168

169+
def _should_exclude_field(
170+
self, key: str, field_info: Optional[pydantic.fields.FieldInfo]
171+
):
172+
exclude, include = (
173+
self.exclude,
174+
self.include,
175+
)
176+
177+
if key in include:
178+
return False
179+
180+
is_pydantic_excluded_field = (
181+
field_info is not None
182+
and (
183+
field_info.exclude
184+
if IS_PYDANTIC_V2
185+
else getattr(field_info.field_info, "exclude")
186+
)
187+
is True
188+
)
189+
return key in exclude or is_pydantic_excluded_field
190+
166191

167192
def _get_encoder(
168193
obj: Any, custom_encoders: Mapping[type, SingleArgCallable]

tests/odm/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
DocumentWithDecimalField,
4949
DocumentWithDeprecatedHiddenField,
5050
DocumentWithEnumKeysDict,
51+
DocumentWithExcludedField,
5152
DocumentWithExtras,
5253
DocumentWithHttpUrlField,
5354
DocumentWithIndexedObjectId,
@@ -209,6 +210,7 @@
209210
LongSelfLink,
210211
BsonRegexDoc,
211212
NativeRegexDoc,
213+
DocumentWithExcludedField,
212214
]
213215

214216

tests/odm/models.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class DocumentTestModel(Document):
165165
test_int: int
166166
test_doc: SubDocument
167167
test_str: str
168-
test_list: List[SubDocument] = Field(exclude=True)
168+
test_list: List[SubDocument]
169169

170170
class Settings:
171171
use_cache = True
@@ -268,9 +268,11 @@ class Settings:
268268

269269
class DocumentWithDeprecatedHiddenField(Document):
270270
if IS_PYDANTIC_V2:
271-
test_hidden: List[str] = Field(json_schema_extra={"hidden": True})
271+
test_hidden: Optional[List[str]] = Field(
272+
default=None, json_schema_extra={"hidden": True}
273+
)
272274
else:
273-
test_hidden: List[str] = Field(hidden=True)
275+
test_hidden: Optional[List[str]] = Field(default=None, hidden=True)
274276

275277

276278
class DocumentWithCustomIdUUID(Document):
@@ -568,7 +570,7 @@ class House(Document):
568570
roof: Optional[Link[Roof]] = None
569571
yards: Optional[List[Link[Yard]]] = None
570572
height: Indexed(int) = 2
571-
name: Indexed(str) = Field(exclude=True)
573+
name: Indexed(str)
572574

573575
if IS_PYDANTIC_V2:
574576
model_config = ConfigDict(
@@ -905,6 +907,11 @@ class Settings:
905907
use_state_management = True
906908

907909

910+
class DocumentWithExcludedField(Document):
911+
included_field: int
912+
excluded_field: Optional[int] = Field(default=None, exclude=True)
913+
914+
908915
class ReleaseElemMatch(BaseModel):
909916
major_ver: int
910917
minor_ver: int

tests/odm/test_encoder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DocumentWithComplexDictKey,
1919
DocumentWithDecimalField,
2020
DocumentWithEnumKeysDict,
21+
DocumentWithExcludedField,
2122
DocumentWithHttpUrlField,
2223
DocumentWithKeepNullsFalse,
2324
DocumentWithStringField,
@@ -138,6 +139,13 @@ def test_keep_nulls_false():
138139
assert encoded_doc == {"m": {"i": 10}}
139140

140141

142+
async def test_excluded():
143+
doc = DocumentWithExcludedField(included_field=1, excluded_field=2)
144+
encoded_doc = Encoder().encode(doc)
145+
assert "included_field" in encoded_doc
146+
assert "excluded_field" not in encoded_doc
147+
148+
141149
@pytest.mark.skipif(not IS_PYDANTIC_V2, reason="Test only for Pydantic v2")
142150
def test_should_encode_pydantic_v2_url_correctly():
143151
url = AnyUrl("https://example.com")

tests/odm/test_fields.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DocumentWithBsonEncodersFiledsTypes,
2020
DocumentWithCustomFiledsTypes,
2121
DocumentWithDeprecatedHiddenField,
22+
DocumentWithExcludedField,
2223
Sample,
2324
)
2425

@@ -111,11 +112,16 @@ async def test_custom_filed_types():
111112

112113

113114
async def test_excluded(document):
114-
document = await DocumentTestModel.find_one()
115+
doc = DocumentWithExcludedField(included_field=1, excluded_field=2)
116+
await doc.insert()
117+
stored_doc = await DocumentWithExcludedField.get(doc.id)
118+
assert stored_doc is not None
115119
if IS_PYDANTIC_V2:
116-
assert "test_list" not in document.model_dump()
120+
assert "included_field" in stored_doc.model_dump()
121+
assert "excluded_field" not in stored_doc.model_dump()
117122
else:
118-
assert "test_list" not in document.dict()
123+
assert "included_field" in stored_doc.dict()
124+
assert "excluded_field" not in stored_doc.dict()
119125

120126

121127
async def test_hidden(deprecated_init_beanie):

0 commit comments

Comments
 (0)