Pārlūkot izejas kodu

fix: miss usage of os.path.join for URL assembly and add tests on yarl (#4224)

Bowen Liang 11 mēneši atpakaļ
vecāks
revīzija
228de1f12a

+ 2 - 2
api/core/model_runtime/model_providers/chatglm/llm/llm.py

@@ -1,6 +1,5 @@
 import logging
 from collections.abc import Generator
-from os.path import join
 from typing import Optional, cast
 
 from httpx import Timeout
@@ -19,6 +18,7 @@ from openai import (
 )
 from openai.types.chat import ChatCompletion, ChatCompletionChunk
 from openai.types.chat.chat_completion_message import FunctionCall
+from yarl import URL
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
@@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
         client_kwargs = {
             "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
             "api_key": "1",
-            "base_url": join(credentials['api_base'], 'v1')
+            "base_url": str(URL(credentials['api_base']) / 'v1')
         }
 
         return client_kwargs

+ 2 - 2
api/core/tools/provider/builtin/dalle/tools/dalle2.py

@@ -1,8 +1,8 @@
 from base64 import b64decode
-from os.path import join
 from typing import Any, Union
 
 from openai import OpenAI
+from yarl import URL
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
@@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool):
         if not openai_base_url:
             openai_base_url = None
         else:
-            openai_base_url = join(openai_base_url, 'v1')
+            openai_base_url = str(URL(openai_base_url) / 'v1')
 
         client = OpenAI(
             api_key=self.runtime.credentials['openai_api_key'],

+ 2 - 2
api/core/tools/provider/builtin/dalle/tools/dalle3.py

@@ -1,8 +1,8 @@
 from base64 import b64decode
-from os.path import join
 from typing import Any, Union
 
 from openai import OpenAI
+from yarl import URL
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool.builtin_tool import BuiltinTool
@@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool):
         if not openai_base_url:
             openai_base_url = None
         else:
-            openai_base_url = join(openai_base_url, 'v1')
+            openai_base_url = str(URL(openai_base_url) / 'v1')
 
         client = OpenAI(
             api_key=self.runtime.credentials['openai_api_key'],

+ 23 - 0
api/tests/unit_tests/libs/test_yarl.py

@@ -0,0 +1,23 @@
+import pytest
+from yarl import URL
+
+
+def test_yarl_urls():
+    expected_1 = 'https://dify.ai/api'
+    assert str(URL('https://dify.ai') / 'api') == expected_1
+    assert str(URL('https://dify.ai/') / 'api') == expected_1
+
+    expected_2 = 'http://dify.ai:12345/api'
+    assert str(URL('http://dify.ai:12345') / 'api') == expected_2
+    assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
+
+    expected_3 = 'https://dify.ai/api/v1'
+    assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
+    assert str(URL('https://dify.ai') / 'api/v1') == expected_3
+    assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
+    assert str(URL('https://dify.ai/api') / 'v1') == expected_3
+    assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
+
+    with pytest.raises(ValueError) as e1:
+        str(URL('https://dify.ai') / '/api')
+    assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"