Переглянути джерело

fix(app_generator_service): overload type hints (#11507)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 місяців тому
батько
коміт
fd354d999d

+ 35 - 2
api/core/app/apps/advanced_chat/app_generator.py

@@ -3,7 +3,7 @@ import logging
 import threading
 import uuid
 from collections.abc import Generator, Mapping
-from typing import Any, Optional, Union
+from typing import Any, Literal, Optional, Union, overload
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -36,6 +36,29 @@ logger = logging.getLogger(__name__)
 class AdvancedChatAppGenerator(MessageBasedAppGenerator):
     _dialogue_count: int
 
+    @overload
+    def generate(
+        self,
+        app_model: App,
+        workflow: Workflow,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: Literal[True],
+    ) -> Generator[str, None, None]: ...
+
+    @overload
+    def generate(
+        self,
+        app_model: App,
+        workflow: Workflow,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: Literal[False],
+    ) -> Mapping[str, Any]: ...
+
+    @overload
     def generate(
         self,
         app_model: App,
@@ -44,7 +67,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         streaming: bool = True,
-    ) -> Mapping[str, Any] | Generator[str, None, None]:
+    ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
+
+    def generate(
+        self,
+        app_model: App,
+        workflow: Workflow,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: bool = True,
+    ):
         """
         Generate App response.
 

+ 35 - 2
api/core/app/apps/agent_chat/app_generator.py

@@ -2,7 +2,7 @@ import logging
 import threading
 import uuid
 from collections.abc import Generator, Mapping
-from typing import Any, Union
+from typing import Any, Literal, Union, overload
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -28,6 +28,39 @@ logger = logging.getLogger(__name__)
 
 
 class AgentChatAppGenerator(MessageBasedAppGenerator):
+    @overload
+    def generate(
+        self,
+        *,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: Literal[True],
+    ) -> Generator[str, None, None]: ...
+
+    @overload
+    def generate(
+        self,
+        *,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: Literal[False],
+    ) -> Mapping[str, Any]: ...
+
+    @overload
+    def generate(
+        self,
+        *,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: bool,
+    ) -> Mapping[str, Any] | Generator[str, None, None]: ...
+
     def generate(
         self,
         *,
@@ -36,7 +69,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
         args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         streaming: bool = True,
-    ) -> Mapping[str, Any] | Generator[str, None, None]:
+    ):
         """
         Generate App response.
 

+ 18 - 8
api/core/app/apps/chat/app_generator.py

@@ -1,7 +1,7 @@
 import logging
 import threading
 import uuid
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from typing import Any, Literal, Union, overload
 
 from flask import Flask, current_app
@@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
         self,
         app_model: App,
         user: Union[Account, EndUser],
-        args: Any,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: Literal[True] = True,
+        streaming: Literal[True],
     ) -> Generator[str, None, None]: ...
 
     @overload
@@ -44,19 +44,29 @@ class ChatAppGenerator(MessageBasedAppGenerator):
         self,
         app_model: App,
         user: Union[Account, EndUser],
-        args: Any,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: Literal[False] = False,
-    ) -> dict: ...
+        streaming: Literal[False],
+    ) -> Mapping[str, Any]: ...
+
+    @overload
+    def generate(
+        self,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: bool,
+    ) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
 
     def generate(
         self,
         app_model: App,
         user: Union[Account, EndUser],
-        args: Any,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
         streaming: bool = True,
-    ) -> Union[dict, Generator[str, None, None]]:
+    ):
         """
         Generate App response.
 

+ 23 - 8
api/core/app/apps/completion/app_generator.py

@@ -1,7 +1,7 @@
 import logging
 import threading
 import uuid
-from collections.abc import Generator
+from collections.abc import Generator, Mapping
 from typing import Any, Literal, Union, overload
 
 from flask import Flask, current_app
@@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         self,
         app_model: App,
         user: Union[Account, EndUser],
-        args: dict,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: Literal[True] = True,
+        streaming: Literal[True],
     ) -> Generator[str, None, None]: ...
 
     @overload
@@ -44,14 +44,29 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
         self,
         app_model: App,
         user: Union[Account, EndUser],
-        args: dict,
+        args: Mapping[str, Any],
         invoke_from: InvokeFrom,
-        stream: Literal[False] = False,
-    ) -> dict: ...
+        streaming: Literal[False],
+    ) -> Mapping[str, Any]: ...
 
+    @overload
     def generate(
-        self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
-    ) -> Union[dict, Generator[str, None, None]]:
+        self,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: bool,
+    ) -> Mapping[str, Any] | Generator[str, None, None]: ...
+
+    def generate(
+        self,
+        app_model: App,
+        user: Union[Account, EndUser],
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: bool = True,
+    ):
         """
         Generate App response.
 

+ 44 - 2
api/core/app/apps/workflow/app_generator.py

@@ -3,7 +3,7 @@ import logging
 import threading
 import uuid
 from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Optional, Union
+from typing import Any, Literal, Optional, Union, overload
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -30,6 +30,35 @@ logger = logging.getLogger(__name__)
 
 
 class WorkflowAppGenerator(BaseAppGenerator):
+    @overload
+    def generate(
+        self,
+        *,
+        app_model: App,
+        workflow: Workflow,
+        user: Account | EndUser,
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: Literal[True],
+        call_depth: int = 0,
+        workflow_thread_pool_id: Optional[str] = None,
+    ) -> Generator[str, None, None]: ...
+
+    @overload
+    def generate(
+        self,
+        *,
+        app_model: App,
+        workflow: Workflow,
+        user: Account | EndUser,
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: Literal[False],
+        call_depth: int = 0,
+        workflow_thread_pool_id: Optional[str] = None,
+    ) -> Mapping[str, Any]: ...
+
+    @overload
     def generate(
         self,
         *,
@@ -41,7 +70,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
         streaming: bool = True,
         call_depth: int = 0,
         workflow_thread_pool_id: Optional[str] = None,
-    ) -> Mapping[str, Any] | Generator[str, None, None]:
+    ) -> Mapping[str, Any] | Generator[str, None, None]: ...
+
+    def generate(
+        self,
+        *,
+        app_model: App,
+        workflow: Workflow,
+        user: Account | EndUser,
+        args: Mapping[str, Any],
+        invoke_from: InvokeFrom,
+        streaming: bool = True,
+        call_depth: int = 0,
+        workflow_thread_pool_id: Optional[str] = None,
+    ):
         files: Sequence[Mapping[str, Any]] = args.get("files") or []
 
         # parse files