瀏覽代碼

Fix/type-error (#11240)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 月之前
父節點
當前提交
c34bdb74e6

+ 4 - 1
api/core/app/app_config/easy_ui_based_app/model_config/manager.py

@@ -1,3 +1,6 @@
+from collections.abc import Mapping
+from typing import Any
+
 from core.app.app_config.entities import ModelConfigEntity
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers import model_provider_factory
@@ -36,7 +39,7 @@ class ModelConfigManager:
         )
 
     @classmethod
-    def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
+    def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
         """
         Validate and set defaults for model config
 

+ 12 - 34
api/core/app/apps/advanced_chat/app_generator.py

@@ -2,8 +2,8 @@ import contextvars
 import logging
 import threading
 import uuid
-from collections.abc import Generator
-from typing import Any, Literal, Optional, Union, overload
+from collections.abc import Generator, Mapping
+from typing import Any, Optional, Union
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -33,37 +33,15 @@ logger = logging.getLogger(__name__)
 
 
 class AdvancedChatAppGenerator(MessageBasedAppGenerator):
-    @overload
     def generate(
         self,
         app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
-        args: dict,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: Literal[True] = True,
-    ) -> Generator[str, None, None]: ...
-
-    @overload
-    def generate(
-        self,
-        app_model: App,
-        workflow: Workflow,
-        user: Union[Account, EndUser],
-        args: dict,
-        invoke_from: InvokeFrom,
-        stream: Literal[False] = False,
-    ) -> dict: ...
-
-    def generate(
-        self,
-        app_model: App,
-        workflow: Workflow,
-        user: Union[Account, EndUser],
-        args: dict,
-        invoke_from: InvokeFrom,
-        stream: bool = True,
-    ) -> dict[str, Any] | Generator[str, Any, None]:
+        streaming: bool = True,
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
 
@@ -134,7 +112,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             files=file_objs,
             parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
             user_id=user.id,
-            stream=stream,
+            stream=streaming,
             invoke_from=invoke_from,
             extras=extras,
             trace_manager=trace_manager,
@@ -148,12 +126,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=invoke_from,
             application_generate_entity=application_generate_entity,
             conversation=conversation,
-            stream=stream,
+            stream=streaming,
         )
 
     def single_iteration_generate(
-        self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
-    ) -> dict[str, Any] | Generator[str, Any, None]:
+        self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
 
@@ -182,7 +160,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             query="",
             files=[],
             user_id=user.id,
-            stream=stream,
+            stream=streaming,
             invoke_from=InvokeFrom.DEBUGGER,
             extras={"auto_generate_conversation_name": False},
             single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
@@ -197,7 +175,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=InvokeFrom.DEBUGGER,
             application_generate_entity=application_generate_entity,
             conversation=None,
-            stream=stream,
+            stream=streaming,
         )
 
     def _generate(
@@ -209,7 +187,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         application_generate_entity: AdvancedChatAppGenerateEntity,
         conversation: Optional[Conversation] = None,
         stream: bool = True,
-    ) -> dict[str, Any] | Generator[str, Any, None]:
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
 

+ 3 - 2
api/core/app/apps/agent_chat/app_config_manager.py

@@ -1,5 +1,6 @@
 import uuid
-from typing import Optional
+from collections.abc import Mapping
+from typing import Any, Optional
 
 from core.agent.entities import AgentEntity
 from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@@ -85,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
         return app_config
 
     @classmethod
-    def config_validate(cls, tenant_id: str, config: dict) -> dict:
+    def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
         """
         Validate for agent chat app model config
 

+ 11 - 29
api/core/app/apps/agent_chat/app_generator.py

@@ -1,8 +1,8 @@
 import logging
 import threading
 import uuid
-from collections.abc import Generator
-from typing import Any, Literal, Union, overload
+from collections.abc import Generator, Mapping
+from typing import Any, Union
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -28,34 +28,15 @@ logger = logging.getLogger(__name__)
 
 
 class AgentChatAppGenerator(MessageBasedAppGenerator):
-    @overload
     def generate(
         self,
+        *,
         app_model: App,
         user: Union[Account, EndUser],
-        args: dict,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: Literal[True] = True,
-    ) -> Generator[dict, None, None]: ...
-
-    @overload
-    def generate(
-        self,
-        app_model: App,
-        user: Union[Account, EndUser],
-        args: dict,
-        invoke_from: InvokeFrom,
-        stream: Literal[False] = False,
-    ) -> dict: ...
-
-    def generate(
-        self,
-        app_model: App,
-        user: Union[Account, EndUser],
-        args: Any,
-        invoke_from: InvokeFrom,
-        stream: bool = True,
-    ) -> Union[dict, Generator[dict, None, None]]:
+        streaming: bool = True,
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
 
@@ -65,7 +46,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
         :param invoke_from: invoke from source
         :param stream: is stream
         """
-        if not stream:
+        if not streaming:
             raise ValueError("Agent Chat App does not support blocking mode")
 
         if not args.get("query"):
@@ -96,7 +77,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
 
             # validate config
             override_model_config_dict = AgentChatAppConfigManager.config_validate(
-                tenant_id=app_model.tenant_id, config=args.get("model_config")
+                tenant_id=app_model.tenant_id,
+                config=args["model_config"],
             )
 
             # always enable retriever resource in debugger mode
@@ -141,7 +123,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             files=file_objs,
             parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
             user_id=user.id,
-            stream=stream,
+            stream=streaming,
             invoke_from=invoke_from,
             extras=extras,
             call_depth=0,
@@ -182,7 +164,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             conversation=conversation,
             message=message,
             user=user,
-            stream=stream,
+            stream=streaming,
         )
 
         return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

+ 5 - 3
api/core/app/apps/base_app_generate_response_converter.py

@@ -1,6 +1,6 @@
 import logging
 from abc import ABC, abstractmethod
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from typing import Any, Union
 
 from core.app.entities.app_invoke_entities import InvokeFrom
@@ -14,8 +14,10 @@ class AppGenerateResponseConverter(ABC):
 
     @classmethod
     def convert(
-        cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
-    ) -> dict[str, Any] | Generator[str, Any, None]:
+        cls,
+        response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
+        invoke_from: InvokeFrom,
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_full_response(response)

+ 3 - 3
api/core/app/apps/chat/app_generator.py

@@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
         user: Union[Account, EndUser],
         args: Any,
         invoke_from: InvokeFrom,
-        stream: bool = True,
+        streaming: bool = True,
     ) -> Union[dict, Generator[str, None, None]]:
         """
         Generate App response.
@@ -142,7 +142,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             invoke_from=invoke_from,
             extras=extras,
             trace_manager=trace_manager,
-            stream=stream,
+            stream=streaming,
         )
 
         # init generate records
@@ -179,7 +179,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
             conversation=conversation,
             message=message,
             user=user,
-            stream=stream,
+            stream=streaming,
         )
 
         return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

+ 3 - 3
api/core/app/apps/completion/app_generator.py

@@ -50,7 +50,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
     ) -> dict: ...
 
     def generate(
-        self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
+        self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
     ) -> Union[dict, Generator[str, None, None]]:
         """
         Generate App response.
@@ -119,7 +119,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             query=query,
             files=file_objs,
             user_id=user.id,
-            stream=stream,
+            stream=streaming,
             invoke_from=invoke_from,
             extras=extras,
             trace_manager=trace_manager,
@@ -158,7 +158,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
             conversation=conversation,
             message=message,
             user=user,
-            stream=stream,
+            stream=streaming,
         )
 
         return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

+ 20 - 50
api/core/app/apps/workflow/app_generator.py

@@ -3,7 +3,7 @@ import logging
 import threading
 import uuid
 from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Literal, Optional, Union, overload
+from typing import Any, Optional, Union
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -30,43 +30,18 @@ logger = logging.getLogger(__name__)
 
 
 class WorkflowAppGenerator(BaseAppGenerator):
-    @overload
-    def generate(
-        self,
-        app_model: App,
-        workflow: Workflow,
-        user: Union[Account, EndUser],
-        args: dict,
-        invoke_from: InvokeFrom,
-        stream: Literal[True] = True,
-        call_depth: int = 0,
-        workflow_thread_pool_id: Optional[str] = None,
-    ) -> Generator[str, None, None]: ...
-
-    @overload
-    def generate(
-        self,
-        app_model: App,
-        workflow: Workflow,
-        user: Union[Account, EndUser],
-        args: dict,
-        invoke_from: InvokeFrom,
-        stream: Literal[False] = False,
-        call_depth: int = 0,
-        workflow_thread_pool_id: Optional[str] = None,
-    ) -> dict: ...
-
     def generate(
         self,
+        *,
         app_model: App,
         workflow: Workflow,
-        user: Union[Account, EndUser],
+        user: Account | EndUser,
         args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: bool = True,
+        streaming: bool = True,
         call_depth: int = 0,
         workflow_thread_pool_id: Optional[str] = None,
-    ):
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         files: Sequence[Mapping[str, Any]] = args.get("files") or []
 
         # parse files
@@ -101,7 +76,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
             ),
             files=system_files,
             user_id=user.id,
-            stream=stream,
+            stream=streaming,
             invoke_from=invoke_from,
             call_depth=call_depth,
             trace_manager=trace_manager,
@@ -115,7 +90,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
             user=user,
             application_generate_entity=application_generate_entity,
             invoke_from=invoke_from,
-            stream=stream,
+            streaming=streaming,
             workflow_thread_pool_id=workflow_thread_pool_id,
         )
 
@@ -127,20 +102,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
         user: Union[Account, EndUser],
         application_generate_entity: WorkflowAppGenerateEntity,
         invoke_from: InvokeFrom,
-        stream: bool = True,
+        streaming: bool = True,
         workflow_thread_pool_id: Optional[str] = None,
-    ) -> dict[str, Any] | Generator[str, None, None]:
-        """
-        Generate App response.
-
-        :param app_model: App
-        :param workflow: Workflow
-        :param user: account or end user
-        :param application_generate_entity: application generate entity
-        :param invoke_from: invoke from source
-        :param stream: is stream
-        :param workflow_thread_pool_id: workflow thread pool id
-        """
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         # init queue manager
         queue_manager = WorkflowAppQueueManager(
             task_id=application_generate_entity.task_id,
@@ -169,14 +133,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
             workflow=workflow,
             queue_manager=queue_manager,
             user=user,
-            stream=stream,
+            stream=streaming,
         )
 
         return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
 
     def single_iteration_generate(
-        self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
-    ) -> dict[str, Any] | Generator[str, Any, None]:
+        self,
+        app_model: App,
+        workflow: Workflow,
+        node_id: str,
+        user: Account,
+        args: Mapping[str, Any],
+        streaming: bool = True,
+    ) -> Mapping[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
 
@@ -203,7 +173,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
             inputs={},
             files=[],
             user_id=user.id,
-            stream=stream,
+            stream=streaming,
             invoke_from=InvokeFrom.DEBUGGER,
             extras={"auto_generate_conversation_name": False},
             single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
@@ -218,7 +188,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
             user=user,
             invoke_from=InvokeFrom.DEBUGGER,
             application_generate_entity=application_generate_entity,
-            stream=stream,
+            streaming=streaming,
         )
 
     def _generate_worker(

+ 7 - 10
api/core/app/features/rate_limiting/rate_limit.py

@@ -1,9 +1,9 @@
 import logging
 import time
 import uuid
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from datetime import timedelta
-from typing import Optional, Union
+from typing import Any, Optional, Union
 
 from core.errors.error import AppInvokeQuotaExceededError
 from extensions.ext_redis import redis_client
@@ -88,20 +88,17 @@ class RateLimit:
     def gen_request_key() -> str:
         return str(uuid.uuid4())
 
-    def generate(self, generator: Union[Generator, callable, dict], request_id: str):
-        if isinstance(generator, dict):
+    def generate(self, generator: Union[Generator[str, None, None], Mapping[str, Any]], request_id: str):
+        if isinstance(generator, Mapping):
             return generator
         else:
-            return RateLimitGenerator(self, generator, request_id)
+            return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
 
 
 class RateLimitGenerator:
-    def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
+    def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
         self.rate_limit = rate_limit
-        if callable(generator):
-            self.generator = generator()
-        else:
-            self.generator = generator
+        self.generator = generator
         self.request_id = request_id
         self.closed = False
 

+ 4 - 2
api/libs/helper.py

@@ -6,7 +6,7 @@ import string
 import subprocess
 import time
 import uuid
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from datetime import datetime
 from hashlib import sha256
 from typing import Any, Optional, Union
@@ -180,7 +180,9 @@ def generate_text_hash(text: str) -> str:
     return sha256(hash_text.encode()).hexdigest()
 
 
-def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response:
+def compact_generate_response(
+    response: Union[Mapping[str, Any], RateLimitGenerator, Generator[str, None, None]],
+) -> Response:
     if isinstance(response, dict):
         return Response(response=json.dumps(response), status=200, mimetype="application/json")
     else:

+ 45 - 24
api/services/app_generate_service.py

@@ -43,50 +43,66 @@ class AppGenerateService:
             request_id = rate_limit.enter(request_id)
             if app_model.mode == AppMode.COMPLETION.value:
                 return rate_limit.generate(
-                    CompletionAppGenerator().generate(
-                        app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
+                    generator=CompletionAppGenerator().generate(
+                        app_model=app_model,
+                        user=user,
+                        args=args,
+                        invoke_from=invoke_from,
+                        streaming=streaming,
                     ),
-                    request_id,
+                    request_id=request_id,
                 )
             elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
+                generator = AgentChatAppGenerator().generate(
+                    app_model=app_model,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    streaming=streaming,
+                )
                 return rate_limit.generate(
-                    AgentChatAppGenerator().generate(
-                        app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
-                    ),
-                    request_id,
+                    generator=generator,
+                    request_id=request_id,
                 )
             elif app_model.mode == AppMode.CHAT.value:
                 return rate_limit.generate(
-                    ChatAppGenerator().generate(
-                        app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
+                    generator=ChatAppGenerator().generate(
+                        app_model=app_model,
+                        user=user,
+                        args=args,
+                        invoke_from=invoke_from,
+                        streaming=streaming,
                     ),
-                    request_id,
+                    request_id=request_id,
                 )
             elif app_model.mode == AppMode.ADVANCED_CHAT.value:
                 workflow = cls._get_workflow(app_model, invoke_from)
                 return rate_limit.generate(
-                    AdvancedChatAppGenerator().generate(
+                    generator=AdvancedChatAppGenerator().generate(
                         app_model=app_model,
                         workflow=workflow,
                         user=user,
                         args=args,
                         invoke_from=invoke_from,
-                        stream=streaming,
+                        streaming=streaming,
                     ),
-                    request_id,
+                    request_id=request_id,
                 )
             elif app_model.mode == AppMode.WORKFLOW.value:
                 workflow = cls._get_workflow(app_model, invoke_from)
+                generator = WorkflowAppGenerator().generate(
+                    app_model=app_model,
+                    workflow=workflow,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    streaming=streaming,
+                    call_depth=0,
+                    workflow_thread_pool_id=None,
+                )
                 return rate_limit.generate(
-                    WorkflowAppGenerator().generate(
-                        app_model=app_model,
-                        workflow=workflow,
-                        user=user,
-                        args=args,
-                        invoke_from=invoke_from,
-                        stream=streaming,
-                    ),
-                    request_id,
+                    generator=generator,
+                    request_id=request_id,
                 )
             else:
                 raise ValueError(f"Invalid app mode {app_model.mode}")
@@ -108,12 +124,17 @@ class AppGenerateService:
         if app_model.mode == AppMode.ADVANCED_CHAT.value:
             workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
             return AdvancedChatAppGenerator().single_iteration_generate(
-                app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
+                app_model=app_model,
+                workflow=workflow,
+                node_id=node_id,
+                user=user,
+                args=args,
+                streaming=streaming,
             )
         elif app_model.mode == AppMode.WORKFLOW.value:
             workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
             return WorkflowAppGenerator().single_iteration_generate(
-                app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
+                app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
             )
         else:
             raise ValueError(f"Invalid app mode {app_model.mode}")