Browse Source

chore(db): use a better way to export models and remove unused table (#11838)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 tháng trước cách đây
mục cha
commit
3599751f93

+ 1 - 2
api/controllers/console/app/wraps.py

@@ -5,8 +5,7 @@ from typing import Optional, Union
 from controllers.console.app.error import AppNotFoundError
 from extensions.ext_database import db
 from libs.login import current_user
-from models import App
-from models.model import AppMode
+from models import App, AppMode
 
 
 def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

+ 1 - 1
api/core/helper/encrypter.py

@@ -1,6 +1,5 @@
 import base64
 
-from extensions.ext_database import db
 from libs import rsa
 
 
@@ -14,6 +13,7 @@ def obfuscated_token(token: str):
 
 def encrypt_token(tenant_id: str, token: str):
     from models.account import Tenant
+    from models.engine import db
 
     if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
         raise ValueError(f"Tenant with id {tenant_id} not found")

+ 1 - 14
api/extensions/ext_database.py

@@ -1,18 +1,5 @@
-from flask_sqlalchemy import SQLAlchemy
-from sqlalchemy import MetaData
-
 from dify_app import DifyApp
-
-POSTGRES_INDEXES_NAMING_CONVENTION = {
-    "ix": "%(column_0_label)s_idx",
-    "uq": "%(table_name)s_%(column_0_name)s_key",
-    "ck": "%(table_name)s_%(constraint_name)s_check",
-    "fk": "%(table_name)s_%(column_0_name)s_fkey",
-    "pk": "%(table_name)s_pkey",
-}
-
-metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
-db = SQLAlchemy(metadata=metadata)
+from models import db
 
 
 def init_app(app: DifyApp):

+ 0 - 1
api/extensions/ext_import_modules.py

@@ -3,4 +3,3 @@ from dify_app import DifyApp
 
 def init_app(app: DifyApp):
     from events import event_handlers  # noqa: F401
-    from models import account, dataset, model, source, task, tool, tools, web  # noqa: F401

+ 1 - 1
api/libs/helper.py

@@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast
 from zoneinfo import available_timezones
 
 from flask import Response, stream_with_context
-from flask_restful import fields  # type: ignore
+from flask_restful import fields
 
 from configs import dify_config
 from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

+ 39 - 0
api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py

@@ -0,0 +1,39 @@
+"""remove unused tool_providers
+
+Revision ID: 11b07f66c737
+Revises: cf8f4fc45278
+Create Date: 2024-12-19 17:46:25.780116
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '11b07f66c737'
+down_revision = 'cf8f4fc45278'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.drop_table('tool_providers')
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('tool_providers',
+    sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
+    sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
+    sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
+    sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
+    sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
+    sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
+    sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
+    sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
+    sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
+    )
+    # ### end Alembic commands ###

+ 138 - 4
api/models/__init__.py

@@ -1,53 +1,187 @@
-from .account import Account, AccountIntegrate, InvitationCode, Tenant
-from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
+from .account import (
+    Account,
+    AccountIntegrate,
+    AccountStatus,
+    InvitationCode,
+    Tenant,
+    TenantAccountJoin,
+    TenantAccountJoinRole,
+    TenantAccountRole,
+    TenantStatus,
+)
+from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
+from .dataset import (
+    AppDatasetJoin,
+    Dataset,
+    DatasetCollectionBinding,
+    DatasetKeywordTable,
+    DatasetPermission,
+    DatasetPermissionEnum,
+    DatasetProcessRule,
+    DatasetQuery,
+    Document,
+    DocumentSegment,
+    Embedding,
+    ExternalKnowledgeApis,
+    ExternalKnowledgeBindings,
+    TidbAuthBinding,
+    Whitelist,
+)
+from .engine import db
+from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
 from .model import (
+    ApiRequest,
     ApiToken,
     App,
+    AppAnnotationHitHistory,
+    AppAnnotationSetting,
     AppMode,
+    AppModelConfig,
     Conversation,
+    DatasetRetrieverResource,
+    DifySetup,
     EndUser,
+    IconType,
     InstalledApp,
     Message,
+    MessageAgentThought,
     MessageAnnotation,
+    MessageChain,
+    MessageFeedback,
     MessageFile,
+    OperationLog,
     RecommendedApp,
     Site,
+    Tag,
+    TagBinding,
+    TraceAppConfig,
     UploadFile,
 )
-from .source import DataSourceOauthBinding
-from .tools import ToolFile
+from .provider import (
+    LoadBalancingModelConfig,
+    Provider,
+    ProviderModel,
+    ProviderModelSetting,
+    ProviderOrder,
+    ProviderQuotaType,
+    ProviderType,
+    TenantDefaultModel,
+    TenantPreferredModelProvider,
+)
+from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
+from .task import CeleryTask, CeleryTaskSet
+from .tools import (
+    ApiToolProvider,
+    BuiltinToolProvider,
+    PublishedAppTool,
+    ToolConversationVariables,
+    ToolFile,
+    ToolLabelBinding,
+    ToolModelInvoke,
+    WorkflowToolProvider,
+)
+from .web import PinnedConversation, SavedMessage
 from .workflow import (
     ConversationVariable,
     Workflow,
     WorkflowAppLog,
+    WorkflowAppLogCreatedFrom,
+    WorkflowNodeExecution,
+    WorkflowNodeExecutionStatus,
+    WorkflowNodeExecutionTriggeredFrom,
     WorkflowRun,
+    WorkflowRunStatus,
+    WorkflowType,
 )
 
 __all__ = [
+    "APIBasedExtension",
+    "APIBasedExtensionPoint",
     "Account",
     "AccountIntegrate",
+    "AccountStatus",
+    "ApiRequest",
     "ApiToken",
+    "ApiToolProvider",  # Added
     "App",
+    "AppAnnotationHitHistory",
+    "AppAnnotationSetting",
+    "AppDatasetJoin",
     "AppMode",
+    "AppModelConfig",
+    "BuiltinToolProvider",  # Added
+    "CeleryTask",
+    "CeleryTaskSet",
     "Conversation",
     "ConversationVariable",
+    "CreatedByRole",
+    "DataSourceApiKeyAuthBinding",
     "DataSourceOauthBinding",
     "Dataset",
+    "DatasetCollectionBinding",
+    "DatasetKeywordTable",
+    "DatasetPermission",
+    "DatasetPermissionEnum",
     "DatasetProcessRule",
+    "DatasetQuery",
+    "DatasetRetrieverResource",
+    "DifySetup",
     "Document",
     "DocumentSegment",
+    "Embedding",
     "EndUser",
+    "ExternalKnowledgeApis",
+    "ExternalKnowledgeBindings",
+    "IconType",
     "InstalledApp",
     "InvitationCode",
+    "LoadBalancingModelConfig",
     "Message",
+    "MessageAgentThought",
     "MessageAnnotation",
+    "MessageChain",
+    "MessageFeedback",
     "MessageFile",
+    "OperationLog",
+    "PinnedConversation",
+    "Provider",
+    "ProviderModel",
+    "ProviderModelSetting",
+    "ProviderOrder",
+    "ProviderQuotaType",
+    "ProviderType",
+    "PublishedAppTool",
     "RecommendedApp",
+    "SavedMessage",
     "Site",
+    "Tag",
+    "TagBinding",
     "Tenant",
+    "TenantAccountJoin",
+    "TenantAccountJoinRole",
+    "TenantAccountRole",
+    "TenantDefaultModel",
+    "TenantPreferredModelProvider",
+    "TenantStatus",
+    "TidbAuthBinding",
+    "ToolConversationVariables",
     "ToolFile",
+    "ToolLabelBinding",
+    "ToolModelInvoke",
+    "TraceAppConfig",
     "UploadFile",
+    "UserFrom",
+    "Whitelist",
     "Workflow",
     "WorkflowAppLog",
+    "WorkflowAppLogCreatedFrom",
+    "WorkflowNodeExecution",
+    "WorkflowNodeExecutionStatus",
+    "WorkflowNodeExecutionTriggeredFrom",
     "WorkflowRun",
+    "WorkflowRunStatus",
+    "WorkflowRunTriggeredFrom",
+    "WorkflowToolProvider",
+    "WorkflowType",
+    "db",
 ]

+ 1 - 2
api/models/account.py

@@ -3,8 +3,7 @@ import json
 
 from flask_login import UserMixin
 
-from extensions.ext_database import db
-
+from .engine import db
 from .types import StringUUID
 
 

+ 1 - 2
api/models/api_based_extension.py

@@ -1,7 +1,6 @@
 import enum
 
-from extensions.ext_database import db
-
+from .engine import db
 from .types import StringUUID
 
 

+ 1 - 1
api/models/dataset.py

@@ -15,10 +15,10 @@ from sqlalchemy.dialects.postgresql import JSONB
 
 from configs import dify_config
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
-from extensions.ext_database import db
 from extensions.ext_storage import storage
 
 from .account import Account
+from .engine import db
 from .model import App, Tag, TagBinding, UploadFile
 from .types import StringUUID
 

+ 13 - 0
api/models/engine.py

@@ -0,0 +1,13 @@
+from flask_sqlalchemy import SQLAlchemy
+from sqlalchemy import MetaData
+
+POSTGRES_INDEXES_NAMING_CONVENTION = {
+    "ix": "%(column_0_label)s_idx",
+    "uq": "%(table_name)s_%(column_0_name)s_key",
+    "ck": "%(table_name)s_%(constraint_name)s_check",
+    "fk": "%(table_name)s_%(column_0_name)s_fkey",
+    "pk": "%(table_name)s_pkey",
+}
+
+metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
+db = SQLAlchemy(metadata=metadata)

+ 1 - 1
api/models/model.py

@@ -16,11 +16,11 @@ from configs import dify_config
 from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
 from core.file import helpers as file_helpers
 from core.file.tool_file_parser import ToolFileParser
-from extensions.ext_database import db
 from libs.helper import generate_string
 from models.enums import CreatedByRole
 
 from .account import Account, Tenant
+from .engine import db
 from .types import StringUUID
 
 

+ 1 - 2
api/models/provider.py

@@ -1,7 +1,6 @@
 from enum import Enum
 
-from extensions.ext_database import db
-
+from .engine import db
 from .types import StringUUID
 
 

+ 1 - 2
api/models/source.py

@@ -2,8 +2,7 @@ import json
 
 from sqlalchemy.dialects.postgresql import JSONB
 
-from extensions.ext_database import db
-
+from .engine import db
 from .types import StringUUID
 
 

+ 1 - 1
api/models/task.py

@@ -2,7 +2,7 @@ from datetime import UTC, datetime
 
 from celery import states
 
-from extensions.ext_database import db
+from .engine import db
 
 
 class CeleryTask(db.Model):

+ 0 - 47
api/models/tool.py

@@ -1,47 +0,0 @@
-import json
-from enum import Enum
-
-from extensions.ext_database import db
-
-from .types import StringUUID
-
-
-class ToolProviderName(Enum):
-    SERPAPI = "serpapi"
-
-    @staticmethod
-    def value_of(value):
-        for member in ToolProviderName:
-            if member.value == value:
-                return member
-        raise ValueError(f"No matching enum found for value '{value}'")
-
-
-class ToolProvider(db.Model):
-    __tablename__ = "tool_providers"
-    __table_args__ = (
-        db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),
-        db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"),
-    )
-
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
-    tenant_id = db.Column(StringUUID, nullable=False)
-    tool_name = db.Column(db.String(40), nullable=False)
-    encrypted_credentials = db.Column(db.Text, nullable=True)
-    is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
-    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
-    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
-
-    @property
-    def credentials_is_set(self):
-        """
-        Returns True if the encrypted_config is not None, indicating that the token is set.
-        """
-        return self.encrypted_credentials is not None
-
-    @property
-    def credentials(self):
-        """
-        Returns the decrypted config.
-        """
-        return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None

+ 2 - 6
api/models/tools.py

@@ -8,8 +8,8 @@ from sqlalchemy.orm import Mapped, mapped_column
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
-from extensions.ext_database import db
 
+from .engine import db
 from .model import Account, App, Tenant
 from .types import StringUUID
 
@@ -82,7 +82,7 @@ class PublishedAppTool(db.Model):
         return I18nObject(**json.loads(self.description))
 
     @property
-    def app(self) -> App:
+    def app(self):
         return db.session.query(App).filter(App.id == self.app_id).first()
 
 
@@ -201,10 +201,6 @@ class WorkflowToolProvider(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
-    @property
-    def schema_type(self) -> ApiProviderSchemaType:
-        return ApiProviderSchemaType.value_of(self.schema_type_str)
-
     @property
     def user(self) -> Account | None:
         return db.session.query(Account).filter(Account.id == self.user_id).first()

+ 1 - 2
api/models/web.py

@@ -1,5 +1,4 @@
-from extensions.ext_database import db
-
+from .engine import db
 from .model import Message
 from .types import StringUUID
 

+ 2 - 2
api/models/workflow.py

@@ -12,12 +12,12 @@ import contexts
 from constants import HIDDEN_VALUE
 from core.helper import encrypter
 from core.variables import SecretVariable, Variable
-from extensions.ext_database import db
 from factories import variable_factory
 from libs import helper
 from models.enums import CreatedByRole
 
 from .account import Account
+from .engine import db
 from .types import StringUUID
 
 
@@ -399,7 +399,7 @@ class WorkflowRun(db.Model):
     graph = db.Column(db.Text)
     inputs = db.Column(db.Text)
     status = db.Column(db.String(255), nullable=False)  # running, succeeded, failed, stopped, partial-succeeded
-    outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
+    outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
     error = db.Column(db.Text)
     elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
     total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))