|
@@ -1,6 +1,5 @@
|
|
from collections.abc import Generator
|
|
from collections.abc import Generator
|
|
from typing import cast
|
|
from typing import cast
|
|
-from urllib.parse import urljoin
|
|
|
|
|
|
|
|
from httpx import Timeout
|
|
from httpx import Timeout
|
|
from openai import (
|
|
from openai import (
|
|
@@ -19,6 +18,7 @@ from openai import (
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
from openai.types.chat.chat_completion_message import FunctionCall
|
|
from openai.types.chat.chat_completion_message import FunctionCall
|
|
from openai.types.completion import Completion
|
|
from openai.types.completion import Completion
|
|
|
|
+from yarl import URL
|
|
|
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
@@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|
UserPromptMessage(content='ping')
|
|
UserPromptMessage(content='ping')
|
|
], model_parameters={
|
|
], model_parameters={
|
|
'max_tokens': 10,
|
|
'max_tokens': 10,
|
|
- }, stop=[])
|
|
|
|
|
|
+ }, stop=[], stream=False)
|
|
except Exception as ex:
|
|
except Exception as ex:
|
|
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
|
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
|
|
|
|
|
@@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|
)
|
|
)
|
|
]
|
|
]
|
|
|
|
|
|
|
|
+ model_properties = {
|
|
|
|
+ ModelPropertyKey.MODE: completion_model,
|
|
|
|
+ } if completion_model else {}
|
|
|
|
+
|
|
|
|
+ model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
|
|
|
|
+
|
|
entity = AIModelEntity(
|
|
entity = AIModelEntity(
|
|
model=model,
|
|
model=model,
|
|
label=I18nObject(
|
|
label=I18nObject(
|
|
@@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|
),
|
|
),
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
model_type=ModelType.LLM,
|
|
model_type=ModelType.LLM,
|
|
- model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
|
|
|
|
|
|
+ model_properties=model_properties,
|
|
parameter_rules=rules
|
|
parameter_rules=rules
|
|
)
|
|
)
|
|
|
|
|
|
@@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|
client_kwargs = {
|
|
client_kwargs = {
|
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
|
"api_key": "1",
|
|
"api_key": "1",
|
|
- "base_url": urljoin(credentials['server_url'], 'v1'),
|
|
|
|
|
|
+ "base_url": str(URL(credentials['server_url']) / 'v1'),
|
|
}
|
|
}
|
|
|
|
|
|
return client_kwargs
|
|
return client_kwargs
|