浏览代码

add pgvecto_rs support and upgrade SQLAlchemy (#3833)

Jyong 1 年之前
父节点
当前提交
3e9dbe3e0a

+ 3 - 1
.github/workflows/api-tests.yml

@@ -61,19 +61,21 @@ jobs:
       - name: Run Workflow
       - name: Run Workflow
         run: dev/pytest/pytest_workflow.sh
         run: dev/pytest/pytest_workflow.sh
 
 
-      - name: Set up Vector Stores (Weaviate, Qdrant and Milvus)
+      - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
         uses: hoverkraft-tech/compose-action@v2.0.0
         uses: hoverkraft-tech/compose-action@v2.0.0
         with:
         with:
           compose-file: |
           compose-file: |
             docker/docker-compose.middleware.yaml
             docker/docker-compose.middleware.yaml
             docker/docker-compose.qdrant.yaml
             docker/docker-compose.qdrant.yaml
             docker/docker-compose.milvus.yaml
             docker/docker-compose.milvus.yaml
+            docker/docker-compose.pgvecto-rs.yaml
           services: |
           services: |
             weaviate
             weaviate
             qdrant
             qdrant
             etcd
             etcd
             minio
             minio
             milvus-standalone
             milvus-standalone
+            pgvecto-rs
 
 
       - name: Test Vector Stores
       - name: Test Vector Stores
         run: dev/pytest/pytest_vdb.sh
         run: dev/pytest/pytest_vdb.sh

+ 8 - 1
api/.env.example

@@ -62,7 +62,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri
 WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 
 
-# Vector database configuration, support: weaviate, qdrant, milvus, relyt
+# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
 VECTOR_STORE=weaviate
 VECTOR_STORE=weaviate
 
 
 # Weaviate configuration
 # Weaviate configuration
@@ -92,6 +92,13 @@ RELYT_USER=postgres
 RELYT_PASSWORD=postgres
 RELYT_PASSWORD=postgres
 RELYT_DATABASE=postgres
 RELYT_DATABASE=postgres
 
 
+# PGVECTO_RS configuration
+PGVECTO_RS_HOST=localhost
+PGVECTO_RS_PORT=5431
+PGVECTO_RS_USER=postgres
+PGVECTO_RS_PASSWORD=difyai123456
+PGVECTO_RS_DATABASE=postgres
+
 # Upload configuration
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5
 UPLOAD_FILE_BATCH_LIMIT=5

+ 7 - 0
api/config.py

@@ -251,6 +251,13 @@ class Config:
         self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
         self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
         self.RELYT_DATABASE = get_env('RELYT_DATABASE')
         self.RELYT_DATABASE = get_env('RELYT_DATABASE')
 
 
+        # pgvecto rs settings
+        self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
+        self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
+        self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER')
+        self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
+        self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
+
         # ------------------------
         # ------------------------
         # Mail Configurations.
         # Mail Configurations.
         # ------------------------
         # ------------------------

+ 1 - 1
api/controllers/console/datasets/datasets.py

@@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
     @account_initialization_required
     @account_initialization_required
     def get(self):
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
         vector_type = current_app.config['VECTOR_STORE']
-        if vector_type == 'milvus' or vector_type == 'relyt':
+        if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
             return {
             return {
                 'retrieval_method': [
                 'retrieval_method': [
                     'semantic_search'
                     'semantic_search'

+ 0 - 0
api/core/rag/datasource/vdb/pgvecto_rs/__init__.py


+ 12 - 0
api/core/rag/datasource/vdb/pgvecto_rs/collection.py

@@ -0,0 +1,12 @@
+from uuid import UUID
+
+from numpy import ndarray
+from sqlalchemy.orm import DeclarativeBase, Mapped
+
+
+class CollectionORM(DeclarativeBase):
+    __tablename__: str
+    id: Mapped[UUID]
+    text: Mapped[str]
+    meta: Mapped[dict]
+    vector: Mapped[ndarray]

+ 224 - 0
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -0,0 +1,224 @@
+import logging
+from typing import Any
+from uuid import UUID, uuid4
+
+from numpy import ndarray
+from pgvecto_rs.sqlalchemy import Vector
+from pydantic import BaseModel, root_validator
+from sqlalchemy import Float, String, create_engine, insert, select, text
+from sqlalchemy import text as sql_text
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.orm import Mapped, Session, mapped_column
+
+from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+
+class PgvectoRSConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    database: str
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['host']:
+            raise ValueError("config PGVECTO_RS_HOST is required")
+        if not values['port']:
+            raise ValueError("config PGVECTO_RS_PORT is required")
+        if not values['user']:
+            raise ValueError("config PGVECTO_RS_USER is required")
+        if not values['password']:
+            raise ValueError("config PGVECTO_RS_PASSWORD is required")
+        if not values['database']:
+            raise ValueError("config PGVECTO_RS_DATABASE is required")
+        return values
+
+
+class PGVectoRS(BaseVector):
+
+    def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
+        self._client = create_engine(self._url)
+        with Session(self._client) as session:
+            session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
+            session.commit()
+        self._fields = []
+
+        class _Table(CollectionORM):
+            __tablename__ = collection_name
+            __table_args__ = {"extend_existing": True}  # noqa: RUF012
+            id: Mapped[UUID] = mapped_column(
+                postgresql.UUID(as_uuid=True),
+                primary_key=True,
+            )
+            text: Mapped[str] = mapped_column(String)
+            meta: Mapped[dict] = mapped_column(postgresql.JSONB)
+            vector: Mapped[ndarray] = mapped_column(Vector(dim))
+
+        self._table = _Table
+        self._distance_op = "<=>"
+
+    def get_type(self) -> str:
+        return 'pgvecto-rs'
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        self.create_collection(len(embeddings[0]))
+        self.add_texts(texts, embeddings)
+
+    def create_collection(self, dimension: int):
+        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+            index_name = f"{self._collection_name}_embedding_index"
+            with Session(self._client) as session:
+                create_statement = sql_text(f"""
+                    CREATE TABLE IF NOT EXISTS {self._collection_name} (
+                        id UUID PRIMARY KEY,
+                        text TEXT NOT NULL,
+                        meta JSONB NOT NULL,
+                        vector vector({dimension}) NOT NULL
+                    ) using heap; 
+                """)
+                session.execute(create_statement)
+                index_statement = sql_text(f"""
+                        CREATE INDEX IF NOT EXISTS {index_name}
+                        ON {self._collection_name} USING vectors(vector vector_l2_ops)
+                        WITH (options = $$
+                                optimizing.optimizing_threads = 30
+                                segment.max_growing_segment_size = 2000
+                                segment.max_sealed_segment_size = 30000000
+                                [indexing.hnsw]
+                                m=30
+                                ef_construction=500
+                                $$);
+                    """)
+                session.execute(index_statement)
+                session.commit()
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        pks = []
+        with Session(self._client) as session:
+            for document, embedding in zip(documents, embeddings):
+                pk = uuid4()
+                session.execute(
+                    insert(self._table).values(
+                        id=pk,
+                        text=document.page_content,
+                        meta=document.metadata,
+                        vector=embedding,
+                    ),
+                )
+                pks.append(pk)
+            session.commit()
+
+        return pks
+
+    def delete_by_document_id(self, document_id: str):
+        ids = self.get_ids_by_metadata_field('document_id', document_id)
+        if ids:
+            with Session(self._client) as session:
+                select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
+                session.execute(select_statement, {'ids': ids})
+                session.commit()
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        result = None
+        with Session(self._client) as session:
+            select_statement = sql_text(
+                f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
+            )
+            result = session.execute(select_statement).fetchall()
+        if result:
+            return [item[0] for item in result]
+        else:
+            return None
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        ids = self.get_ids_by_metadata_field(key, value)
+        if ids:
+            with Session(self._client) as session:
+                select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
+                session.execute(select_statement, {'ids': ids})
+                session.commit()
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        with Session(self._client) as session:
+            select_statement = sql_text(
+                f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
+            )
+            result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
+        if result:
+            ids = [item[0] for item in result]
+            if ids:
+                with Session(self._client) as session:
+                    select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
+                    session.execute(select_statement, {'ids': ids})
+                    session.commit()
+
+    def delete(self) -> None:
+        with Session(self._client) as session:
+            session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
+            session.commit()
+
+    def text_exists(self, id: str) -> bool:
+        with Session(self._client) as session:
+            select_statement = sql_text(
+                f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
+            )
+            result = session.execute(select_statement).fetchall()
+        return len(result) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        with Session(self._client) as session:
+            stmt = (
+                select(
+                    self._table,
+                    self._table.vector.op(self._distance_op, return_type=Float)(
+                        query_vector,
+                    ).label("distance"),
+                )
+                .limit(kwargs.get('top_k', 2))
+                .order_by("distance")
+            )
+            res = session.execute(stmt)
+            results = [(row[0], row[1]) for row in res]
+
+        # Organize results.
+        docs = []
+        for record, dis in results:
+            metadata = record.meta
+            score = 1 - dis
+            metadata['score'] = score
+            score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
+            if score > score_threshold:
+                doc = Document(page_content=record.text,
+                               metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # with Session(self._client) as session:
+        #     select_statement = sql_text(
+        #         f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
+        #     )
+        #     results = session.execute(select_statement).fetchall()
+        # if results:
+        #     docs = []
+        #     for result in results:
+        #         doc = Document(page_content=result[0],
+        #                        metadata=result[1])
+        #         docs.append(doc)
+        #     return docs
+        return []

+ 1 - 1
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -235,7 +235,7 @@ class RelytVector(BaseVector):
         docs = []
         docs = []
         for document, score in results:
         for document, score in results:
             score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
             score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
-            if score > score_threshold:
+            if 1 - score > score_threshold:
                 docs.append(document)
                 docs.append(document)
         return docs
         return docs
 
 

+ 25 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -139,6 +139,31 @@ class Vector:
                 ),
                 ),
                 group_id=self._dataset.id
                 group_id=self._dataset.id
             )
             )
+        elif vector_type == "pgvecto_rs":
+            from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix.lower()
+            else:
+                dataset_id = self._dataset.id
+                collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+                index_struct_dict = {
+                    "type": 'pgvecto_rs',
+                    "vector_store": {"class_prefix": collection_name}
+                }
+                self._dataset.index_struct = json.dumps(index_struct_dict)
+            dim = len(self._embeddings.embed_query("pgvecto_rs"))
+            return PGVectoRS(
+                collection_name=collection_name,
+                config=PgvectoRSConfig(
+                    host=config.get('PGVECTO_RS_HOST'),
+                    port=config.get('PGVECTO_RS_PORT'),
+                    user=config.get('PGVECTO_RS_USER'),
+                    password=config.get('PGVECTO_RS_PASSWORD'),
+                    database=config.get('PGVECTO_RS_DATABASE'),
+                ),
+                dim=dim
+            )
         else:
         else:
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 
 

+ 1 - 0
api/migrations/script.py.mako

@@ -6,6 +6,7 @@ Create Date: ${create_date}
 
 
 """
 """
 from alembic import op
 from alembic import op
+import models as models
 import sqlalchemy as sa
 import sqlalchemy as sa
 ${imports if imports else ""}
 ${imports if imports else ""}
 
 

+ 27 - 0
api/models/__init__.py

@@ -1,5 +1,8 @@
 from enum import Enum
 from enum import Enum
 
 
+from sqlalchemy import CHAR, TypeDecorator
+from sqlalchemy.dialects.postgresql import UUID
+
 
 
 class CreatedByRole(Enum):
 class CreatedByRole(Enum):
     """
     """
@@ -42,3 +45,27 @@ class CreatedFrom(Enum):
             if role.value == value:
             if role.value == value:
                 return role
                 return role
         raise ValueError(f'invalid createdFrom value {value}')
         raise ValueError(f'invalid createdFrom value {value}')
+
+
+class StringUUID(TypeDecorator):
+    impl = CHAR
+    cache_ok = True
+
+    def process_bind_param(self, value, dialect):
+        if value is None:
+            return value
+        elif dialect.name == 'postgresql':
+            return str(value)
+        else:
+            return value.hex
+
+    def load_dialect_impl(self, dialect):
+        if dialect.name == 'postgresql':
+            return dialect.type_descriptor(UUID())
+        else:
+            return dialect.type_descriptor(CHAR(36))
+
+    def process_result_value(self, value, dialect):
+        if value is None:
+            return value
+        return str(value)

+ 11 - 11
api/models/account.py

@@ -2,9 +2,9 @@ import enum
 import json
 import json
 
 
 from flask_login import UserMixin
 from flask_login import UserMixin
-from sqlalchemy.dialects.postgresql import UUID
 
 
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 
 
 
 
 class AccountStatus(str, enum.Enum):
 class AccountStatus(str, enum.Enum):
@@ -22,7 +22,7 @@ class Account(UserMixin, db.Model):
         db.Index('account_email_idx', 'email')
         db.Index('account_email_idx', 'email')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     email = db.Column(db.String(255), nullable=False)
     email = db.Column(db.String(255), nullable=False)
     password = db.Column(db.String(255), nullable=True)
     password = db.Column(db.String(255), nullable=True)
@@ -128,7 +128,7 @@ class Tenant(db.Model):
         db.PrimaryKeyConstraint('id', name='tenant_pkey'),
         db.PrimaryKeyConstraint('id', name='tenant_pkey'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     encrypt_public_key = db.Column(db.Text)
     encrypt_public_key = db.Column(db.Text)
     plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
     plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
@@ -168,12 +168,12 @@ class TenantAccountJoin(db.Model):
         db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join')
         db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    account_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    account_id = db.Column(StringUUID, nullable=False)
     current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     role = db.Column(db.String(16), nullable=False, server_default='normal')
     role = db.Column(db.String(16), nullable=False, server_default='normal')
-    invited_by = db.Column(UUID, nullable=True)
+    invited_by = db.Column(StringUUID, nullable=True)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
@@ -186,8 +186,8 @@ class AccountIntegrate(db.Model):
         db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id')
         db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    account_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    account_id = db.Column(StringUUID, nullable=False)
     provider = db.Column(db.String(16), nullable=False)
     provider = db.Column(db.String(16), nullable=False)
     open_id = db.Column(db.String(255), nullable=False)
     open_id = db.Column(db.String(255), nullable=False)
     encrypted_token = db.Column(db.String(255), nullable=False)
     encrypted_token = db.Column(db.String(255), nullable=False)
@@ -208,7 +208,7 @@ class InvitationCode(db.Model):
     code = db.Column(db.String(32), nullable=False)
     code = db.Column(db.String(32), nullable=False)
     status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying"))
     status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying"))
     used_at = db.Column(db.DateTime)
     used_at = db.Column(db.DateTime)
-    used_by_tenant_id = db.Column(UUID)
-    used_by_account_id = db.Column(UUID)
+    used_by_tenant_id = db.Column(StringUUID)
+    used_by_account_id = db.Column(StringUUID)
     deprecated_at = db.Column(db.DateTime)
     deprecated_at = db.Column(db.DateTime)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

+ 3 - 4
api/models/api_based_extension.py

@@ -1,8 +1,7 @@
 import enum
 import enum
 
 
-from sqlalchemy.dialects.postgresql import UUID
-
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 
 
 
 
 class APIBasedExtensionPoint(enum.Enum):
 class APIBasedExtensionPoint(enum.Enum):
@@ -19,8 +18,8 @@ class APIBasedExtension(db.Model):
         db.Index('api_based_extension_tenant_idx', 'tenant_id'),
         db.Index('api_based_extension_tenant_idx', 'tenant_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     api_endpoint = db.Column(db.String(255), nullable=False)
     api_endpoint = db.Column(db.String(255), nullable=False)
     api_key = db.Column(db.Text, nullable=False)
     api_key = db.Column(db.Text, nullable=False)

+ 37 - 36
api/models/dataset.py

@@ -4,10 +4,11 @@ import pickle
 from json import JSONDecodeError
 from json import JSONDecodeError
 
 
 from sqlalchemy import func
 from sqlalchemy import func
-from sqlalchemy.dialects.postgresql import JSONB, UUID
+from sqlalchemy.dialects.postgresql import JSONB
 
 
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
+from models import StringUUID
 from models.account import Account
 from models.account import Account
 from models.model import App, Tag, TagBinding, UploadFile
 from models.model import App, Tag, TagBinding, UploadFile
 
 
@@ -22,8 +23,8 @@ class Dataset(db.Model):
 
 
     INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
     INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     description = db.Column(db.Text, nullable=True)
     description = db.Column(db.Text, nullable=True)
     provider = db.Column(db.String(255), nullable=False,
     provider = db.Column(db.String(255), nullable=False,
@@ -33,15 +34,15 @@ class Dataset(db.Model):
     data_source_type = db.Column(db.String(255))
     data_source_type = db.Column(db.String(255))
     indexing_technique = db.Column(db.String(255), nullable=True)
     indexing_technique = db.Column(db.String(255), nullable=True)
     index_struct = db.Column(db.Text, nullable=True)
     index_struct = db.Column(db.Text, nullable=True)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False,
     created_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
-    updated_by = db.Column(UUID, nullable=True)
+    updated_by = db.Column(StringUUID, nullable=True)
     updated_at = db.Column(db.DateTime, nullable=False,
     updated_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
     embedding_model = db.Column(db.String(255), nullable=True)
     embedding_model = db.Column(db.String(255), nullable=True)
     embedding_model_provider = db.Column(db.String(255), nullable=True)
     embedding_model_provider = db.Column(db.String(255), nullable=True)
-    collection_binding_id = db.Column(UUID, nullable=True)
+    collection_binding_id = db.Column(StringUUID, nullable=True)
     retrieval_model = db.Column(JSONB, nullable=True)
     retrieval_model = db.Column(JSONB, nullable=True)
 
 
     @property
     @property
@@ -145,13 +146,13 @@ class DatasetProcessRule(db.Model):
         db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
         db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
     )
     )
 
 
-    id = db.Column(UUID, nullable=False,
+    id = db.Column(StringUUID, nullable=False,
                    server_default=db.text('uuid_generate_v4()'))
                    server_default=db.text('uuid_generate_v4()'))
-    dataset_id = db.Column(UUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
     mode = db.Column(db.String(255), nullable=False,
     mode = db.Column(db.String(255), nullable=False,
                      server_default=db.text("'automatic'::character varying"))
                      server_default=db.text("'automatic'::character varying"))
     rules = db.Column(db.Text, nullable=True)
     rules = db.Column(db.Text, nullable=True)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False,
     created_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
@@ -197,19 +198,19 @@ class Document(db.Model):
     )
     )
 
 
     # initial fields
     # initial fields
-    id = db.Column(UUID, nullable=False,
+    id = db.Column(StringUUID, nullable=False,
                    server_default=db.text('uuid_generate_v4()'))
                    server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    dataset_id = db.Column(UUID, nullable=False)
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
     position = db.Column(db.Integer, nullable=False)
     position = db.Column(db.Integer, nullable=False)
     data_source_type = db.Column(db.String(255), nullable=False)
     data_source_type = db.Column(db.String(255), nullable=False)
     data_source_info = db.Column(db.Text, nullable=True)
     data_source_info = db.Column(db.Text, nullable=True)
-    dataset_process_rule_id = db.Column(UUID, nullable=True)
+    dataset_process_rule_id = db.Column(StringUUID, nullable=True)
     batch = db.Column(db.String(255), nullable=False)
     batch = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     created_from = db.Column(db.String(255), nullable=False)
     created_from = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(UUID, nullable=False)
-    created_api_request_id = db.Column(UUID, nullable=True)
+    created_by = db.Column(StringUUID, nullable=False)
+    created_api_request_id = db.Column(StringUUID, nullable=True)
     created_at = db.Column(db.DateTime, nullable=False,
     created_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
@@ -234,7 +235,7 @@ class Document(db.Model):
 
 
     # pause
     # pause
     is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
     is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
-    paused_by = db.Column(UUID, nullable=True)
+    paused_by = db.Column(StringUUID, nullable=True)
     paused_at = db.Column(db.DateTime, nullable=True)
     paused_at = db.Column(db.DateTime, nullable=True)
 
 
     # error
     # error
@@ -247,11 +248,11 @@ class Document(db.Model):
     enabled = db.Column(db.Boolean, nullable=False,
     enabled = db.Column(db.Boolean, nullable=False,
                         server_default=db.text('true'))
                         server_default=db.text('true'))
     disabled_at = db.Column(db.DateTime, nullable=True)
     disabled_at = db.Column(db.DateTime, nullable=True)
-    disabled_by = db.Column(UUID, nullable=True)
+    disabled_by = db.Column(StringUUID, nullable=True)
     archived = db.Column(db.Boolean, nullable=False,
     archived = db.Column(db.Boolean, nullable=False,
                          server_default=db.text('false'))
                          server_default=db.text('false'))
     archived_reason = db.Column(db.String(255), nullable=True)
     archived_reason = db.Column(db.String(255), nullable=True)
-    archived_by = db.Column(UUID, nullable=True)
+    archived_by = db.Column(StringUUID, nullable=True)
     archived_at = db.Column(db.DateTime, nullable=True)
     archived_at = db.Column(db.DateTime, nullable=True)
     updated_at = db.Column(db.DateTime, nullable=False,
     updated_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -356,11 +357,11 @@ class DocumentSegment(db.Model):
     )
     )
 
 
     # initial fields
     # initial fields
-    id = db.Column(UUID, nullable=False,
+    id = db.Column(StringUUID, nullable=False,
                    server_default=db.text('uuid_generate_v4()'))
                    server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    dataset_id = db.Column(UUID, nullable=False)
-    document_id = db.Column(UUID, nullable=False)
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
+    document_id = db.Column(StringUUID, nullable=False)
     position = db.Column(db.Integer, nullable=False)
     position = db.Column(db.Integer, nullable=False)
     content = db.Column(db.Text, nullable=False)
     content = db.Column(db.Text, nullable=False)
     answer = db.Column(db.Text, nullable=True)
     answer = db.Column(db.Text, nullable=True)
@@ -377,13 +378,13 @@ class DocumentSegment(db.Model):
     enabled = db.Column(db.Boolean, nullable=False,
     enabled = db.Column(db.Boolean, nullable=False,
                         server_default=db.text('true'))
                         server_default=db.text('true'))
     disabled_at = db.Column(db.DateTime, nullable=True)
     disabled_at = db.Column(db.DateTime, nullable=True)
-    disabled_by = db.Column(UUID, nullable=True)
+    disabled_by = db.Column(StringUUID, nullable=True)
     status = db.Column(db.String(255), nullable=False,
     status = db.Column(db.String(255), nullable=False,
                        server_default=db.text("'waiting'::character varying"))
                        server_default=db.text("'waiting'::character varying"))
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False,
     created_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
-    updated_by = db.Column(UUID, nullable=True)
+    updated_by = db.Column(StringUUID, nullable=True)
     updated_at = db.Column(db.DateTime, nullable=False,
     updated_at = db.Column(db.DateTime, nullable=False,
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
     indexing_at = db.Column(db.DateTime, nullable=True)
     indexing_at = db.Column(db.DateTime, nullable=True)
@@ -421,9 +422,9 @@ class AppDatasetJoin(db.Model):
         db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
         db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
     )
     )
 
 
-    id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    dataset_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
 
 
     @property
     @property
@@ -438,13 +439,13 @@ class DatasetQuery(db.Model):
         db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
         db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
     )
     )
 
 
-    id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
-    dataset_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    dataset_id = db.Column(StringUUID, nullable=False)
     content = db.Column(db.Text, nullable=False)
     content = db.Column(db.Text, nullable=False)
     source = db.Column(db.String(255), nullable=False)
     source = db.Column(db.String(255), nullable=False)
-    source_app_id = db.Column(UUID, nullable=True)
+    source_app_id = db.Column(StringUUID, nullable=True)
     created_by_role = db.Column(db.String, nullable=False)
     created_by_role = db.Column(db.String, nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
 
 
 
 
@@ -455,8 +456,8 @@ class DatasetKeywordTable(db.Model):
         db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
         db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
     )
     )
 
 
-    id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
-    dataset_id = db.Column(UUID, nullable=False, unique=True)
+    id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
+    dataset_id = db.Column(StringUUID, nullable=False, unique=True)
     keyword_table = db.Column(db.Text, nullable=False)
     keyword_table = db.Column(db.Text, nullable=False)
     data_source_type = db.Column(db.String(255), nullable=False,
     data_source_type = db.Column(db.String(255), nullable=False,
                                  server_default=db.text("'database'::character varying"))
                                  server_default=db.text("'database'::character varying"))
@@ -501,7 +502,7 @@ class Embedding(db.Model):
         db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
         db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
     )
     )
 
 
-    id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
     model_name = db.Column(db.String(40), nullable=False,
     model_name = db.Column(db.String(40), nullable=False,
                            server_default=db.text("'text-embedding-ada-002'::character varying"))
                            server_default=db.text("'text-embedding-ada-002'::character varying"))
     hash = db.Column(db.String(64), nullable=False)
     hash = db.Column(db.String(64), nullable=False)
@@ -525,7 +526,7 @@ class DatasetCollectionBinding(db.Model):
 
 
     )
     )
 
 
-    id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
     provider_name = db.Column(db.String(40), nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     model_name = db.Column(db.String(40), nullable=False)
     model_name = db.Column(db.String(40), nullable=False)
     type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
     type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)

+ 89 - 89
api/models/model.py

@@ -7,13 +7,13 @@ from typing import Optional
 from flask import current_app, request
 from flask import current_app, request
 from flask_login import UserMixin
 from flask_login import UserMixin
 from sqlalchemy import Float, text
 from sqlalchemy import Float, text
-from sqlalchemy.dialects.postgresql import UUID
 
 
 from core.file.tool_file_parser import ToolFileParser
 from core.file.tool_file_parser import ToolFileParser
 from core.file.upload_file_parser import UploadFileParser
 from core.file.upload_file_parser import UploadFileParser
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs.helper import generate_string
 from libs.helper import generate_string
 
 
+from . import StringUUID
 from .account import Account, Tenant
 from .account import Account, Tenant
 
 
 
 
@@ -56,15 +56,15 @@ class App(db.Model):
         db.Index('app_tenant_id_idx', 'tenant_id')
         db.Index('app_tenant_id_idx', 'tenant_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
     description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
     mode = db.Column(db.String(255), nullable=False)
     mode = db.Column(db.String(255), nullable=False)
     icon = db.Column(db.String(255))
     icon = db.Column(db.String(255))
     icon_background = db.Column(db.String(255))
     icon_background = db.Column(db.String(255))
-    app_model_config_id = db.Column(UUID, nullable=True)
-    workflow_id = db.Column(UUID, nullable=True)
+    app_model_config_id = db.Column(StringUUID, nullable=True)
+    workflow_id = db.Column(StringUUID, nullable=True)
     status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
     status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
     enable_site = db.Column(db.Boolean, nullable=False)
     enable_site = db.Column(db.Boolean, nullable=False)
     enable_api = db.Column(db.Boolean, nullable=False)
     enable_api = db.Column(db.Boolean, nullable=False)
@@ -207,8 +207,8 @@ class AppModelConfig(db.Model):
         db.Index('app_app_id_idx', 'app_id')
         db.Index('app_app_id_idx', 'app_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
     provider = db.Column(db.String(255), nullable=True)
     provider = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
     configs = db.Column(db.JSON, nullable=True)
     configs = db.Column(db.JSON, nullable=True)
@@ -430,8 +430,8 @@ class RecommendedApp(db.Model):
         db.Index('recommended_app_is_listed_idx', 'is_listed', 'language')
         db.Index('recommended_app_is_listed_idx', 'is_listed', 'language')
     )
     )
 
 
-    id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
     description = db.Column(db.JSON, nullable=False)
     description = db.Column(db.JSON, nullable=False)
     copyright = db.Column(db.String(255), nullable=False)
     copyright = db.Column(db.String(255), nullable=False)
     privacy_policy = db.Column(db.String(255), nullable=False)
     privacy_policy = db.Column(db.String(255), nullable=False)
@@ -458,10 +458,10 @@ class InstalledApp(db.Model):
         db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app')
         db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    app_id = db.Column(UUID, nullable=False)
-    app_owner_tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    app_id = db.Column(StringUUID, nullable=False)
+    app_owner_tenant_id = db.Column(StringUUID, nullable=False)
     position = db.Column(db.Integer, nullable=False, default=0)
     position = db.Column(db.Integer, nullable=False, default=0)
     is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     last_used_at = db.Column(db.DateTime, nullable=True)
     last_used_at = db.Column(db.DateTime, nullable=True)
@@ -486,9 +486,9 @@ class Conversation(db.Model):
         db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
         db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    app_model_config_id = db.Column(UUID, nullable=True)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    app_model_config_id = db.Column(StringUUID, nullable=True)
     model_provider = db.Column(db.String(255), nullable=True)
     model_provider = db.Column(db.String(255), nullable=True)
     override_model_configs = db.Column(db.Text)
     override_model_configs = db.Column(db.Text)
     model_id = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
@@ -502,10 +502,10 @@ class Conversation(db.Model):
     status = db.Column(db.String(255), nullable=False)
     status = db.Column(db.String(255), nullable=False)
     invoke_from = db.Column(db.String(255), nullable=True)
     invoke_from = db.Column(db.String(255), nullable=True)
     from_source = db.Column(db.String(255), nullable=False)
     from_source = db.Column(db.String(255), nullable=False)
-    from_end_user_id = db.Column(UUID)
-    from_account_id = db.Column(UUID)
+    from_end_user_id = db.Column(StringUUID)
+    from_account_id = db.Column(StringUUID)
     read_at = db.Column(db.DateTime)
     read_at = db.Column(db.DateTime)
-    read_account_id = db.Column(UUID)
+    read_account_id = db.Column(StringUUID)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
@@ -626,12 +626,12 @@ class Message(db.Model):
         db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'),
         db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
     model_provider = db.Column(db.String(255), nullable=True)
     model_provider = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
     override_model_configs = db.Column(db.Text)
     override_model_configs = db.Column(db.Text)
-    conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
+    conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
     inputs = db.Column(db.JSON)
     inputs = db.Column(db.JSON)
     query = db.Column(db.Text, nullable=False)
     query = db.Column(db.Text, nullable=False)
     message = db.Column(db.JSON, nullable=False)
     message = db.Column(db.JSON, nullable=False)
@@ -650,12 +650,12 @@ class Message(db.Model):
     message_metadata = db.Column(db.Text)
     message_metadata = db.Column(db.Text)
     invoke_from = db.Column(db.String(255), nullable=True)
     invoke_from = db.Column(db.String(255), nullable=True)
     from_source = db.Column(db.String(255), nullable=False)
     from_source = db.Column(db.String(255), nullable=False)
-    from_end_user_id = db.Column(UUID)
-    from_account_id = db.Column(UUID)
+    from_end_user_id = db.Column(StringUUID)
+    from_account_id = db.Column(StringUUID)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
-    workflow_run_id = db.Column(UUID)
+    workflow_run_id = db.Column(StringUUID)
 
 
     @property
     @property
     def re_sign_file_url_answer(self) -> str:
     def re_sign_file_url_answer(self) -> str:
@@ -846,15 +846,15 @@ class MessageFeedback(db.Model):
         db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating')
         db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    conversation_id = db.Column(UUID, nullable=False)
-    message_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    conversation_id = db.Column(StringUUID, nullable=False)
+    message_id = db.Column(StringUUID, nullable=False)
     rating = db.Column(db.String(255), nullable=False)
     rating = db.Column(db.String(255), nullable=False)
     content = db.Column(db.Text)
     content = db.Column(db.Text)
     from_source = db.Column(db.String(255), nullable=False)
     from_source = db.Column(db.String(255), nullable=False)
-    from_end_user_id = db.Column(UUID)
-    from_account_id = db.Column(UUID)
+    from_end_user_id = db.Column(StringUUID)
+    from_account_id = db.Column(StringUUID)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
@@ -872,15 +872,15 @@ class MessageFile(db.Model):
         db.Index('message_file_created_by_idx', 'created_by')
         db.Index('message_file_created_by_idx', 'created_by')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    message_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(StringUUID, nullable=False)
     type = db.Column(db.String(255), nullable=False)
     type = db.Column(db.String(255), nullable=False)
     transfer_method = db.Column(db.String(255), nullable=False)
     transfer_method = db.Column(db.String(255), nullable=False)
     url = db.Column(db.Text, nullable=True)
     url = db.Column(db.Text, nullable=True)
     belongs_to = db.Column(db.String(255), nullable=True)
     belongs_to = db.Column(db.String(255), nullable=True)
-    upload_file_id = db.Column(UUID, nullable=True)
+    upload_file_id = db.Column(StringUUID, nullable=True)
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
 
 
@@ -893,14 +893,14 @@ class MessageAnnotation(db.Model):
         db.Index('message_annotation_message_idx', 'message_id')
         db.Index('message_annotation_message_idx', 'message_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
-    message_id = db.Column(UUID, nullable=True)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
+    message_id = db.Column(StringUUID, nullable=True)
     question = db.Column(db.Text, nullable=True)
     question = db.Column(db.Text, nullable=True)
     content = db.Column(db.Text, nullable=False)
     content = db.Column(db.Text, nullable=False)
     hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
     hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
-    account_id = db.Column(UUID, nullable=False)
+    account_id = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
@@ -925,15 +925,15 @@ class AppAnnotationHitHistory(db.Model):
         db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
         db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    annotation_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    annotation_id = db.Column(StringUUID, nullable=False)
     source = db.Column(db.Text, nullable=False)
     source = db.Column(db.Text, nullable=False)
     question = db.Column(db.Text, nullable=False)
     question = db.Column(db.Text, nullable=False)
-    account_id = db.Column(UUID, nullable=False)
+    account_id = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     score = db.Column(Float, nullable=False, server_default=db.text('0'))
     score = db.Column(Float, nullable=False, server_default=db.text('0'))
-    message_id = db.Column(UUID, nullable=False)
+    message_id = db.Column(StringUUID, nullable=False)
     annotation_question = db.Column(db.Text, nullable=False)
     annotation_question = db.Column(db.Text, nullable=False)
     annotation_content = db.Column(db.Text, nullable=False)
     annotation_content = db.Column(db.Text, nullable=False)
 
 
@@ -957,13 +957,13 @@ class AppAnnotationSetting(db.Model):
         db.Index('app_annotation_settings_app_idx', 'app_id')
         db.Index('app_annotation_settings_app_idx', 'app_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
     score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
     score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
-    collection_binding_id = db.Column(UUID, nullable=False)
-    created_user_id = db.Column(UUID, nullable=False)
+    collection_binding_id = db.Column(StringUUID, nullable=False)
+    created_user_id = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
-    updated_user_id = db.Column(UUID, nullable=False)
+    updated_user_id = db.Column(StringUUID, nullable=False)
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
     @property
     @property
@@ -995,9 +995,9 @@ class OperationLog(db.Model):
         db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action')
         db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    account_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    account_id = db.Column(StringUUID, nullable=False)
     action = db.Column(db.String(255), nullable=False)
     action = db.Column(db.String(255), nullable=False)
     content = db.Column(db.JSON)
     content = db.Column(db.JSON)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -1013,9 +1013,9 @@ class EndUser(UserMixin, db.Model):
         db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'),
         db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    app_id = db.Column(UUID, nullable=True)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    app_id = db.Column(StringUUID, nullable=True)
     type = db.Column(db.String(255), nullable=False)
     type = db.Column(db.String(255), nullable=False)
     external_user_id = db.Column(db.String(255), nullable=True)
     external_user_id = db.Column(db.String(255), nullable=True)
     name = db.Column(db.String(255))
     name = db.Column(db.String(255))
@@ -1033,8 +1033,8 @@ class Site(db.Model):
         db.Index('site_code_idx', 'code', 'status')
         db.Index('site_code_idx', 'code', 'status')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
     title = db.Column(db.String(255), nullable=False)
     title = db.Column(db.String(255), nullable=False)
     icon = db.Column(db.String(255))
     icon = db.Column(db.String(255))
     icon_background = db.Column(db.String(255))
     icon_background = db.Column(db.String(255))
@@ -1074,9 +1074,9 @@ class ApiToken(db.Model):
         db.Index('api_token_tenant_idx', 'tenant_id', 'type')
         db.Index('api_token_tenant_idx', 'tenant_id', 'type')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=True)
-    tenant_id = db.Column(UUID, nullable=True)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=True)
+    tenant_id = db.Column(StringUUID, nullable=True)
     type = db.Column(db.String(16), nullable=False)
     type = db.Column(db.String(16), nullable=False)
     token = db.Column(db.String(255), nullable=False)
     token = db.Column(db.String(255), nullable=False)
     last_used_at = db.Column(db.DateTime, nullable=True)
     last_used_at = db.Column(db.DateTime, nullable=True)
@@ -1099,8 +1099,8 @@ class UploadFile(db.Model):
         db.Index('upload_file_tenant_idx', 'tenant_id')
         db.Index('upload_file_tenant_idx', 'tenant_id')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     storage_type = db.Column(db.String(255), nullable=False)
     storage_type = db.Column(db.String(255), nullable=False)
     key = db.Column(db.String(255), nullable=False)
     key = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
@@ -1108,10 +1108,10 @@ class UploadFile(db.Model):
     extension = db.Column(db.String(255), nullable=False)
     extension = db.Column(db.String(255), nullable=False)
     mime_type = db.Column(db.String(255), nullable=True)
     mime_type = db.Column(db.String(255), nullable=True)
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
-    used_by = db.Column(UUID, nullable=True)
+    used_by = db.Column(StringUUID, nullable=True)
     used_at = db.Column(db.DateTime, nullable=True)
     used_at = db.Column(db.DateTime, nullable=True)
     hash = db.Column(db.String(255), nullable=True)
     hash = db.Column(db.String(255), nullable=True)
 
 
@@ -1123,9 +1123,9 @@ class ApiRequest(db.Model):
         db.Index('api_request_token_idx', 'tenant_id', 'api_token_id')
         db.Index('api_request_token_idx', 'tenant_id', 'api_token_id')
     )
     )
 
 
-    id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    api_token_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    api_token_id = db.Column(StringUUID, nullable=False)
     path = db.Column(db.String(255), nullable=False)
     path = db.Column(db.String(255), nullable=False)
     request = db.Column(db.Text, nullable=True)
     request = db.Column(db.Text, nullable=True)
     response = db.Column(db.Text, nullable=True)
     response = db.Column(db.Text, nullable=True)
@@ -1140,8 +1140,8 @@ class MessageChain(db.Model):
         db.Index('message_chain_message_id_idx', 'message_id')
         db.Index('message_chain_message_id_idx', 'message_id')
     )
     )
 
 
-    id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
-    message_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(StringUUID, nullable=False)
     type = db.Column(db.String(255), nullable=False)
     type = db.Column(db.String(255), nullable=False)
     input = db.Column(db.Text, nullable=True)
     input = db.Column(db.Text, nullable=True)
     output = db.Column(db.Text, nullable=True)
     output = db.Column(db.Text, nullable=True)
@@ -1156,9 +1156,9 @@ class MessageAgentThought(db.Model):
         db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'),
         db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'),
     )
     )
 
 
-    id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
-    message_id = db.Column(UUID, nullable=False)
-    message_chain_id = db.Column(UUID, nullable=True)
+    id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(StringUUID, nullable=False)
+    message_chain_id = db.Column(StringUUID, nullable=True)
     position = db.Column(db.Integer, nullable=False)
     position = db.Column(db.Integer, nullable=False)
     thought = db.Column(db.Text, nullable=True)
     thought = db.Column(db.Text, nullable=True)
     tool = db.Column(db.Text, nullable=True)
     tool = db.Column(db.Text, nullable=True)
@@ -1166,7 +1166,7 @@ class MessageAgentThought(db.Model):
     tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
     tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
     tool_input = db.Column(db.Text, nullable=True)
     tool_input = db.Column(db.Text, nullable=True)
     observation = db.Column(db.Text, nullable=True)
     observation = db.Column(db.Text, nullable=True)
-    # plugin_id = db.Column(UUID, nullable=True)  ## for future design
+    # plugin_id = db.Column(StringUUID, nullable=True)  ## for future design
     tool_process_data = db.Column(db.Text, nullable=True)
     tool_process_data = db.Column(db.Text, nullable=True)
     message = db.Column(db.Text, nullable=True)
     message = db.Column(db.Text, nullable=True)
     message_token = db.Column(db.Integer, nullable=True)
     message_token = db.Column(db.Integer, nullable=True)
@@ -1182,7 +1182,7 @@ class MessageAgentThought(db.Model):
     currency = db.Column(db.String, nullable=True)
     currency = db.Column(db.String, nullable=True)
     latency = db.Column(db.Float, nullable=True)
     latency = db.Column(db.Float, nullable=True)
     created_by_role = db.Column(db.String, nullable=False)
     created_by_role = db.Column(db.String, nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
 
 
     @property
     @property
@@ -1273,15 +1273,15 @@ class DatasetRetrieverResource(db.Model):
         db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
         db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
     )
     )
 
 
-    id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
-    message_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(StringUUID, nullable=False)
     position = db.Column(db.Integer, nullable=False)
     position = db.Column(db.Integer, nullable=False)
-    dataset_id = db.Column(UUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
     dataset_name = db.Column(db.Text, nullable=False)
     dataset_name = db.Column(db.Text, nullable=False)
-    document_id = db.Column(UUID, nullable=False)
+    document_id = db.Column(StringUUID, nullable=False)
     document_name = db.Column(db.Text, nullable=False)
     document_name = db.Column(db.Text, nullable=False)
     data_source_type = db.Column(db.Text, nullable=False)
     data_source_type = db.Column(db.Text, nullable=False)
-    segment_id = db.Column(UUID, nullable=False)
+    segment_id = db.Column(StringUUID, nullable=False)
     score = db.Column(db.Float, nullable=True)
     score = db.Column(db.Float, nullable=True)
     content = db.Column(db.Text, nullable=False)
     content = db.Column(db.Text, nullable=False)
     hit_count = db.Column(db.Integer, nullable=True)
     hit_count = db.Column(db.Integer, nullable=True)
@@ -1289,7 +1289,7 @@ class DatasetRetrieverResource(db.Model):
     segment_position = db.Column(db.Integer, nullable=True)
     segment_position = db.Column(db.Integer, nullable=True)
     index_node_hash = db.Column(db.Text, nullable=True)
     index_node_hash = db.Column(db.Text, nullable=True)
     retriever_from = db.Column(db.Text, nullable=False)
     retriever_from = db.Column(db.Text, nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
 
 
 
 
@@ -1303,11 +1303,11 @@ class Tag(db.Model):
 
 
     TAG_TYPE_LIST = ['knowledge', 'app']
     TAG_TYPE_LIST = ['knowledge', 'app']
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=True)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=True)
     type = db.Column(db.String(16), nullable=False)
     type = db.Column(db.String(16), nullable=False)
     name = db.Column(db.String(255), nullable=False)
     name = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
 
 
@@ -1319,9 +1319,9 @@ class TagBinding(db.Model):
         db.Index('tag_bind_tag_id_idx', 'tag_id'),
         db.Index('tag_bind_tag_id_idx', 'tag_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=True)
-    tag_id = db.Column(UUID, nullable=True)
-    target_id = db.Column(UUID, nullable=True)
-    created_by = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=True)
+    tag_id = db.Column(StringUUID, nullable=True)
+    target_id = db.Column(StringUUID, nullable=True)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

+ 12 - 13
api/models/provider.py

@@ -1,8 +1,7 @@
 from enum import Enum
 from enum import Enum
 
 
-from sqlalchemy.dialects.postgresql import UUID
-
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 
 
 
 
 class ProviderType(Enum):
 class ProviderType(Enum):
@@ -46,8 +45,8 @@ class Provider(db.Model):
         db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
         db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
     provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
     encrypted_config = db.Column(db.Text, nullable=True)
     encrypted_config = db.Column(db.Text, nullable=True)
@@ -93,8 +92,8 @@ class ProviderModel(db.Model):
         db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
         db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     model_name = db.Column(db.String(255), nullable=False)
     model_name = db.Column(db.String(255), nullable=False)
     model_type = db.Column(db.String(40), nullable=False)
     model_type = db.Column(db.String(40), nullable=False)
@@ -111,8 +110,8 @@ class TenantDefaultModel(db.Model):
         db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
         db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     model_name = db.Column(db.String(40), nullable=False)
     model_name = db.Column(db.String(40), nullable=False)
     model_type = db.Column(db.String(40), nullable=False)
     model_type = db.Column(db.String(40), nullable=False)
@@ -127,8 +126,8 @@ class TenantPreferredModelProvider(db.Model):
         db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
         db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     preferred_provider_type = db.Column(db.String(40), nullable=False)
     preferred_provider_type = db.Column(db.String(40), nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@@ -142,10 +141,10 @@ class ProviderOrder(db.Model):
         db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
         db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
     provider_name = db.Column(db.String(40), nullable=False)
-    account_id = db.Column(UUID, nullable=False)
+    account_id = db.Column(StringUUID, nullable=False)
     payment_product_id = db.Column(db.String(191), nullable=False)
     payment_product_id = db.Column(db.String(191), nullable=False)
     payment_id = db.Column(db.String(191))
     payment_id = db.Column(db.String(191))
     transaction_id = db.Column(db.String(191))
     transaction_id = db.Column(db.String(191))

+ 4 - 3
api/models/source.py

@@ -1,6 +1,7 @@
-from sqlalchemy.dialects.postgresql import JSONB, UUID
+from sqlalchemy.dialects.postgresql import JSONB
 
 
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 
 
 
 
 class DataSourceBinding(db.Model):
 class DataSourceBinding(db.Model):
@@ -11,8 +12,8 @@ class DataSourceBinding(db.Model):
         db.Index('source_info_idx', "source_info", postgresql_using='gin')
         db.Index('source_info_idx', "source_info", postgresql_using='gin')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     access_token = db.Column(db.String(255), nullable=False)
     access_token = db.Column(db.String(255), nullable=False)
     provider = db.Column(db.String(255), nullable=False)
     provider = db.Column(db.String(255), nullable=False)
     source_info = db.Column(JSONB, nullable=False)
     source_info = db.Column(JSONB, nullable=False)

+ 3 - 4
api/models/tool.py

@@ -1,9 +1,8 @@
 import json
 import json
 from enum import Enum
 from enum import Enum
 
 
-from sqlalchemy.dialects.postgresql import UUID
-
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 
 
 
 
 class ToolProviderName(Enum):
 class ToolProviderName(Enum):
@@ -24,8 +23,8 @@ class ToolProvider(db.Model):
         db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
         db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
     tool_name = db.Column(db.String(40), nullable=False)
     tool_name = db.Column(db.String(40), nullable=False)
     encrypted_credentials = db.Column(db.Text, nullable=True)
     encrypted_credentials = db.Column(db.Text, nullable=True)
     is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))

+ 21 - 21
api/models/tools.py

@@ -1,12 +1,12 @@
 import json
 import json
 
 
 from sqlalchemy import ForeignKey
 from sqlalchemy import ForeignKey
-from sqlalchemy.dialects.postgresql import UUID
 
 
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiBasedToolBundle
 from core.tools.entities.tool_bundle import ApiBasedToolBundle
 from core.tools.entities.tool_entities import ApiProviderSchemaType
 from core.tools.entities.tool_entities import ApiProviderSchemaType
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 from models.model import Account, App, Tenant
 from models.model import Account, App, Tenant
 
 
 
 
@@ -22,11 +22,11 @@ class BuiltinToolProvider(db.Model):
     )
     )
 
 
     # id of the tool provider
     # id of the tool provider
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # id of the tenant
     # id of the tenant
-    tenant_id = db.Column(UUID, nullable=True)
+    tenant_id = db.Column(StringUUID, nullable=True)
     # who created this tool provider
     # who created this tool provider
-    user_id = db.Column(UUID, nullable=False)
+    user_id = db.Column(StringUUID, nullable=False)
     # name of the tool provider
     # name of the tool provider
     provider = db.Column(db.String(40), nullable=False)
     provider = db.Column(db.String(40), nullable=False)
     # credential of the tool provider
     # credential of the tool provider
@@ -49,11 +49,11 @@ class PublishedAppTool(db.Model):
     )
     )
 
 
     # id of the tool provider
     # id of the tool provider
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # id of the app
     # id of the app
-    app_id = db.Column(UUID, ForeignKey('apps.id'), nullable=False)
+    app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
     # who published this tool
     # who published this tool
-    user_id = db.Column(UUID, nullable=False)
+    user_id = db.Column(StringUUID, nullable=False)
     # description of the tool, stored in i18n format, for human
     # description of the tool, stored in i18n format, for human
     description = db.Column(db.Text, nullable=False)
     description = db.Column(db.Text, nullable=False)
     # llm_description of the tool, for LLM
     # llm_description of the tool, for LLM
@@ -87,7 +87,7 @@ class ApiToolProvider(db.Model):
         db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider')
         db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider')
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # name of the api provider
     # name of the api provider
     name = db.Column(db.String(40), nullable=False)
     name = db.Column(db.String(40), nullable=False)
     # icon
     # icon
@@ -96,9 +96,9 @@ class ApiToolProvider(db.Model):
     schema = db.Column(db.Text, nullable=False)
     schema = db.Column(db.Text, nullable=False)
     schema_type_str = db.Column(db.String(40), nullable=False)
     schema_type_str = db.Column(db.String(40), nullable=False)
     # who created this tool
     # who created this tool
-    user_id = db.Column(UUID, nullable=False)
+    user_id = db.Column(StringUUID, nullable=False)
     # tenant id
     # tenant id
-    tenant_id = db.Column(UUID, nullable=False)
+    tenant_id = db.Column(StringUUID, nullable=False)
     # description of the provider
     # description of the provider
     description = db.Column(db.Text, nullable=False)
     description = db.Column(db.Text, nullable=False)
     # json format tools
     # json format tools
@@ -140,11 +140,11 @@ class ToolModelInvoke(db.Model):
         db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'),
         db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # who invoke this tool
     # who invoke this tool
-    user_id = db.Column(UUID, nullable=False)
+    user_id = db.Column(StringUUID, nullable=False)
     # tenant id
     # tenant id
-    tenant_id = db.Column(UUID, nullable=False)
+    tenant_id = db.Column(StringUUID, nullable=False)
     # provider
     # provider
     provider = db.Column(db.String(40), nullable=False)
     provider = db.Column(db.String(40), nullable=False)
     # type
     # type
@@ -180,13 +180,13 @@ class ToolConversationVariables(db.Model):
         db.Index('conversation_id_idx', 'conversation_id'),
         db.Index('conversation_id_idx', 'conversation_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # conversation user id
     # conversation user id
-    user_id = db.Column(UUID, nullable=False)
+    user_id = db.Column(StringUUID, nullable=False)
     # tenant id
     # tenant id
-    tenant_id = db.Column(UUID, nullable=False)
+    tenant_id = db.Column(StringUUID, nullable=False)
     # conversation id
     # conversation id
-    conversation_id = db.Column(UUID, nullable=False)
+    conversation_id = db.Column(StringUUID, nullable=False)
     # variables pool
     # variables pool
     variables_str = db.Column(db.Text, nullable=False)
     variables_str = db.Column(db.Text, nullable=False)
 
 
@@ -208,13 +208,13 @@ class ToolFile(db.Model):
         db.Index('tool_file_conversation_id_idx', 'conversation_id'),
         db.Index('tool_file_conversation_id_idx', 'conversation_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # conversation user id
     # conversation user id
-    user_id = db.Column(UUID, nullable=False)
+    user_id = db.Column(StringUUID, nullable=False)
     # tenant id
     # tenant id
-    tenant_id = db.Column(UUID, nullable=False)
+    tenant_id = db.Column(StringUUID, nullable=False)
     # conversation id
     # conversation id
-    conversation_id = db.Column(UUID, nullable=True)
+    conversation_id = db.Column(StringUUID, nullable=True)
     # file key
     # file key
     file_key = db.Column(db.String(255), nullable=False)
     file_key = db.Column(db.String(255), nullable=False)
     # mime type
     # mime type

+ 9 - 9
api/models/web.py

@@ -1,6 +1,6 @@
-from sqlalchemy.dialects.postgresql import UUID
 
 
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 from models.model import Message
 from models.model import Message
 
 
 
 
@@ -11,11 +11,11 @@ class SavedMessage(db.Model):
         db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'),
         db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    message_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    message_id = db.Column(StringUUID, nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
     @property
     @property
@@ -30,9 +30,9 @@ class PinnedConversation(db.Model):
         db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'),
         db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    app_id = db.Column(UUID, nullable=False)
-    conversation_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    app_id = db.Column(StringUUID, nullable=False)
+    conversation_id = db.Column(StringUUID, nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
     created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

+ 23 - 24
api/models/workflow.py

@@ -2,10 +2,9 @@ import json
 from enum import Enum
 from enum import Enum
 from typing import Optional, Union
 from typing import Optional, Union
 
 
-from sqlalchemy.dialects.postgresql import UUID
-
 from core.tools.tool_manager import ToolManager
 from core.tools.tool_manager import ToolManager
 from extensions.ext_database import db
 from extensions.ext_database import db
+from models import StringUUID
 from models.account import Account
 from models.account import Account
 
 
 
 
@@ -102,16 +101,16 @@ class Workflow(db.Model):
         db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
         db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    app_id = db.Column(StringUUID, nullable=False)
     type = db.Column(db.String(255), nullable=False)
     type = db.Column(db.String(255), nullable=False)
     version = db.Column(db.String(255), nullable=False)
     version = db.Column(db.String(255), nullable=False)
     graph = db.Column(db.Text)
     graph = db.Column(db.Text)
     features = db.Column(db.Text)
     features = db.Column(db.Text)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
-    updated_by = db.Column(UUID)
+    updated_by = db.Column(StringUUID)
     updated_at = db.Column(db.DateTime)
     updated_at = db.Column(db.DateTime)
 
 
     @property
     @property
@@ -245,11 +244,11 @@ class WorkflowRun(db.Model):
         db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'),
         db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    app_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    app_id = db.Column(StringUUID, nullable=False)
     sequence_number = db.Column(db.Integer, nullable=False)
     sequence_number = db.Column(db.Integer, nullable=False)
-    workflow_id = db.Column(UUID, nullable=False)
+    workflow_id = db.Column(StringUUID, nullable=False)
     type = db.Column(db.String(255), nullable=False)
     type = db.Column(db.String(255), nullable=False)
     triggered_from = db.Column(db.String(255), nullable=False)
     triggered_from = db.Column(db.String(255), nullable=False)
     version = db.Column(db.String(255), nullable=False)
     version = db.Column(db.String(255), nullable=False)
@@ -262,7 +261,7 @@ class WorkflowRun(db.Model):
     total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
     total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
     total_steps = db.Column(db.Integer, server_default=db.text('0'))
     total_steps = db.Column(db.Integer, server_default=db.text('0'))
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     finished_at = db.Column(db.DateTime)
     finished_at = db.Column(db.DateTime)
 
 
@@ -404,12 +403,12 @@ class WorkflowNodeExecution(db.Model):
                  'triggered_from', 'node_id'),
                  'triggered_from', 'node_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    app_id = db.Column(UUID, nullable=False)
-    workflow_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    app_id = db.Column(StringUUID, nullable=False)
+    workflow_id = db.Column(StringUUID, nullable=False)
     triggered_from = db.Column(db.String(255), nullable=False)
     triggered_from = db.Column(db.String(255), nullable=False)
-    workflow_run_id = db.Column(UUID)
+    workflow_run_id = db.Column(StringUUID)
     index = db.Column(db.Integer, nullable=False)
     index = db.Column(db.Integer, nullable=False)
     predecessor_node_id = db.Column(db.String(255))
     predecessor_node_id = db.Column(db.String(255))
     node_id = db.Column(db.String(255), nullable=False)
     node_id = db.Column(db.String(255), nullable=False)
@@ -424,7 +423,7 @@ class WorkflowNodeExecution(db.Model):
     execution_metadata = db.Column(db.Text)
     execution_metadata = db.Column(db.Text)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     finished_at = db.Column(db.DateTime)
     finished_at = db.Column(db.DateTime)
 
 
     @property
     @property
@@ -529,14 +528,14 @@ class WorkflowAppLog(db.Model):
         db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'),
         db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'),
     )
     )
 
 
-    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
-    tenant_id = db.Column(UUID, nullable=False)
-    app_id = db.Column(UUID, nullable=False)
-    workflow_id = db.Column(UUID, nullable=False)
-    workflow_run_id = db.Column(UUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    app_id = db.Column(StringUUID, nullable=False)
+    workflow_id = db.Column(StringUUID, nullable=False)
+    workflow_run_id = db.Column(StringUUID, nullable=False)
     created_from = db.Column(db.String(255), nullable=False)
     created_from = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(UUID, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
     @property
     @property

+ 1 - 1
api/requirements.txt

@@ -1,7 +1,7 @@
 beautifulsoup4==4.12.2
 beautifulsoup4==4.12.2
 flask~=3.0.1
 flask~=3.0.1
 Flask-SQLAlchemy~=3.0.5
 Flask-SQLAlchemy~=3.0.5
-SQLAlchemy~=1.4.28
+SQLAlchemy~=2.0.29
 Flask-Compress~=1.14
 Flask-Compress~=1.14
 flask-login~=0.6.3
 flask-login~=0.6.3
 flask-migrate~=4.0.5
 flask-migrate~=4.0.5

+ 0 - 0
api/tests/integration_tests/vdb/pgvecto_rs/__init__.py


+ 37 - 0
api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py

@@ -0,0 +1,37 @@
+from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
+from tests.integration_tests.vdb.test_vector_store import (
+    AbstractVectorTest,
+    get_example_text,
+    setup_mock_redis,
+)
+
+
+class TestPgvectoRSVector(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = PGVectoRS(
+            collection_name=self.collection_name.lower(),
+            config=PgvectoRSConfig(
+                host='localhost',
+                port=5431,
+                user='postgres',
+                password='difyai123456',
+                database='dify',
+            ),
+            dim=128
+        )
+
+    def search_by_full_text(self):
+        # pgvecto rs only support english text search, So it’s not open for now
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+    def delete_by_document_id(self):
+        self.vector.delete_by_document_id(document_id=self.example_doc_id)
+
+    def get_ids_by_metadata_field(self):
+        ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+        assert len(ids) == 1
+
+def test_pgvecot_rs(setup_mock_redis):
+    TestPgvectoRSVector().run_all_tests()

+ 1 - 1
api/tests/integration_tests/vdb/test_vector_store.py

@@ -45,7 +45,7 @@ class AbstractVectorTest:
     def __init__(self):
     def __init__(self):
         self.vector = None
         self.vector = None
         self.dataset_id = str(uuid.uuid4())
         self.dataset_id = str(uuid.uuid4())
-        self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
+        self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test'
         self.example_doc_id = str(uuid.uuid4())
         self.example_doc_id = str(uuid.uuid4())
         self.example_embedding = [1.001 * i for i in range(128)]
         self.example_embedding = [1.001 * i for i in range(128)]
 
 

+ 24 - 0
docker/docker-compose.pgvecto-rs.yaml

@@ -0,0 +1,24 @@
+version: '3'
+services:
+  # The pgvecto—rs database.
+  pgvecto-rs:
+    image: tensorchord/pgvecto-rs:pg16-v0.2.0
+    restart: always
+    environment:
+      PGUSER: postgres
+      # The password for the default postgres user.
+      POSTGRES_PASSWORD: difyai123456
+      # The name of the default postgres database.
+      POSTGRES_DB: dify
+      # postgres data directory
+      PGDATA: /var/lib/postgresql/data/pgdata
+    volumes:
+      - ./volumes/pgvectors/data:/var/lib/postgresql/data
+    # uncomment to expose db(postgresql) port to host
+    ports:
+      - "5431:5432"
+    healthcheck:
+      test: [ "CMD", "pg_isready" ]
+      interval: 1s
+      timeout: 3s
+      retries: 30