diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 1925a48f04..ca61a34011 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -239,6 +239,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + pattern: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, @@ -284,6 +285,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + pattern: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, @@ -338,6 +340,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + pattern: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, sa_column: Union[Column[Any], UndefinedType] = Undefined, @@ -373,6 +376,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + pattern: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, @@ -388,6 +392,16 @@ def Field( schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + + if IS_PYDANTIC_V2: + current_schema_extra.update( + pattern=pattern or regex or current_schema_extra.get("pattern") + ) + else: + current_schema_extra.update( + regex=regex or pattern or current_schema_extra.get("pattern") + ) + field_info = FieldInfo( default, default_factory=default_factory, @@ -410,7 +424,6 @@ def Field( min_length=min_length, max_length=max_length, allow_mutation=allow_mutation, - regex=regex, discriminator=discriminator, repr=repr, primary_key=primary_key, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 9d7bc77625..bbdf212627 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -55,3 +55,19 @@ class Model(SQLModel): instance = Model(id=123, foo="bar") assert "foo=" not in repr(instance) + + +@pytest.mark.parametrize("param", ["regex", "pattern"]) +def test_field_regex_param(param: str): + class DateModel(SQLModel): + date_1: str = Field(**{param: r"^\d{2}-\d{2}-\d{4}$"}) + + DateModel(date_1="12-31-2024") # Validates correctly + + +def test_field_pattern_via_schema_extra(): + class DateModel(SQLModel): + date_1: str = Field(schema_extra={"pattern": r"^\d{2}-\d{2}-\d{4}$"}) + + with pytest.raises(ValidationError): + DateModel(date_1="incorrect")