|
@@ -2,11 +2,13 @@ import json
|
|
from typing import Type
|
|
from typing import Type
|
|
|
|
|
|
from core.helper import encrypter
|
|
from core.helper import encrypter
|
|
|
|
+from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
|
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
|
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
|
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
|
|
|
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
|
|
+from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
|
|
from core.third_party.langchain.llms.openllm import OpenLLM
|
|
from core.third_party.langchain.llms.openllm import OpenLLM
|
|
from models.provider import ProviderType
|
|
from models.provider import ProviderType
|
|
|
|
|
|
@@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider):
|
|
"""
|
|
"""
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
if model_type == ModelType.TEXT_GENERATION:
|
|
model_class = OpenLLMModel
|
|
model_class = OpenLLMModel
|
|
|
|
+ elif model_type== ModelType.EMBEDDINGS:
|
|
|
|
+ model_class = OpenLLMEmbedding
|
|
else:
|
|
else:
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
@@ -69,14 +73,21 @@ class OpenLLMProvider(BaseModelProvider):
|
|
'server_url': credentials['server_url']
|
|
'server_url': credentials['server_url']
|
|
}
|
|
}
|
|
|
|
|
|
- llm = OpenLLM(
|
|
+ if model_type == ModelType.TEXT_GENERATION:
|
|
- llm_kwargs={
|
|
+ llm = OpenLLM(
|
|
- 'max_new_tokens': 10
|
|
+ llm_kwargs={
|
|
- },
|
|
+ 'max_new_tokens': 10
|
|
- **credential_kwargs
|
|
+ },
|
|
- )
|
|
+ **credential_kwargs
|
|
-
|
|
+ )
|
|
- llm("ping")
|
|
+
|
|
|
|
+ llm("ping")
|
|
|
|
+ elif model_type == ModelType.EMBEDDINGS:
|
|
|
|
+ embedding = OpenLLMEmbeddings(
|
|
|
|
+ **credential_kwargs
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ embedding.embed_query("ping")
|
|
except Exception as ex:
|
|
except Exception as ex:
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|