소스 검색

Fix/localai (#2840)

Yeuoly 1 년 전
부모
커밋
742be06ea9

+ 10 - 4
api/core/model_runtime/model_providers/localai/llm/llm.py

@@ -1,6 +1,5 @@
 from collections.abc import Generator
 from typing import cast
-from urllib.parse import urljoin
 
 from httpx import Timeout
 from openai import (
@@ -19,6 +18,7 @@ from openai import (
 from openai.types.chat import ChatCompletion, ChatCompletionChunk
 from openai.types.chat.chat_completion_message import FunctionCall
 from openai.types.completion import Completion
+from yarl import URL
 
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
@@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel):
                 UserPromptMessage(content='ping')
             ], model_parameters={
                 'max_tokens': 10,
-            }, stop=[])
+            }, stop=[], stream=False)
         except Exception as ex:
             raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
 
@@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel):
             )
         ]
 
+        model_properties = { 
+            ModelPropertyKey.MODE: completion_model,
+        } if completion_model else {}
+
+        model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
+
         entity = AIModelEntity(
             model=model,
             label=I18nObject(
@@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel):
             ),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
-            model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
+            model_properties=model_properties,
             parameter_rules=rules
         )
 
@@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel):
         client_kwargs = {
             "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
             "api_key": "1",
-            "base_url": urljoin(credentials['server_url'], 'v1'),
+            "base_url": str(URL(credentials['server_url']) / 'v1'),
         }
 
         return client_kwargs

+ 9 - 0
api/core/model_runtime/model_providers/localai/localai.yaml

@@ -56,3 +56,12 @@ model_credential_schema:
       placeholder:
         zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
         en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
+    - variable: context_size
+      label:
+        zh_Hans: 上下文大小
+        en_US: Context size
+      placeholder:
+        zh_Hans: 输入上下文大小
+        en_US: Enter context size
+      required: false
+      type: text-input

+ 25 - 3
api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py

@@ -1,11 +1,12 @@
 import time
 from json import JSONDecodeError, dumps
-from os.path import join
 from typing import Optional
 
 from requests import post
+from yarl import URL
 
-from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
@@ -57,7 +58,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
         }
 
         try:
-            response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
+            response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
         except Exception as e:
             raise InvokeConnectionError(str(e))
         
@@ -113,6 +114,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
             # use GPT2Tokenizer to get num tokens
             num_tokens += self._get_num_tokens_by_gpt2(text)
         return num_tokens
+    
+    def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+        Get customizable model schema
+
+        :param model: model name
+        :param credentials: model credentials
+        :return: model schema
+        """
+        return AIModelEntity(
+            model=model,
+            label=I18nObject(zh_Hans=model, en_US=model),
+            model_type=ModelType.TEXT_EMBEDDING,
+            features=[],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
+                ModelPropertyKey.MAX_CHUNKS: 1,
+            },
+            parameter_rules=[]
+        )
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """