|
@@ -6,6 +6,7 @@ from core.model_runtime.entities.message_entities import (
|
|
|
PromptMessage,
|
|
|
PromptMessageTool,
|
|
|
)
|
|
|
+from core.model_runtime.entities.model_entities import ModelFeature
|
|
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
|
|
|
|
|
|
|
@@ -28,14 +29,13 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
user: Optional[str] = None,
|
|
|
) -> Union[LLMResult, Generator]:
|
|
|
self._add_custom_parameters(credentials, model, model_parameters)
|
|
|
- return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
|
|
+ return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
self._add_custom_parameters(credentials, model, None)
|
|
|
super().validate_credentials(model, credentials)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None:
|
|
|
+ def _add_custom_parameters(self, credentials: dict, model: str, model_parameters: dict) -> None:
|
|
|
if model is None:
|
|
|
model = "bge-large-zh-v1.5"
|
|
|
|
|
@@ -45,3 +45,7 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
credentials["mode"] = LLMMode.COMPLETION.value
|
|
|
else:
|
|
|
credentials["mode"] = LLMMode.CHAT.value
|
|
|
+
|
|
|
+ schema = self.get_model_schema(model, credentials)
|
|
|
+ if ModelFeature.TOOL_CALL in schema.features or ModelFeature.MULTI_TOOL_CALL in schema.features:
|
|
|
+ credentials["function_calling_type"] = "tool_call"
|