Kaynağa Gözat

feat: openai_api_compatible support config stream_mode_delimiter (#2190)

Co-authored-by: wanggang <wanggy01@servyou.com.cn>
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
geosmart 1 yıl önce
ebeveyn
işleme
21450b8a51

+ 23 - 13
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -224,7 +224,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
         else:
             raise ValueError(f"Unknown completion type {credentials['completion_type']}")
-    
+
         return entity
 
     # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
@@ -343,32 +343,44 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 )
             )
 
-        for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'):
+        # delimiter for stream response, need unicode_escape
+        import codecs
+        delimiter = credentials.get("stream_mode_delimiter", "\n\n")
+        delimiter = codecs.decode(delimiter, "unicode_escape")
+
+        for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             if chunk:
                 decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
-
                 chunk_json = None
                 try:
                     chunk_json = json.loads(decoded_chunk)
                 # stream ended
                 except json.JSONDecodeError as e:
+                    logger.error(f"decoded_chunk error,delimiter={delimiter},decoded_chunk={decoded_chunk}")
                     yield create_final_llm_result_chunk(
                         index=chunk_index + 1,
                         message=AssistantPromptMessage(content=""),
                         finish_reason="Non-JSON encountered."
                     )
                     break
-
                 if not chunk_json or len(chunk_json['choices']) == 0:
                     continue
 
                 choice = chunk_json['choices'][0]
+                finish_reason = chunk_json['choices'][0].get('finish_reason')
                 chunk_index += 1
 
                 if 'delta' in choice:
                     delta = choice['delta']
                     if delta.get('content') is None or delta.get('content') == '':
-                        continue
+                        if finish_reason is not None:
+                            yield create_final_llm_result_chunk(
+                                index=chunk_index,
+                                message=AssistantPromptMessage(content=choice.get('text', '')),
+                                finish_reason=finish_reason
+                            )
+                        else:
+                            continue
 
                     assistant_message_tool_calls = delta.get('tool_calls', None)
                     # assistant_message_function_call = delta.delta.function_call
@@ -387,24 +399,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
                     full_assistant_content += delta.get('content', '')
                 elif 'text' in choice:
-                    if choice.get('text') is None or choice.get('text') == '':
+                    choice_text = choice.get('text', '')
+                    if choice_text == '':
                         continue
 
                     # transform assistant message to prompt message
-                    assistant_prompt_message = AssistantPromptMessage(
-                        content=choice.get('text', '')
-                    )
-
-                    full_assistant_content += choice.get('text', '')
+                    assistant_prompt_message = AssistantPromptMessage(content=choice_text)
+                    full_assistant_content += choice_text
                 else:
                     continue
 
                 # check payload indicator for completion
-                if chunk_json['choices'][0].get('finish_reason') is not None:
+                if finish_reason is not None:
                     yield create_final_llm_result_chunk(
                         index=chunk_index,
                         message=assistant_prompt_message,
-                        finish_reason=chunk_json['choices'][0]['finish_reason']
+                        finish_reason=finish_reason
                     )
                 else:
                     yield LLMResultChunk(

+ 9 - 0
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -75,3 +75,12 @@ model_credential_schema:
           value: llm
       default: '4096'
       type: text-input
+    - variable: stream_mode_delimiter
+      label:
+        zh_Hans: 流模式返回结果的分隔符
+        en_US: Delimiter for streaming results
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: '\n\n'
+      type: text-input

+ 55 - 13
api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py

@@ -12,6 +12,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
 Using Together.ai's OpenAI-compatible API as testing endpoint
 """
 
+
 def test_validate_credentials():
     model = OAIAPICompatLargeLanguageModel()
 
@@ -34,6 +35,7 @@ def test_validate_credentials():
         }
     )
 
+
 def test_invoke_model():
     model = OAIAPICompatLargeLanguageModel()
 
@@ -65,9 +67,47 @@ def test_invoke_model():
     assert isinstance(response, LLMResult)
     assert len(response.message.content) > 0
 
+
 def test_invoke_stream_model():
     model = OAIAPICompatLargeLanguageModel()
 
+    response = model.invoke(
+        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'endpoint_url': 'https://api.together.xyz/v1/',
+            'mode': 'chat',
+            'stream_mode_delimiter': '\\n\\n'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_invoke_stream_model_without_delimiter():
+    model = OAIAPICompatLargeLanguageModel()
+
     response = model.invoke(
         model='mistralai/Mixtral-8x7B-Instruct-v0.1',
         credentials={
@@ -100,6 +140,7 @@ def test_invoke_stream_model():
         assert isinstance(chunk.delta, LLMResultChunkDelta)
         assert isinstance(chunk.delta.message, AssistantPromptMessage)
 
+
 # using OpenAI's ChatGPT-3.5 as testing endpoint
 def test_invoke_chat_model_with_tools():
     model = OAIAPICompatLargeLanguageModel()
@@ -126,22 +167,22 @@ def test_invoke_chat_model_with_tools():
                 parameters={
                     "type": "object",
                     "properties": {
-                      "location": {
-                        "type": "string",
-                        "description": "The city and state e.g. San Francisco, CA"
-                      },
-                      "unit": {
-                        "type": "string",
-                        "enum": [
-                          "celsius",
-                          "fahrenheit"
-                        ]
-                      }
+                        "location": {
+                            "type": "string",
+                            "description": "The city and state e.g. San Francisco, CA"
+                        },
+                        "unit": {
+                            "type": "string",
+                            "enum": [
+                                "celsius",
+                                "fahrenheit"
+                            ]
+                        }
                     },
                     "required": [
-                      "location"
+                        "location"
                     ]
-                  }
+                }
             ),
         ],
         model_parameters={
@@ -156,6 +197,7 @@ def test_invoke_chat_model_with_tools():
     assert isinstance(result.message, AssistantPromptMessage)
     assert len(result.message.tool_calls) > 0
 
+
 def test_get_num_tokens():
     model = OAIAPICompatLargeLanguageModel()