Forráskód Böngészése

fix(api/services/workflow/workflow_converter.py): Add NoneType checkers & format file. (#7446)

-LAN- 8 hónapja
szülő
commit
5e42e90abc
1 módosított fájl, 127 hozzáadás és 192 törlés
  1. 127 192
      api/services/workflow/workflow_converter.py

+ 127 - 192
api/services/workflow/workflow_converter.py

@@ -32,12 +32,9 @@ class WorkflowConverter:
     App Convert to Workflow Mode
     """
 
-    def convert_to_workflow(self, app_model: App,
-                            account: Account,
-                            name: str,
-                            icon_type: str,
-                            icon: str,
-                            icon_background: str) -> App:
+    def convert_to_workflow(
+        self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str
+    ):
         """
         Convert app to workflow
 
@@ -56,18 +53,18 @@ class WorkflowConverter:
         :return: new App instance
         """
         # convert app model config
+        if not app_model.app_model_config:
+            raise ValueError("App model config is required")
+
         workflow = self.convert_app_model_config_to_workflow(
-            app_model=app_model,
-            app_model_config=app_model.app_model_config,
-            account_id=account.id
+            app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id
         )
 
         # create new app
         new_app = App()
         new_app.tenant_id = app_model.tenant_id
-        new_app.name = name if name else app_model.name + '(workflow)'
-        new_app.mode = AppMode.ADVANCED_CHAT.value \
-            if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
+        new_app.name = name if name else app_model.name + "(workflow)"
+        new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
         new_app.icon_type = icon_type if icon_type else app_model.icon_type
         new_app.icon = icon if icon else app_model.icon
         new_app.icon_background = icon_background if icon_background else app_model.icon_background
@@ -88,30 +85,21 @@ class WorkflowConverter:
 
         return new_app
 
-    def convert_app_model_config_to_workflow(self, app_model: App,
-                                             app_model_config: AppModelConfig,
-                                             account_id: str) -> Workflow:
+    def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str):
         """
         Convert app model config to workflow mode
         :param app_model: App instance
         :param app_model_config: AppModelConfig instance
         :param account_id: Account ID
-        :return:
         """
         # get new app mode
         new_app_mode = self._get_new_app_mode(app_model)
 
         # convert app model config
-        app_config = self._convert_to_app_config(
-            app_model=app_model,
-            app_model_config=app_model_config
-        )
+        app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
 
         # init workflow graph
-        graph = {
-            "nodes": [],
-            "edges": []
-        }
+        graph = {"nodes": [], "edges": []}
 
         # Convert list:
         # - variables -> start
@@ -123,11 +111,9 @@ class WorkflowConverter:
         # - show_retrieve_source -> knowledge-retrieval
 
         # convert to start node
-        start_node = self._convert_to_start_node(
-            variables=app_config.variables
-        )
+        start_node = self._convert_to_start_node(variables=app_config.variables)
 
-        graph['nodes'].append(start_node)
+        graph["nodes"].append(start_node)
 
         # convert to http request node
         external_data_variable_node_mapping = {}
@@ -135,7 +121,7 @@ class WorkflowConverter:
             http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
                 app_model=app_model,
                 variables=app_config.variables,
-                external_data_variables=app_config.external_data_variables
+                external_data_variables=app_config.external_data_variables,
             )
 
             for http_request_node in http_request_nodes:
@@ -144,9 +130,7 @@ class WorkflowConverter:
         # convert to knowledge retrieval node
         if app_config.dataset:
             knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
-                new_app_mode=new_app_mode,
-                dataset_config=app_config.dataset,
-                model_config=app_config.model
+                new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model
             )
 
             if knowledge_retrieval_node:
@@ -160,7 +144,7 @@ class WorkflowConverter:
             model_config=app_config.model,
             prompt_template=app_config.prompt_template,
             file_upload=app_config.additional_features.file_upload,
-            external_data_variable_node_mapping=external_data_variable_node_mapping
+            external_data_variable_node_mapping=external_data_variable_node_mapping,
         )
 
         graph = self._append_node(graph, llm_node)
@@ -199,7 +183,7 @@ class WorkflowConverter:
             tenant_id=app_model.tenant_id,
             app_id=app_model.id,
             type=WorkflowType.from_app_mode(new_app_mode).value,
-            version='draft',
+            version="draft",
             graph=json.dumps(graph),
             features=json.dumps(features),
             created_by=account_id,
@@ -212,24 +196,18 @@ class WorkflowConverter:
 
         return workflow
 
-    def _convert_to_app_config(self, app_model: App,
-                               app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
+    def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
         app_mode = AppMode.value_of(app_model.mode)
         if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
             app_model.mode = AppMode.AGENT_CHAT.value
             app_config = AgentChatAppConfigManager.get_app_config(
-                app_model=app_model,
-                app_model_config=app_model_config
+                app_model=app_model, app_model_config=app_model_config
             )
         elif app_mode == AppMode.CHAT:
-            app_config = ChatAppConfigManager.get_app_config(
-                app_model=app_model,
-                app_model_config=app_model_config
-            )
+            app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
         elif app_mode == AppMode.COMPLETION:
             app_config = CompletionAppConfigManager.get_app_config(
-                app_model=app_model,
-                app_model_config=app_model_config
+                app_model=app_model, app_model_config=app_model_config
             )
         else:
             raise ValueError("Invalid app mode")
@@ -248,14 +226,13 @@ class WorkflowConverter:
             "data": {
                 "title": "START",
                 "type": NodeType.START.value,
-                "variables": [jsonable_encoder(v) for v in variables]
-            }
+                "variables": [jsonable_encoder(v) for v in variables],
+            },
         }
 
-    def _convert_to_http_request_node(self, app_model: App,
-                                      variables: list[VariableEntity],
-                                      external_data_variables: list[ExternalDataVariableEntity]) \
-            -> tuple[list[dict], dict[str, str]]:
+    def _convert_to_http_request_node(
+        self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity]
+    ) -> tuple[list[dict], dict[str, str]]:
         """
         Convert API Based Extension to HTTP Request Node
         :param app_model: App instance
@@ -277,40 +254,33 @@ class WorkflowConverter:
 
             # get params from config
             api_based_extension_id = tool_config.get("api_based_extension_id")
+            if not api_based_extension_id:
+                continue
 
             # get api_based_extension
             api_based_extension = self._get_api_based_extension(
-                tenant_id=tenant_id,
-                api_based_extension_id=api_based_extension_id
+                tenant_id=tenant_id, api_based_extension_id=api_based_extension_id
             )
 
-            if not api_based_extension:
-                raise ValueError("[External data tool] API query failed, variable: {}, "
-                                 "error: api_based_extension_id is invalid"
-                                 .format(tool_variable))
-
             # decrypt api_key
-            api_key = encrypter.decrypt_token(
-                tenant_id=tenant_id,
-                token=api_based_extension.api_key
-            )
+            api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key)
 
             inputs = {}
             for v in variables:
-                inputs[v.variable] = '{{#start.' + v.variable + '#}}'
+                inputs[v.variable] = "{{#start." + v.variable + "#}}"
 
             request_body = {
-                'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
-                'params': {
-                    'app_id': app_model.id,
-                    'tool_variable': tool_variable,
-                    'inputs': inputs,
-                    'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else ''
-                }
+                "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
+                "params": {
+                    "app_id": app_model.id,
+                    "tool_variable": tool_variable,
+                    "inputs": inputs,
+                    "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
+                },
             }
 
             request_body_json = json.dumps(request_body)
-            request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}')
+            request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}")
 
             http_request_node = {
                 "id": f"http_request_{index}",
@@ -320,20 +290,11 @@ class WorkflowConverter:
                     "type": NodeType.HTTP_REQUEST.value,
                     "method": "post",
                     "url": api_based_extension.api_endpoint,
-                    "authorization": {
-                        "type": "api-key",
-                        "config": {
-                            "type": "bearer",
-                            "api_key": api_key
-                        }
-                    },
+                    "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
                     "headers": "",
                     "params": "",
-                    "body": {
-                        "type": "json",
-                        "data": request_body_json
-                    }
-                }
+                    "body": {"type": "json", "data": request_body_json},
+                },
             }
 
             nodes.append(http_request_node)
@@ -345,32 +306,24 @@ class WorkflowConverter:
                 "data": {
                     "title": f"Parse {api_based_extension.name} Response",
                     "type": NodeType.CODE.value,
-                    "variables": [{
-                        "variable": "response_json",
-                        "value_selector": [http_request_node['id'], "body"]
-                    }],
+                    "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
                     "code_language": "python3",
                     "code": "import json\n\ndef main(response_json: str) -> str:\n    response_body = json.loads("
-                            "response_json)\n    return {\n        \"result\": response_body[\"result\"]\n    }",
-                    "outputs": {
-                        "result": {
-                            "type": "string"
-                        }
-                    }
-                }
+                    'response_json)\n    return {\n        "result": response_body["result"]\n    }',
+                    "outputs": {"result": {"type": "string"}},
+                },
             }
 
             nodes.append(code_node)
 
-            external_data_variable_node_mapping[external_data_variable.variable] = code_node['id']
+            external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"]
             index += 1
 
         return nodes, external_data_variable_node_mapping
 
-    def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
-                                             dataset_config: DatasetEntity,
-                                             model_config: ModelConfigEntity) \
-            -> Optional[dict]:
+    def _convert_to_knowledge_retrieval_node(
+        self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
+    ) -> Optional[dict]:
         """
         Convert datasets to Knowledge Retrieval Node
         :param new_app_mode: new app mode
@@ -404,7 +357,7 @@ class WorkflowConverter:
                         "completion_params": {
                             **model_config.parameters,
                             "stop": model_config.stop,
-                        }
+                        },
                     }
                 }
                 if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
@@ -412,20 +365,23 @@ class WorkflowConverter:
                 "multiple_retrieval_config": {
                     "top_k": retrieve_config.top_k,
                     "score_threshold": retrieve_config.score_threshold,
-                    "reranking_model": retrieve_config.reranking_model
+                    "reranking_model": retrieve_config.reranking_model,
                 }
                 if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
                 else None,
-            }
+            },
         }
 
-    def _convert_to_llm_node(self, original_app_mode: AppMode,
-                             new_app_mode: AppMode,
-                             graph: dict,
-                             model_config: ModelConfigEntity,
-                             prompt_template: PromptTemplateEntity,
-                             file_upload: Optional[FileExtraConfig] = None,
-                             external_data_variable_node_mapping: dict[str, str] = None) -> dict:
+    def _convert_to_llm_node(
+        self,
+        original_app_mode: AppMode,
+        new_app_mode: AppMode,
+        graph: dict,
+        model_config: ModelConfigEntity,
+        prompt_template: PromptTemplateEntity,
+        file_upload: Optional[FileExtraConfig] = None,
+        external_data_variable_node_mapping: dict[str, str] | None = None,
+    ) -> dict:
         """
         Convert to LLM Node
         :param original_app_mode: original app mode
@@ -437,17 +393,18 @@ class WorkflowConverter:
         :param external_data_variable_node_mapping: external data variable node mapping
         """
         # fetch start and knowledge retrieval node
-        start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes']))
-        knowledge_retrieval_node = next(filter(
-            lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value,
-            graph['nodes']
-        ), None)
+        start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
+        knowledge_retrieval_node = next(
+            filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
+        )
 
         role_prefix = None
 
         # Chat Model
         if model_config.mode == LLMMode.CHAT.value:
             if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+                if not prompt_template.simple_prompt_template:
+                    raise ValueError("Simple prompt template is required")
                 # get prompt template
                 prompt_transform = SimplePromptTransform()
                 prompt_template_config = prompt_transform.get_prompt_template(
@@ -456,45 +413,35 @@ class WorkflowConverter:
                     model=model_config.model,
                     pre_prompt=prompt_template.simple_prompt_template,
                     has_context=knowledge_retrieval_node is not None,
-                    query_in_prompt=False
+                    query_in_prompt=False,
                 )
 
-                template = prompt_template_config['prompt_template'].template
+                template = prompt_template_config["prompt_template"].template
                 if not template:
                     prompts = []
                 else:
                     template = self._replace_template_variables(
-                        template,
-                        start_node['data']['variables'],
-                        external_data_variable_node_mapping
+                        template, start_node["data"]["variables"], external_data_variable_node_mapping
                     )
 
-                    prompts = [
-                        {
-                            "role": 'user',
-                            "text": template
-                        }
-                    ]
+                    prompts = [{"role": "user", "text": template}]
             else:
                 advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
 
                 prompts = []
-                for m in advanced_chat_prompt_template.messages:
-                    if advanced_chat_prompt_template:
+                if advanced_chat_prompt_template:
+                    for m in advanced_chat_prompt_template.messages:
                         text = m.text
                         text = self._replace_template_variables(
-                            text,
-                            start_node['data']['variables'],
-                            external_data_variable_node_mapping
+                            text, start_node["data"]["variables"], external_data_variable_node_mapping
                         )
 
-                        prompts.append({
-                            "role": m.role.value,
-                            "text": text
-                        })
+                        prompts.append({"role": m.role.value, "text": text})
         # Completion Model
         else:
             if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+                if not prompt_template.simple_prompt_template:
+                    raise ValueError("Simple prompt template is required")
                 # get prompt template
                 prompt_transform = SimplePromptTransform()
                 prompt_template_config = prompt_transform.get_prompt_template(
@@ -503,57 +450,50 @@ class WorkflowConverter:
                     model=model_config.model,
                     pre_prompt=prompt_template.simple_prompt_template,
                     has_context=knowledge_retrieval_node is not None,
-                    query_in_prompt=False
+                    query_in_prompt=False,
                 )
 
-                template = prompt_template_config['prompt_template'].template
+                template = prompt_template_config["prompt_template"].template
                 template = self._replace_template_variables(
-                    template,
-                    start_node['data']['variables'],
-                    external_data_variable_node_mapping
+                    template=template,
+                    variables=start_node["data"]["variables"],
+                    external_data_variable_node_mapping=external_data_variable_node_mapping,
                 )
 
-                prompts = {
-                    "text": template
-                }
+                prompts = {"text": template}
 
-                prompt_rules = prompt_template_config['prompt_rules']
+                prompt_rules = prompt_template_config["prompt_rules"]
                 role_prefix = {
-                    "user": prompt_rules.get('human_prefix', 'Human'),
-                    "assistant": prompt_rules.get('assistant_prefix', 'Assistant')
+                    "user": prompt_rules.get("human_prefix", "Human"),
+                    "assistant": prompt_rules.get("assistant_prefix", "Assistant"),
                 }
             else:
                 advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
                 if advanced_completion_prompt_template:
                     text = advanced_completion_prompt_template.prompt
                     text = self._replace_template_variables(
-                        text,
-                        start_node['data']['variables'],
-                        external_data_variable_node_mapping
+                        template=text,
+                        variables=start_node["data"]["variables"],
+                        external_data_variable_node_mapping=external_data_variable_node_mapping,
                     )
                 else:
                     text = ""
 
-                text = text.replace('{{#query#}}', '{{#sys.query#}}')
+                text = text.replace("{{#query#}}", "{{#sys.query#}}")
 
                 prompts = {
                     "text": text,
                 }
 
-                if advanced_completion_prompt_template.role_prefix:
+                if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix:
                     role_prefix = {
                         "user": advanced_completion_prompt_template.role_prefix.user,
-                        "assistant": advanced_completion_prompt_template.role_prefix.assistant
+                        "assistant": advanced_completion_prompt_template.role_prefix.assistant,
                     }
 
         memory = None
         if new_app_mode == AppMode.ADVANCED_CHAT:
-            memory = {
-                "role_prefix": role_prefix,
-                "window": {
-                    "enabled": False
-                }
-            }
+            memory = {"role_prefix": role_prefix, "window": {"enabled": False}}
 
         completion_params = model_config.parameters
         completion_params.update({"stop": model_config.stop})
@@ -567,28 +507,29 @@ class WorkflowConverter:
                     "provider": model_config.provider,
                     "name": model_config.model,
                     "mode": model_config.mode,
-                    "completion_params": completion_params
+                    "completion_params": completion_params,
                 },
                 "prompt_template": prompts,
                 "memory": memory,
                 "context": {
                     "enabled": knowledge_retrieval_node is not None,
                     "variable_selector": ["knowledge_retrieval", "result"]
-                    if knowledge_retrieval_node is not None else None
+                    if knowledge_retrieval_node is not None
+                    else None,
                 },
                 "vision": {
                     "enabled": file_upload is not None,
                     "variable_selector": ["sys", "files"] if file_upload is not None else None,
-                    "configs": {
-                        "detail": file_upload.image_config['detail']
-                    } if file_upload is not None else None
-                }
-            }
+                    "configs": {"detail": file_upload.image_config["detail"]}
+                    if file_upload is not None and file_upload.image_config is not None
+                    else None,
+                },
+            },
         }
 
-    def _replace_template_variables(self, template: str,
-                                    variables: list[dict],
-                                    external_data_variable_node_mapping: dict[str, str] = None) -> str:
+    def _replace_template_variables(
+        self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None
+    ) -> str:
         """
         Replace Template Variables
         :param template: template
@@ -597,12 +538,11 @@ class WorkflowConverter:
         :return:
         """
         for v in variables:
-            template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}')
+            template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}")
 
         if external_data_variable_node_mapping:
             for variable, code_node_id in external_data_variable_node_mapping.items():
-                template = template.replace('{{' + variable + '}}',
-                                            '{{#' + code_node_id + '.result#}}')
+                template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}")
 
         return template
 
@@ -618,11 +558,8 @@ class WorkflowConverter:
             "data": {
                 "title": "END",
                 "type": NodeType.END.value,
-                "outputs": [{
-                    "variable": "result",
-                    "value_selector": ["llm", "text"]
-                }]
-            }
+                "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
+            },
         }
 
     def _convert_to_answer_node(self) -> dict:
@@ -634,11 +571,7 @@ class WorkflowConverter:
         return {
             "id": "answer",
             "position": None,
-            "data": {
-                "title": "ANSWER",
-                "type": NodeType.ANSWER.value,
-                "answer": "{{#llm.text#}}"
-            }
+            "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
         }
 
     def _create_edge(self, source: str, target: str) -> dict:
@@ -648,11 +581,7 @@ class WorkflowConverter:
         :param target: target node id
         :return:
         """
-        return {
-            "id": f"{source}-{target}",
-            "source": source,
-            "target": target
-        }
+        return {"id": f"{source}-{target}", "source": source, "target": target}
 
     def _append_node(self, graph: dict, node: dict) -> dict:
         """
@@ -662,9 +591,9 @@ class WorkflowConverter:
         :param node: Node to append
         :return:
         """
-        previous_node = graph['nodes'][-1]
-        graph['nodes'].append(node)
-        graph['edges'].append(self._create_edge(previous_node['id'], node['id']))
+        previous_node = graph["nodes"][-1]
+        graph["nodes"].append(node)
+        graph["edges"].append(self._create_edge(previous_node["id"], node["id"]))
         return graph
 
     def _get_new_app_mode(self, app_model: App) -> AppMode:
@@ -678,14 +607,20 @@ class WorkflowConverter:
         else:
             return AppMode.ADVANCED_CHAT
 
-    def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
+    def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str):
         """
         Get API Based Extension
         :param tenant_id: tenant id
         :param api_based_extension_id: api based extension id
         :return:
         """
-        return db.session.query(APIBasedExtension).filter(
-            APIBasedExtension.tenant_id == tenant_id,
-            APIBasedExtension.id == api_based_extension_id
-        ).first()
+        api_based_extension = (
+            db.session.query(APIBasedExtension)
+            .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
+            .first()
+        )
+
+        if not api_based_extension:
+            raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}")
+
+        return api_based_extension