diff --git a/pgvector/bit.py b/pgvector/bit.py index 26a9d8d..eee4b9b 100644 --- a/pgvector/bit.py +++ b/pgvector/bit.py @@ -62,9 +62,12 @@ def from_binary(cls, value): @classmethod def _to_db(cls, value): + if value is None: + return value + if not isinstance(value, cls): - raise ValueError('expected bit') - + value = cls(value) + return value.to_text() @classmethod diff --git a/pgvector/sqlalchemy/bit.py b/pgvector/sqlalchemy/bit.py index 1ea85c3..fb7f026 100644 --- a/pgvector/sqlalchemy/bit.py +++ b/pgvector/sqlalchemy/bit.py @@ -1,6 +1,8 @@ +import asyncpg +from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg from sqlalchemy.dialects.postgresql.base import ischema_names from sqlalchemy.types import UserDefinedType, Float - +from .. import Bit class BIT(UserDefinedType): cache_ok = True @@ -24,7 +26,7 @@ def process(value): return value return process else: - return super().bind_processor(dialect) + return lambda value: Bit._to_db(value) class comparator_factory(UserDefinedType.Comparator): def hamming_distance(self, other): diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index c59c12e..af34bee 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -311,6 +311,13 @@ def test_bit(self, engine): item = session.get(Item, 1) assert item.binary_embedding == '101' + def test_boolean_list_bit(self, engine): + with Session(engine) as session: + session.add(Item(id=1, binary_embedding=[True, False, True])) + session.commit() + item = session.get(Item, 1) + assert item.binary_embedding == '101' + def test_bit_hamming_distance(self, engine): create_items() with Session(engine) as session: @@ -567,7 +574,6 @@ def test_halfvec_array(self, engine): item = session.get(Item, 1) assert item.half_embeddings == [HalfVector([1, 2, 3]), HalfVector([4, 5, 6])] - @pytest.mark.parametrize('engine', async_engines) class TestSqlalchemyAsync: def setup_method(self):