Browse Source

feat:add think tag display for xinference deepseek r1 (#13291)

呆萌闷油瓶 2 months ago
parent
commit
f7e7a399d9
1 changed files with 16 additions and 6 deletions
  1. 16 6
      api/core/model_runtime/model_providers/xinference/llm/llm.py

+ 16 - 6
api/core/model_runtime/model_providers/xinference/llm/llm.py

@@ -1,3 +1,4 @@
+import re
 from collections.abc import Generator, Iterator
 from typing import Optional, cast
 
@@ -635,16 +636,16 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
         handle stream chat generate response
         """
         full_response = ""
-
+        is_reasoning_started_tag = False
         for chunk in resp:
             if len(chunk.choices) == 0:
                 continue
-
             delta = chunk.choices[0]
-
             if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""):
                 continue
-
+            delta_content = delta.delta.content
+            if not delta_content:
+                delta_content = ""
             # check if there is a tool call in the response
             function_call = None
             tool_calls = []
@@ -657,9 +658,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
             if function_call:
                 assistant_message_tool_calls += [self._extract_response_function_call(function_call)]
 
+            if not is_reasoning_started_tag and "<think>" in delta_content:
+                is_reasoning_started_tag = True
+                delta_content = "> 💭 " + delta_content.replace("<think>", "")
+            elif is_reasoning_started_tag and "</think>" in delta_content:
+                delta_content = delta_content.replace("</think>", "") + "\n\n"
+                is_reasoning_started_tag = False
+            elif is_reasoning_started_tag:
+                if "\n" in delta_content:
+                    delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
             # transform assistant message to prompt message
             assistant_prompt_message = AssistantPromptMessage(
-                content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
+                content=delta_content or "", tool_calls=assistant_message_tool_calls
             )
 
             if delta.finish_reason is not None:
@@ -697,7 +707,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
                     ),
                 )
 
-                full_response += delta.delta.content
+                full_response += delta_content
 
     def _handle_completion_generate_response(
         self,