|
@@ -37,6 +37,12 @@ class BaichuanModel(BaseLLM):
|
|
|
prompts = self._get_prompt_from_messages(messages)
|
|
|
return self._client.generate([prompts], stop, callbacks)
|
|
|
|
|
|
+ def prompt_file_name(self, mode: str) -> str:
|
|
|
+ if mode == 'completion':
|
|
|
+ return 'baichuan_completion'
|
|
|
+ else:
|
|
|
+ return 'baichuan_chat'
|
|
|
+
|
|
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
|
|
"""
|
|
|
get num tokens of prompt messages.
|