瀏覽代碼

fix: prompt for baichuan text generation models (#1299)

takatost 1 年之前
父節點
當前提交
8480b0197b
共有 1 個文件被更改,包括 6 次插入0 次删除
  1. 6 0
      api/core/model_providers/models/llm/baichuan_model.py

+ 6 - 0
api/core/model_providers/models/llm/baichuan_model.py

@@ -37,6 +37,12 @@ class BaichuanModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return self._client.generate([prompts], stop, callbacks)
 
+    def prompt_file_name(self, mode: str) -> str:
+        if mode == 'completion':
+            return 'baichuan_completion'
+        else:
+            return 'baichuan_chat'
+
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """
         get num tokens of prompt messages.