ソースを参照

feat: support json schema for gemini models (#10835)

非法操作 5 ヶ月 前
コミット
bc1013dacf
18 ファイル変更61 行追加77 行削除
  1. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml
  2. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml
  3. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml
  4. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml
  5. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml
  6. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml
  7. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml
  8. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml
  9. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml
  10. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml
  11. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml
  12. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml
  13. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml
  14. 3 4
      api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml
  15. 1 0
      api/core/model_runtime/model_providers/google/llm/gemini-pro-vision.yaml
  16. 1 0
      api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml
  17. 10 14
      api/core/model_runtime/model_providers/google/llm/llm.py
  18. 7 7
      api/tests/integration_tests/model_runtime/google/test_llm.py

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 3 - 4
api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml

@@ -24,14 +24,13 @@ parameter_rules:
       zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
       en_US: Only sample from the top K options for each subsequent token.
     required: false
-  - name: max_tokens_to_sample
+  - name: max_output_tokens
     use_template: max_tokens
-    required: true
     default: 8192
     min: 1
     max: 8192
-  - name: response_format
-    use_template: response_format
+  - name: json_schema
+    use_template: json_schema
 pricing:
   input: '0.00'
   output: '0.00'

+ 1 - 0
api/core/model_runtime/model_providers/google/llm/gemini-pro-vision.yaml

@@ -32,3 +32,4 @@ pricing:
   output: '0.00'
   unit: '0.000001'
   currency: USD
+deprecated: true

+ 1 - 0
api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml

@@ -36,3 +36,4 @@ pricing:
   output: '0.00'
   unit: '0.000001'
   currency: USD
+deprecated: true

+ 10 - 14
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -1,7 +1,6 @@
 import base64
 import io
 import json
-import logging
 from collections.abc import Generator
 from typing import Optional, Union, cast
 
@@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 
-logger = logging.getLogger(__name__)
-
-GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
-The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
-if you are not sure about the structure.
-
-<instructions>
-{{instructions}}
-</instructions>
-"""  # noqa: E501
-
 
 class GoogleLargeLanguageModel(LargeLanguageModel):
     def _invoke(
@@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 
         try:
             ping_message = SystemPromptMessage(content="ping")
-            self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
+            self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
 
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
@@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         :return: full response or stream response chunk generator result
         """
         config_kwargs = model_parameters.copy()
-        config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
+        if schema := config_kwargs.pop("json_schema", None):
+            try:
+                schema = json.loads(schema)
+            except:
+                raise exceptions.InvalidArgument("Invalid JSON Schema")
+            if tools:
+                raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
+            config_kwargs["response_schema"] = schema
+            config_kwargs["response_mime_type"] = "application/json"
 
         if stop:
             config_kwargs["stop_sequences"] = stop

+ 7 - 7
api/tests/integration_tests/model_runtime/google/test_llm.py

@@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
     model = GoogleLargeLanguageModel()
 
     response = model.invoke(
-        model="gemini-pro",
+        model="gemini-1.5-pro",
         credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
@@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
                 ]
             ),
         ],
-        model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
+        model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
         stop=["How"],
         stream=False,
         user="abc-123",
@@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
     model = GoogleLargeLanguageModel()
 
     response = model.invoke(
-        model="gemini-pro",
+        model="gemini-1.5-pro",
         credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
@@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
                 ]
             ),
         ],
-        model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
+        model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
         stream=True,
         user="abc-123",
     )
@@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
     model = GoogleLargeLanguageModel()
 
     result = model.invoke(
-        model="gemini-pro-vision",
+        model="gemini-1.5-pro",
         credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(
@@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
     model = GoogleLargeLanguageModel()
 
     result = model.invoke(
-        model="gemini-pro-vision",
+        model="gemini-1.5-pro",
         credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(content="You are a helpful AI assistant."),
@@ -164,7 +164,7 @@ def test_get_num_tokens():
     model = GoogleLargeLanguageModel()
 
     num_tokens = model.get_num_tokens(
-        model="gemini-pro",
+        model="gemini-1.5-pro",
         credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
         prompt_messages=[
             SystemPromptMessage(