|
@@ -1,7 +1,7 @@
|
|
|
from collections.abc import Generator
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
|
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
|
@@ -26,7 +26,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
) -> Union[LLMResult, Generator]:
|
|
|
self._update_credential(model, credentials)
|
|
|
|
|
|
- return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+ return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
self._update_credential(model, credentials)
|
|
@@ -46,7 +46,48 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
) -> Union[LLMResult, Generator]:
|
|
|
self._update_credential(model, credentials)
|
|
|
|
|
|
- return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+ block_as_stream = False
|
|
|
+ if model.startswith("openai/o1"):
|
|
|
+ block_as_stream = True
|
|
|
+ stop = None
|
|
|
+
|
|
|
+ # invoke block as stream
|
|
|
+ if stream and block_as_stream:
|
|
|
+ return self._generate_block_as_stream(
|
|
|
+ model, credentials, prompt_messages, model_parameters, tools, stop, user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+
|
|
|
+ def _generate_block_as_stream(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None,
|
|
|
+ stop: Optional[list[str]] = None,
|
|
|
+ user: Optional[str] = None,
|
|
|
+ ) -> Generator:
|
|
|
+ resp: LLMResult = super()._generate(
|
|
|
+ model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
|
|
+ )
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=resp.message,
|
|
|
+ usage=self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ prompt_tokens=resp.usage.prompt_tokens,
|
|
|
+ completion_tokens=resp.usage.completion_tokens,
|
|
|
+ ),
|
|
|
+ finish_reason="stop",
|
|
|
+ ),
|
|
|
+ )
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
self._update_credential(model, credentials)
|