|
@@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask
|
|
|
from core.model_providers.error import ProviderTokenNotInitError
|
|
|
from core.model_providers.model_factory import ModelFactory
|
|
|
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
|
|
|
+from core.model_providers.models.llm.base import BaseLLM
|
|
|
+from core.tool.current_datetime_tool import DatetimeTool
|
|
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
|
|
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
|
|
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
|
|
from core.tool.web_reader_tool import WebReaderTool
|
|
|
from extensions.ext_database import db
|
|
|
-from libs import helper
|
|
|
from models.dataset import Dataset, DatasetProcessRule
|
|
|
from models.model import AppModelConfig
|
|
|
|
|
@@ -82,15 +83,19 @@ class OrchestratorRuleParser:
|
|
|
try:
|
|
|
summary_model_instance = ModelFactory.get_text_generation_model(
|
|
|
tenant_id=self.tenant_id,
|
|
|
+ model_provider_name=agent_provider_name,
|
|
|
+ model_name=agent_model_name,
|
|
|
model_kwargs=ModelKwargs(
|
|
|
temperature=0,
|
|
|
max_tokens=500
|
|
|
- )
|
|
|
+ ),
|
|
|
+ deduct_quota=False
|
|
|
)
|
|
|
except ProviderTokenNotInitError as e:
|
|
|
summary_model_instance = None
|
|
|
|
|
|
tools = self.to_tools(
|
|
|
+ agent_model_instance=agent_model_instance,
|
|
|
tool_configs=tool_configs,
|
|
|
conversation_message_task=conversation_message_task,
|
|
|
rest_tokens=rest_tokens,
|
|
@@ -140,11 +145,12 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
return None
|
|
|
|
|
|
- def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
|
|
+ def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
|
|
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
|
|
"""
|
|
|
Convert app agent tool configs to tools
|
|
|
|
|
|
+ :param agent_model_instance:
|
|
|
:param rest_tokens:
|
|
|
:param tool_configs: app agent tool configs
|
|
|
:param conversation_message_task:
|
|
@@ -162,7 +168,7 @@ class OrchestratorRuleParser:
|
|
|
if tool_type == "dataset":
|
|
|
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
|
|
elif tool_type == "web_reader":
|
|
|
- tool = self.to_web_reader_tool()
|
|
|
+ tool = self.to_web_reader_tool(agent_model_instance)
|
|
|
elif tool_type == "google_search":
|
|
|
tool = self.to_google_search_tool()
|
|
|
elif tool_type == "wikipedia":
|
|
@@ -207,24 +213,28 @@ class OrchestratorRuleParser:
|
|
|
|
|
|
return tool
|
|
|
|
|
|
- def to_web_reader_tool(self) -> Optional[BaseTool]:
|
|
|
+ def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
|
|
|
"""
|
|
|
A tool for reading web pages
|
|
|
|
|
|
:return:
|
|
|
"""
|
|
|
- summary_model_instance = ModelFactory.get_text_generation_model(
|
|
|
- tenant_id=self.tenant_id,
|
|
|
- model_kwargs=ModelKwargs(
|
|
|
- temperature=0,
|
|
|
- max_tokens=500
|
|
|
+ try:
|
|
|
+ summary_model_instance = ModelFactory.get_text_generation_model(
|
|
|
+ tenant_id=self.tenant_id,
|
|
|
+ model_provider_name=agent_model_instance.model_provider.provider_name,
|
|
|
+ model_name=agent_model_instance.name,
|
|
|
+ model_kwargs=ModelKwargs(
|
|
|
+ temperature=0,
|
|
|
+ max_tokens=500
|
|
|
+ ),
|
|
|
+ deduct_quota=False
|
|
|
)
|
|
|
- )
|
|
|
-
|
|
|
- summary_llm = summary_model_instance.client
|
|
|
+ except ProviderTokenNotInitError:
|
|
|
+ summary_model_instance = None
|
|
|
|
|
|
tool = WebReaderTool(
|
|
|
- llm=summary_llm,
|
|
|
+ llm=summary_model_instance.client if summary_model_instance else None,
|
|
|
max_chunk_length=4000,
|
|
|
continue_reading=True,
|
|
|
callbacks=[DifyStdOutCallbackHandler()]
|
|
@@ -252,11 +262,7 @@ class OrchestratorRuleParser:
|
|
|
return tool
|
|
|
|
|
|
def to_current_datetime_tool(self) -> Optional[BaseTool]:
|
|
|
- tool = Tool(
|
|
|
- name="current_datetime",
|
|
|
- description="A tool when you want to get the current date, time, week, month or year, "
|
|
|
- "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
|
|
|
- func=helper.get_current_datetime,
|
|
|
+ tool = DatetimeTool(
|
|
|
callbacks=[DifyStdOutCallbackHandler()]
|
|
|
)
|
|
|
|