Bladeren bron

Fix/human in answer (#174)

John Wang 1 jaar geleden
bovenliggende
commit
4350bb9a00
1 gewijzigde bestanden met toevoegingen van 8 en 8 verwijderingen
  1. 8 8
      api/core/completion.py

+ 8 - 8
api/core/completion.py

@@ -1,4 +1,4 @@
-from typing import Optional, List, Union
+from typing import Optional, List, Union, Tuple
 
 from langchain.callbacks import CallbackManager
 from langchain.chat_models.base import BaseChatModel
@@ -97,7 +97,7 @@ class Completion:
         )
 
         # get llm prompt
-        prompt = cls.get_main_llm_prompt(
+        prompt, stop_words = cls.get_main_llm_prompt(
             mode=mode,
             llm=final_llm,
             pre_prompt=app_model_config.pre_prompt,
@@ -115,7 +115,7 @@ class Completion:
             mode=mode
         )
 
-        response = final_llm.generate([prompt])
+        response = final_llm.generate([prompt], stop_words)
 
         return response
 
@@ -123,7 +123,7 @@ class Completion:
     def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
                             chain_output: Optional[str],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
-            Union[str | List[BaseMessage]]:
+            Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
         # disable template string in query
         query_params = OutLinePromptTemplate.from_template(template=query).input_variables
         if query_params:
@@ -165,9 +165,9 @@ And answer according to the language of the user's question.
 
             if isinstance(llm, BaseChatModel):
                 # use chat llm as completion model
-                return [HumanMessage(content=prompt_content)]
+                return [HumanMessage(content=prompt_content)], None
             else:
-                return prompt_content
+                return prompt_content, None
         else:
             messages: List[BaseMessage] = []
 
@@ -236,7 +236,7 @@ And answer according to the language of the user's question.
 
             messages.append(human_message)
 
-            return messages
+            return messages, ['\nHuman:']
 
     @classmethod
     def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
@@ -323,7 +323,7 @@ And answer according to the language of the user's question.
         )
 
         # get llm prompt
-        original_prompt = cls.get_main_llm_prompt(
+        original_prompt, _ = cls.get_main_llm_prompt(
             mode="completion",
             llm=llm,
             pre_prompt=pre_prompt,