|
@@ -1,14 +1,13 @@
|
|
import json
|
|
import json
|
|
from typing import Type
|
|
from typing import Type
|
|
|
|
|
|
-from langchain.llms import OpenLLM
|
|
|
|
-
|
|
|
|
from core.helper import encrypter
|
|
from core.helper import encrypter
|
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
|
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
|
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
|
|
|
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
|
|
+from core.third_party.langchain.llms.openllm import OpenLLM
|
|
from models.provider import ProviderType
|
|
from models.provider import ProviderType
|
|
|
|
|
|
|
|
|
|
@@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
return ModelKwargsRules(
|
|
return ModelKwargsRules(
|
|
- temperature=KwargRule[float](min=0, max=2, default=1),
|
|
|
|
|
|
+ temperature=KwargRule[float](min=0.01, max=2, default=1),
|
|
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
|
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
|
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
- max_tokens=KwargRule[int](min=10, max=4000, default=128),
|
|
|
|
|
|
+ max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
|
)
|
|
)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
@@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
|
|
}
|
|
}
|
|
|
|
|
|
llm = OpenLLM(
|
|
llm = OpenLLM(
|
|
- max_tokens=10,
|
|
|
|
|
|
+ llm_kwargs={
|
|
|
|
+ 'max_new_tokens': 10
|
|
|
|
+ },
|
|
**credential_kwargs
|
|
**credential_kwargs
|
|
)
|
|
)
|
|
|
|
|