ソースを参照

feat: support o1 series models for openrouter (#8358)

sino 7 ヶ月 前
コミット
6d56d5c1f6

+ 2 - 0
api/core/model_runtime/model_providers/openrouter/llm/_position.yaml

@@ -1,3 +1,5 @@
+- openai/o1-preview
+- openai/o1-mini
 - openai/gpt-4o
 - openai/gpt-4o-mini
 - openai/gpt-4

+ 44 - 3
api/core/model_runtime/model_providers/openrouter/llm/llm.py

@@ -1,7 +1,7 @@
 from collections.abc import Generator
 from typing import Optional, Union
 
-from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
 from core.model_runtime.entities.model_entities import AIModelEntity
 from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
@@ -26,7 +26,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
     ) -> Union[LLMResult, Generator]:
         self._update_credential(model, credentials)
 
-        return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         self._update_credential(model, credentials)
@@ -46,7 +46,48 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
     ) -> Union[LLMResult, Generator]:
         self._update_credential(model, credentials)
 
-        return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+        block_as_stream = False
+        if model.startswith("openai/o1"):
+            block_as_stream = True
+            stop = None
+
+        # invoke block as stream
+        if stream and block_as_stream:
+            return self._generate_block_as_stream(
+                model, credentials, prompt_messages, model_parameters, tools, stop, user
+            )
+        else:
+            return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def _generate_block_as_stream(
+        self,
+        model: str,
+        credentials: dict,
+        prompt_messages: list[PromptMessage],
+        model_parameters: dict,
+        tools: Optional[list[PromptMessageTool]] = None,
+        stop: Optional[list[str]] = None,
+        user: Optional[str] = None,
+    ) -> Generator:
+        resp: LLMResult = super()._generate(
+            model, credentials, prompt_messages, model_parameters, tools, stop, False, user
+        )
+
+        yield LLMResultChunk(
+            model=model,
+            prompt_messages=prompt_messages,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=resp.message,
+                usage=self._calc_response_usage(
+                    model=model,
+                    credentials=credentials,
+                    prompt_tokens=resp.usage.prompt_tokens,
+                    completion_tokens=resp.usage.completion_tokens,
+                ),
+                finish_reason="stop",
+            ),
+        )
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
         self._update_credential(model, credentials)

+ 40 - 0
api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml

@@ -0,0 +1,40 @@
+model: openai/o1-mini
+label:
+  en_US: o1-mini
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 512
+    min: 1
+    max: 65536
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response_format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: "3.00"
+  output: "12.00"
+  unit: "0.000001"
+  currency: USD

+ 40 - 0
api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml

@@ -0,0 +1,40 @@
+model: openai/o1-preview
+label:
+  en_US: o1-preview
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 512
+    min: 1
+    max: 32768
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response_format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: "15.00"
+  output: "60.00"
+  unit: "0.000001"
+  currency: USD