Selaa lähdekoodia

fix: universal chat when default model invalid (#905)

takatost 1 vuosi sitten
vanhempi
commit
1d9cc5ca05

+ 1 - 1
api/core/agent/agent/openai_function_call_summarize_mixin.py

@@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
 class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
-    summary_llm: BaseLanguageModel
+    summary_llm: BaseLanguageModel = None
     model_instance: BaseLLM
 
     class Config:

+ 1 - 1
api/core/agent/agent/structured_chat.py

@@ -52,7 +52,7 @@ Action:
 class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
     moving_summary_buffer: str = ""
     moving_summary_index: int = 0
-    summary_llm: BaseLanguageModel
+    summary_llm: BaseLanguageModel = None
     model_instance: BaseLLM
 
     class Config:

+ 1 - 1
api/core/agent/agent_executor.py

@@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel):
     strategy: PlanningStrategy
     model_instance: BaseLLM
     tools: list[BaseTool]
-    summary_model_instance: BaseLLM
+    summary_model_instance: BaseLLM = None
     memory: Optional[BaseChatMemory] = None
     callbacks: Callbacks = None
     max_iterations: int = 6

+ 4 - 2
api/core/model_providers/model_factory.py

@@ -46,7 +46,8 @@ class ModelFactory:
                                   model_name: Optional[str] = None,
                                   model_kwargs: Optional[ModelKwargs] = None,
                                   streaming: bool = False,
-                                  callbacks: Callbacks = None) -> Optional[BaseLLM]:
+                                  callbacks: Callbacks = None,
+                                  deduct_quota: bool = True) -> Optional[BaseLLM]:
         """
         get text generation model.
 
@@ -56,6 +57,7 @@ class ModelFactory:
         :param model_kwargs:
         :param streaming:
         :param callbacks:
+        :param deduct_quota:
         :return:
         """
         is_default_model = False
@@ -95,7 +97,7 @@ class ModelFactory:
             else:
                 raise e
 
-        if is_default_model:
+        if is_default_model or not deduct_quota:
             model_instance.deduct_quota = False
 
         return model_instance

+ 25 - 19
api/core/orchestrator_rule_parser.py

@@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask
 from core.model_providers.error import ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
+from core.model_providers.models.llm.base import BaseLLM
+from core.tool.current_datetime_tool import DatetimeTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.provider.serpapi_provider import SerpAPIToolProvider
 from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
 from core.tool.web_reader_tool import WebReaderTool
 from extensions.ext_database import db
-from libs import helper
 from models.dataset import Dataset, DatasetProcessRule
 from models.model import AppModelConfig
 
@@ -82,15 +83,19 @@ class OrchestratorRuleParser:
             try:
                 summary_model_instance = ModelFactory.get_text_generation_model(
                     tenant_id=self.tenant_id,
+                    model_provider_name=agent_provider_name,
+                    model_name=agent_model_name,
                     model_kwargs=ModelKwargs(
                         temperature=0,
                         max_tokens=500
-                    )
+                    ),
+                    deduct_quota=False
                 )
             except ProviderTokenNotInitError as e:
                 summary_model_instance = None
 
             tools = self.to_tools(
+                agent_model_instance=agent_model_instance,
                 tool_configs=tool_configs,
                 conversation_message_task=conversation_message_task,
                 rest_tokens=rest_tokens,
@@ -140,11 +145,12 @@ class OrchestratorRuleParser:
 
         return None
 
-    def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
+    def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
                  rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
         """
         Convert app agent tool configs to tools
 
+        :param agent_model_instance:
         :param rest_tokens:
         :param tool_configs: app agent tool configs
         :param conversation_message_task:
@@ -162,7 +168,7 @@ class OrchestratorRuleParser:
             if tool_type == "dataset":
                 tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
             elif tool_type == "web_reader":
-                tool = self.to_web_reader_tool()
+                tool = self.to_web_reader_tool(agent_model_instance)
             elif tool_type == "google_search":
                 tool = self.to_google_search_tool()
             elif tool_type == "wikipedia":
@@ -207,24 +213,28 @@ class OrchestratorRuleParser:
 
         return tool
 
-    def to_web_reader_tool(self) -> Optional[BaseTool]:
+    def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
         """
         A tool for reading web pages
 
         :return:
         """
-        summary_model_instance = ModelFactory.get_text_generation_model(
-            tenant_id=self.tenant_id,
-            model_kwargs=ModelKwargs(
-                temperature=0,
-                max_tokens=500
+        try:
+            summary_model_instance = ModelFactory.get_text_generation_model(
+                tenant_id=self.tenant_id,
+                model_provider_name=agent_model_instance.model_provider.provider_name,
+                model_name=agent_model_instance.name,
+                model_kwargs=ModelKwargs(
+                    temperature=0,
+                    max_tokens=500
+                ),
+                deduct_quota=False
             )
-        )
-
-        summary_llm = summary_model_instance.client
+        except ProviderTokenNotInitError:
+            summary_model_instance = None
 
         tool = WebReaderTool(
-            llm=summary_llm,
+            llm=summary_model_instance.client if summary_model_instance else None,
             max_chunk_length=4000,
             continue_reading=True,
             callbacks=[DifyStdOutCallbackHandler()]
@@ -252,11 +262,7 @@ class OrchestratorRuleParser:
         return tool
 
     def to_current_datetime_tool(self) -> Optional[BaseTool]:
-        tool = Tool(
-            name="current_datetime",
-            description="A tool when you want to get the current date, time, week, month or year, "
-                        "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
-            func=helper.get_current_datetime,
+        tool = DatetimeTool(
             callbacks=[DifyStdOutCallbackHandler()]
         )
 

+ 25 - 0
api/core/tool/current_datetime_tool.py

@@ -0,0 +1,25 @@
+from datetime import datetime
+from typing import Type
+
+from langchain.tools import BaseTool
+from pydantic import Field, BaseModel
+
+
+class DatetimeToolInput(BaseModel):
+    type: str = Field(..., description="Type for current time, must be: datetime.")
+
+
+class DatetimeTool(BaseTool):
+    """Tool for querying current datetime."""
+    name: str = "current_datetime"
+    args_schema: Type[BaseModel] = DatetimeToolInput
+    description: str = "A tool when you want to get the current date, time, week, month or year, " \
+                       "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
+
+    def _run(self, type: str) -> str:
+        # get current time
+        current_time = datetime.utcnow()
+        return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
+
+    async def _arun(self, tool_input: str) -> str:
+        raise NotImplementedError()

+ 2 - 2
api/core/tool/web_reader_tool.py

@@ -65,7 +65,7 @@ class WebReaderTool(BaseTool):
     summary_chunk_overlap: int = 0
     summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
     continue_reading: bool = True
-    llm: BaseLanguageModel
+    llm: BaseLanguageModel = None
 
     def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
         try:
@@ -78,7 +78,7 @@ class WebReaderTool(BaseTool):
         except Exception as e:
             return f'Read this website failed, caused by: {str(e)}.'
 
-        if summary:
+        if summary and self.llm:
             character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                 chunk_size=self.summary_chunk_tokens,
                 chunk_overlap=self.summary_chunk_overlap,

+ 0 - 6
api/libs/helper.py

@@ -153,9 +153,3 @@ def get_remote_ip(request):
 def generate_text_hash(text: str) -> str:
     hash_text = str(text) + 'None'
     return sha256(hash_text.encode()).hexdigest()
-
-
-def get_current_datetime(type: str) -> str:
-    # get current time
-    current_time = datetime.utcnow()
-    return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")