|
@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
|
|
_format_intermediate_steps
|
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
|
+from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
|
|
+from langchain.memory.prompt import SUMMARY_PROMPT
|
|
|
from langchain.prompts.chat import BaseMessagePromptTemplate
|
|
|
-from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
|
|
-from langchain.schema.language_model import BaseLanguageModel
|
|
|
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
|
|
|
+ get_buffer_string
|
|
|
from langchain.tools import BaseTool
|
|
|
+from pydantic import root_validator
|
|
|
|
|
|
-from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
|
|
-from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
|
|
+from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
|
|
+from core.chain.llm_chain import LLMChain
|
|
|
+from core.model_providers.models.entity.message import to_prompt_messages
|
|
|
+from core.model_providers.models.llm.base import BaseLLM
|
|
|
+from core.third_party.langchain.llms.fake import FakeLLM
|
|
|
|
|
|
|
|
|
-class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
|
|
+class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
|
|
+ moving_summary_buffer: str = ""
|
|
|
+ moving_summary_index: int = 0
|
|
|
+ summary_model_instance: BaseLLM = None
|
|
|
+ model_instance: BaseLLM
|
|
|
+
|
|
|
+ class Config:
|
|
|
+ """Configuration for this pydantic object."""
|
|
|
+
|
|
|
+ arbitrary_types_allowed = True
|
|
|
+
|
|
|
+ @root_validator
|
|
|
+ def validate_llm(cls, values: dict) -> dict:
|
|
|
+ return values
|
|
|
|
|
|
@classmethod
|
|
|
def from_llm_and_tools(
|
|
|
cls,
|
|
|
- llm: BaseLanguageModel,
|
|
|
+ model_instance: BaseLLM,
|
|
|
tools: Sequence[BaseTool],
|
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
|
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
|
|
),
|
|
|
**kwargs: Any,
|
|
|
) -> BaseSingleActionAgent:
|
|
|
- return super().from_llm_and_tools(
|
|
|
- llm=llm,
|
|
|
+ prompt = cls.create_prompt(
|
|
|
+ extra_prompt_messages=extra_prompt_messages,
|
|
|
+ system_message=system_message,
|
|
|
+ )
|
|
|
+ return cls(
|
|
|
+ model_instance=model_instance,
|
|
|
+ llm=FakeLLM(response=''),
|
|
|
+ prompt=prompt,
|
|
|
tools=tools,
|
|
|
callback_manager=callback_manager,
|
|
|
- extra_prompt_messages=extra_prompt_messages,
|
|
|
- system_message=cls.get_system_message(),
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
|
|
:param query:
|
|
|
:return:
|
|
|
"""
|
|
|
- original_max_tokens = self.llm.max_tokens
|
|
|
- self.llm.max_tokens = 40
|
|
|
+ original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
|
|
+ self.model_instance.model_kwargs.max_tokens = 40
|
|
|
|
|
|
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
|
|
messages = prompt.to_messages()
|
|
|
|
|
|
try:
|
|
|
- predicted_message = self.llm.predict_messages(
|
|
|
- messages, functions=self.functions, callbacks=None
|
|
|
+ prompt_messages = to_prompt_messages(messages)
|
|
|
+ result = self.model_instance.run(
|
|
|
+ messages=prompt_messages,
|
|
|
+ functions=self.functions,
|
|
|
+ callbacks=None
|
|
|
)
|
|
|
except Exception as e:
|
|
|
new_exception = self.model_instance.handle_exceptions(e)
|
|
|
raise new_exception
|
|
|
|
|
|
- function_call = predicted_message.additional_kwargs.get("function_call", {})
|
|
|
+ function_call = result.function_call
|
|
|
|
|
|
- self.llm.max_tokens = original_max_tokens
|
|
|
+ self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
|
|
|
|
|
return True if function_call else False
|
|
|
|
|
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
|
|
except ExceededLLMTokensLimitError as e:
|
|
|
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
|
|
|
|
|
- predicted_message = self.llm.predict_messages(
|
|
|
- messages, functions=self.functions, callbacks=callbacks
|
|
|
+ prompt_messages = to_prompt_messages(messages)
|
|
|
+ result = self.model_instance.run(
|
|
|
+ messages=prompt_messages,
|
|
|
+ functions=self.functions,
|
|
|
+ )
|
|
|
+
|
|
|
+ ai_message = AIMessage(
|
|
|
+ content=result.content,
|
|
|
+ additional_kwargs={
|
|
|
+ 'function_call': result.function_call
|
|
|
+ }
|
|
|
)
|
|
|
- agent_decision = _parse_ai_message(predicted_message)
|
|
|
+ agent_decision = _parse_ai_message(ai_message)
|
|
|
|
|
|
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
|
|
tool_inputs = agent_decision.tool_input
|
|
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
|
|
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
|
|
except ValueError:
|
|
|
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
|
|
+
|
|
|
+ def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
|
|
+
|
|
|
+ rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
|
|
+ rest_tokens = rest_tokens - 20
|
|
|
+ if rest_tokens >= 0:
|
|
|
+ return messages
|
|
|
+
|
|
|
+ system_message = None
|
|
|
+ human_message = None
|
|
|
+ should_summary_messages = []
|
|
|
+ for message in messages:
|
|
|
+ if isinstance(message, SystemMessage):
|
|
|
+ system_message = message
|
|
|
+ elif isinstance(message, HumanMessage):
|
|
|
+ human_message = message
|
|
|
+ else:
|
|
|
+ should_summary_messages.append(message)
|
|
|
+
|
|
|
+ if len(should_summary_messages) > 2:
|
|
|
+ ai_message = should_summary_messages[-2]
|
|
|
+ function_message = should_summary_messages[-1]
|
|
|
+ should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
|
|
+ self.moving_summary_index = len(should_summary_messages)
|
|
|
+ else:
|
|
|
+ error_msg = "Exceeded LLM tokens limit, stopped."
|
|
|
+ raise ExceededLLMTokensLimitError(error_msg)
|
|
|
+
|
|
|
+ new_messages = [system_message, human_message]
|
|
|
+
|
|
|
+ if self.moving_summary_index == 0:
|
|
|
+ should_summary_messages.insert(0, human_message)
|
|
|
+
|
|
|
+ self.moving_summary_buffer = self.predict_new_summary(
|
|
|
+ messages=should_summary_messages,
|
|
|
+ existing_summary=self.moving_summary_buffer
|
|
|
+ )
|
|
|
+
|
|
|
+ new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
|
|
+ new_messages.append(ai_message)
|
|
|
+ new_messages.append(function_message)
|
|
|
+
|
|
|
+ return new_messages
|
|
|
+
|
|
|
+ def predict_new_summary(
|
|
|
+ self, messages: List[BaseMessage], existing_summary: str
|
|
|
+ ) -> str:
|
|
|
+ new_lines = get_buffer_string(
|
|
|
+ messages,
|
|
|
+ human_prefix="Human",
|
|
|
+ ai_prefix="AI",
|
|
|
+ )
|
|
|
+
|
|
|
+ chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
|
|
+ return chain.predict(summary=existing_summary, new_lines=new_lines)
|
|
|
+
|
|
|
+ def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
|
|
+ """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
|
|
+
|
|
|
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
|
|
|
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
|
|
+ if model_instance.model_provider.provider_name == 'azure_openai':
|
|
|
+ model = model_instance.base_model_name
|
|
|
+ model = model.replace("gpt-35", "gpt-3.5")
|
|
|
+ else:
|
|
|
+ model = model_instance.base_model_name
|
|
|
+
|
|
|
+ tiktoken_ = _import_tiktoken()
|
|
|
+ try:
|
|
|
+ encoding = tiktoken_.encoding_for_model(model)
|
|
|
+ except KeyError:
|
|
|
+ model = "cl100k_base"
|
|
|
+ encoding = tiktoken_.get_encoding(model)
|
|
|
+
|
|
|
+ if model.startswith("gpt-3.5-turbo"):
|
|
|
+
|
|
|
+ tokens_per_message = 4
|
|
|
+
|
|
|
+ tokens_per_name = -1
|
|
|
+ elif model.startswith("gpt-4"):
|
|
|
+ tokens_per_message = 3
|
|
|
+ tokens_per_name = 1
|
|
|
+ else:
|
|
|
+ raise NotImplementedError(
|
|
|
+ f"get_num_tokens_from_messages() is not presently implemented "
|
|
|
+ f"for model {model}."
|
|
|
+ "See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
|
|
+ "information on how messages are converted to tokens."
|
|
|
+ )
|
|
|
+ num_tokens = 0
|
|
|
+ for m in messages:
|
|
|
+ message = _convert_message_to_dict(m)
|
|
|
+ num_tokens += tokens_per_message
|
|
|
+ for key, value in message.items():
|
|
|
+ if key == "function_call":
|
|
|
+ for f_key, f_value in value.items():
|
|
|
+ num_tokens += len(encoding.encode(f_key))
|
|
|
+ num_tokens += len(encoding.encode(f_value))
|
|
|
+ else:
|
|
|
+ num_tokens += len(encoding.encode(value))
|
|
|
+
|
|
|
+ if key == "name":
|
|
|
+ num_tokens += tokens_per_name
|
|
|
+
|
|
|
+ num_tokens += 3
|
|
|
+
|
|
|
+ if kwargs.get('functions'):
|
|
|
+ for function in kwargs.get('functions'):
|
|
|
+ num_tokens += len(encoding.encode('name'))
|
|
|
+ num_tokens += len(encoding.encode(function.get("name")))
|
|
|
+ num_tokens += len(encoding.encode('description'))
|
|
|
+ num_tokens += len(encoding.encode(function.get("description")))
|
|
|
+ parameters = function.get("parameters")
|
|
|
+ num_tokens += len(encoding.encode('parameters'))
|
|
|
+ if 'title' in parameters:
|
|
|
+ num_tokens += len(encoding.encode('title'))
|
|
|
+ num_tokens += len(encoding.encode(parameters.get("title")))
|
|
|
+ num_tokens += len(encoding.encode('type'))
|
|
|
+ num_tokens += len(encoding.encode(parameters.get("type")))
|
|
|
+ if 'properties' in parameters:
|
|
|
+ num_tokens += len(encoding.encode('properties'))
|
|
|
+ for key, value in parameters.get('properties').items():
|
|
|
+ num_tokens += len(encoding.encode(key))
|
|
|
+ for field_key, field_value in value.items():
|
|
|
+ num_tokens += len(encoding.encode(field_key))
|
|
|
+ if field_key == 'enum':
|
|
|
+ for enum_field in field_value:
|
|
|
+ num_tokens += 3
|
|
|
+ num_tokens += len(encoding.encode(enum_field))
|
|
|
+ else:
|
|
|
+ num_tokens += len(encoding.encode(field_key))
|
|
|
+ num_tokens += len(encoding.encode(str(field_value)))
|
|
|
+ if 'required' in parameters:
|
|
|
+ num_tokens += len(encoding.encode('required'))
|
|
|
+ for required_field in parameters['required']:
|
|
|
+ num_tokens += 3
|
|
|
+ num_tokens += len(encoding.encode(required_field))
|
|
|
+
|
|
|
+ return num_tokens
|