瀏覽代碼

fix streaming (#1944)

Chenhe Gu 1 年之前
父節點
當前提交
de584807e1

+ 3 - 9
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -337,9 +337,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 )
             )
 
-        for chunk in response.iter_content(chunk_size=2048):
+        for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'):
             if chunk:
-                decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
+                decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
 
                 chunk_json = None
                 try:
@@ -356,7 +356,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     continue
 
                 choice = chunk_json['choices'][0]
-                chunk_index = choice['index'] if 'index' in choice else chunk_index
+                chunk_index += 1
 
                 if 'delta' in choice:
                     delta = choice['delta']
@@ -408,12 +408,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                             message=assistant_prompt_message,
                         )
                     )
-            else:
-                yield create_final_llm_result_chunk(
-                    index=chunk_index + 1,
-                    message=AssistantPromptMessage(content=""),
-                    finish_reason="End of stream."
-                )
 
             chunk_index += 1
 

+ 6 - 6
api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py

@@ -22,7 +22,7 @@ def test_validate_credentials():
             model='mistralai/Mixtral-8x7B-Instruct-v0.1',
             credentials={
                 'api_key': 'invalid_key',
-                'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
+                'endpoint_url': 'https://api.together.xyz/v1/',
                 'mode': 'chat'
             }
         )
@@ -31,7 +31,7 @@ def test_validate_credentials():
         model='mistralai/Mixtral-8x7B-Instruct-v0.1',
         credentials={
             'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
+            'endpoint_url': 'https://api.together.xyz/v1/',
             'mode': 'chat'
         }
     )
@@ -43,7 +43,7 @@ def test_invoke_model():
         model='mistralai/Mixtral-8x7B-Instruct-v0.1',
         credentials={
             'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/completions',
+            'endpoint_url': 'https://api.together.xyz/v1/',
             'mode': 'completion'
         },
         prompt_messages=[
@@ -74,7 +74,7 @@ def test_invoke_stream_model():
         model='mistralai/Mixtral-8x7B-Instruct-v0.1',
         credentials={
             'api_key': os.environ.get('TOGETHER_API_KEY'),
-            'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
+            'endpoint_url': 'https://api.together.xyz/v1/',
             'mode': 'chat'
         },
         prompt_messages=[
@@ -110,7 +110,7 @@ def test_invoke_chat_model_with_tools():
         model='gpt-3.5-turbo',
         credentials={
             'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/chat/completions',
+            'endpoint_url': 'https://api.openai.com/v1/',
             'mode': 'chat'
         },
         prompt_messages=[
@@ -165,7 +165,7 @@ def test_get_num_tokens():
         model='mistralai/Mixtral-8x7B-Instruct-v0.1',
         credentials={
             'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/chat/completions'
+            'endpoint_url': 'https://api.openai.com/v1/'
         },
         prompt_messages=[
             SystemPromptMessage(

+ 7 - 11
api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py

@@ -18,9 +18,8 @@ def test_validate_credentials():
             model='text-embedding-ada-002',
             credentials={
                 'api_key': 'invalid_key',
-                'endpoint_url': 'https://api.openai.com/v1/embeddings',
-                'context_size': 8184,
-                'max_chunks': 32
+                'endpoint_url': 'https://api.openai.com/v1/',
+                'context_size': 8184
                 
             }
         )
@@ -29,9 +28,8 @@ def test_validate_credentials():
         model='text-embedding-ada-002',
         credentials={
             'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/embeddings',
-            'context_size': 8184,
-            'max_chunks': 32
+            'endpoint_url': 'https://api.openai.com/v1/',
+            'context_size': 8184
         }
     )
 
@@ -43,9 +41,8 @@ def test_invoke_model():
         model='text-embedding-ada-002',
         credentials={
             'api_key': os.environ.get('OPENAI_API_KEY'),
-            'endpoint_url': 'https://api.openai.com/v1/embeddings',
-            'context_size': 8184,
-            'max_chunks': 32
+            'endpoint_url': 'https://api.openai.com/v1/',
+            'context_size': 8184
         },
         texts=[
             "hello",
@@ -67,8 +64,7 @@ def test_get_num_tokens():
         credentials={
             'api_key': os.environ.get('OPENAI_API_KEY'),
             'endpoint_url': 'https://api.openai.com/v1/embeddings',
-            'context_size': 8184,
-            'max_chunks': 32
+            'context_size': 8184
         },
         texts=[
             "hello",