John Wang пре 1 година
родитељ
комит
90150a6ca9
1 измењених фајлова са 38 додато и 33 уклоњено
  1. 38 33
      api/core/completion.py

+ 38 - 33
api/core/completion.py

@@ -39,7 +39,8 @@ class Completion:
             memory = cls.get_memory_from_conversation(
             memory = cls.get_memory_from_conversation(
                 tenant_id=app.tenant_id,
                 tenant_id=app.tenant_id,
                 app_model_config=app_model_config,
                 app_model_config=app_model_config,
-                conversation=conversation
+                conversation=conversation,
+                return_messages=False
             )
             )
 
 
             inputs = conversation.inputs
             inputs = conversation.inputs
@@ -119,7 +120,8 @@ class Completion:
         return response
         return response
 
 
     @classmethod
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
+    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
+                            chain_output: Optional[str],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
             Union[str | List[BaseMessage]]:
             Union[str | List[BaseMessage]]:
         pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
         pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
@@ -161,11 +163,19 @@ And answer according to the language of the user's question.
                 "query": query
                 "query": query
             }
             }
 
 
-            human_message_prompt = "{query}"
+            human_message_prompt = ""
+
+            if pre_prompt:
+                pre_prompt_inputs = {k: inputs[k] for k in
+                                     OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
+                                     if k in inputs}
+
+                if pre_prompt_inputs:
+                    human_inputs.update(pre_prompt_inputs)
 
 
             if chain_output:
             if chain_output:
                 human_inputs['context'] = chain_output
                 human_inputs['context'] = chain_output
-                human_message_instruction = """Use the following CONTEXT as your learned knowledge.
+                human_message_prompt += """Use the following CONTEXT as your learned knowledge.
 [CONTEXT]
 [CONTEXT]
 {context}
 {context}
 [END CONTEXT]
 [END CONTEXT]
@@ -176,39 +186,33 @@ When answer to user:
 Avoid mentioning that you obtained the information from the context.
 Avoid mentioning that you obtained the information from the context.
 And answer according to the language of the user's question.
 And answer according to the language of the user's question.
 """
 """
-                if pre_prompt:
-                    extra_inputs = {k: inputs[k] for k in
-                                    OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
-                                    if k in inputs}
-                    if extra_inputs:
-                        human_inputs.update(extra_inputs)
-                    human_message_instruction += pre_prompt + "\n"
-
-                human_message_prompt = human_message_instruction + "Q:{query}\nA:"
-            else:
-                if pre_prompt:
-                    extra_inputs = {k: inputs[k] for k in
-                                    OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
-                                    if k in inputs}
-                    if extra_inputs:
-                        human_inputs.update(extra_inputs)
-                    human_message_prompt = pre_prompt + "\n" + human_message_prompt
 
 
-            # construct main prompt
-            human_message = PromptBuilder.to_human_message(
-                prompt_content=human_message_prompt,
-                inputs=human_inputs
-            )
+            if pre_prompt:
+                human_message_prompt += pre_prompt
+
+            query_prompt = "\nHuman: {query}\nAI: "
 
 
             if memory:
             if memory:
                 # append chat histories
                 # append chat histories
-                tmp_messages = messages.copy() + [human_message]
-                curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
-                rest_tokens = llm_constant.max_context_token_length[
-                                  memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
+                tmp_human_message = PromptBuilder.to_human_message(
+                    prompt_content=human_message_prompt + query_prompt,
+                    inputs=human_inputs
+                )
+
+                curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
+                rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
+                              - memory.llm.max_tokens - curr_message_tokens
                 rest_tokens = max(rest_tokens, 0)
                 rest_tokens = max(rest_tokens, 0)
                 history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
                 history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
-                messages += history_messages
+                human_message_prompt += "\n\n" + history_messages
+
+            human_message_prompt += query_prompt
+
+            # construct main prompt
+            human_message = PromptBuilder.to_human_message(
+                prompt_content=human_message_prompt,
+                inputs=human_inputs
+            )
 
 
             messages.append(human_message)
             messages.append(human_message)
 
 
@@ -216,7 +220,8 @@ And answer according to the language of the user's question.
 
 
     @classmethod
     @classmethod
     def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
     def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
-                                 streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
+                                 streaming: bool,
+                                 conversation_message_task: ConversationMessageTask) -> CallbackManager:
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         if streaming:
         if streaming:
             callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
             callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
@@ -228,7 +233,7 @@ And answer according to the language of the user's question.
     @classmethod
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
                                          max_token_limit: int) -> \
                                          max_token_limit: int) -> \
-            List[BaseMessage]:
+            str:
         """Get memory messages."""
         """Get memory messages."""
         memory.max_token_limit = max_token_limit
         memory.max_token_limit = max_token_limit
         memory_key = memory.memory_variables[0]
         memory_key = memory.memory_variables[0]