diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py index c3b8bf84b1598c19ed3612a10db06a3901e791e8..380eb3a35072e7b7fddbc28ab087f6229f3518dc 100644 --- a/chat2db/database/postgres.py +++ b/chat2db/database/postgres.py @@ -33,15 +33,26 @@ class TableInfo(Base): TIMESTAMP(timezone=True), server_default=func.current_timestamp(), onupdate=func.current_timestamp()) - __table_args__ = ( - Index( - 'table_note_vector_index', - table_note_vector, - postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, - postgresql_ops={'table_note_vector': 'vector_cosine_ops'} - ), - ) + if 'opengauss' in config['DATABASE_URL']: + __table_args__ = ( + Index( + 'table_note_vector_index', + table_note_vector, + opengauss_using='hnsw', + opengauss_with={'m': 16, 'ef_construction': 200}, + opengauss_ops={'vector': 'vector_cosine_ops'} + ), + ) + else: + __table_args__ = ( + Index( + 'table_note_vector_index', + table_note_vector, + postgresql_using='hnsw', + postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_ops={'table_note_vector': 'vector_cosine_ops'} + ), + ) class ColumnInfo(Base): @@ -66,15 +77,26 @@ class SqlExample(Base): TIMESTAMP(timezone=True), server_default=func.current_timestamp(), onupdate=func.current_timestamp()) - __table_args__ = ( - Index( - 'question_vector_index', - question_vector, - postgresql_using='hnsw', - postgresql_with={'m': 16, 'ef_construction': 200}, - postgresql_ops={'question_vector': 'vector_cosine_ops'} - ), - ) + if 'opengauss' in config['DATABASE_URL']: + __table_args__ = ( + Index( + 'question_vector_index', + question_vector, + opengauss_using='hnsw', + opengauss_with={'m': 16, 'ef_construction': 200}, + opengauss_ops={'vector': 'vector_cosine_ops'} + ), + ) + else: + __table_args__ = ( + Index( + 'question_vector_index', + question_vector, + postgresql_using='hnsw', + postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_ops={'question_vector': 'vector_cosine_ops'} + ), + ) class PostgresDB: @@ -94,6 +116,7 @@ class PostgresDB: if 'opengauss' in config['DATABASE_URL']: from sqlalchemy import event from opengauss_sqlalchemy.register_async import register_vector + @event.listens_for(cls.engine.sync_engine, "connect") def connect(dbapi_connection, connection_record): dbapi_connection.run_async(register_vector)