builtin_tool.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from core.tools.tool.tool import Tool
  2. from core.tools.model.tool_model_manager import ToolModelManager
  3. from core.model_runtime.entities.message_entities import PromptMessage
  4. from core.model_runtime.entities.llm_entities import LLMResult
  5. from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
  6. from core.tools.utils.web_reader_tool import get_url
  7. from typing import List
  8. from enum import Enum
  9. _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  10. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  11. retain the original meaning and keep the key points.
  12. however, the text you got is too long, what you got is possible a part of the text.
  13. Please summarize the text you got.
  14. """
  15. class BuiltinTool(Tool):
  16. """
  17. Builtin tool
  18. :param meta: the meta data of a tool call processing
  19. """
  20. def invoke_model(
  21. self, user_id: str, prompt_messages: List[PromptMessage], stop: List[str]
  22. ) -> LLMResult:
  23. """
  24. invoke model
  25. :param model_config: the model config
  26. :param prompt_messages: the prompt messages
  27. :param stop: the stop words
  28. :return: the model result
  29. """
  30. # invoke model
  31. return ToolModelManager.invoke(
  32. user_id=user_id,
  33. tenant_id=self.runtime.tenant_id,
  34. tool_type='builtin',
  35. tool_name=self.identity.name,
  36. prompt_messages=prompt_messages,
  37. )
  38. def get_max_tokens(self) -> int:
  39. """
  40. get max tokens
  41. :param model_config: the model config
  42. :return: the max tokens
  43. """
  44. return ToolModelManager.get_max_llm_context_tokens(
  45. tenant_id=self.runtime.tenant_id,
  46. )
  47. def get_prompt_tokens(self, prompt_messages: List[PromptMessage]) -> int:
  48. """
  49. get prompt tokens
  50. :param prompt_messages: the prompt messages
  51. :return: the tokens
  52. """
  53. return ToolModelManager.calculate_tokens(
  54. tenant_id=self.runtime.tenant_id,
  55. prompt_messages=prompt_messages
  56. )
  57. def summary(self, user_id: str, content: str) -> str:
  58. max_tokens = self.get_max_tokens()
  59. if self.get_prompt_tokens(prompt_messages=[
  60. UserPromptMessage(content=content)
  61. ]) < max_tokens * 0.6:
  62. return content
  63. def get_prompt_tokens(content: str) -> int:
  64. return self.get_prompt_tokens(prompt_messages=[
  65. SystemPromptMessage(content=_SUMMARY_PROMPT),
  66. UserPromptMessage(content=content)
  67. ])
  68. def summarize(content: str) -> str:
  69. summary = self.invoke_model(user_id=user_id, prompt_messages=[
  70. SystemPromptMessage(content=_SUMMARY_PROMPT),
  71. UserPromptMessage(content=content)
  72. ], stop=[])
  73. return summary.message.content
  74. lines = content.split('\n')
  75. new_lines = []
  76. # split long line into multiple lines
  77. for i in range(len(lines)):
  78. line = lines[i]
  79. if not line.strip():
  80. continue
  81. if len(line) < max_tokens * 0.5:
  82. new_lines.append(line)
  83. elif get_prompt_tokens(line) > max_tokens * 0.7:
  84. while get_prompt_tokens(line) > max_tokens * 0.7:
  85. new_lines.append(line[:int(max_tokens * 0.5)])
  86. line = line[int(max_tokens * 0.5):]
  87. new_lines.append(line)
  88. else:
  89. new_lines.append(line)
  90. # merge lines into messages with max tokens
  91. messages: List[str] = []
  92. for i in new_lines:
  93. if len(messages) == 0:
  94. messages.append(i)
  95. else:
  96. if len(messages[-1]) + len(i) < max_tokens * 0.5:
  97. messages[-1] += i
  98. if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
  99. messages.append(i)
  100. else:
  101. messages[-1] += i
  102. summaries = []
  103. for i in range(len(messages)):
  104. message = messages[i]
  105. summary = summarize(message)
  106. summaries.append(summary)
  107. result = '\n'.join(summaries)
  108. if self.get_prompt_tokens(prompt_messages=[
  109. UserPromptMessage(content=result)
  110. ]) > max_tokens * 0.7:
  111. return self.summary(user_id=user_id, content=result)
  112. return result
  113. def get_url(self, url: str, user_agent: str = None) -> str:
  114. """
  115. get url
  116. """
  117. return get_url(url, user_agent=user_agent)