Przeglądaj źródła

fix: gemini system prompt with variable raise error (#11946)

非法操作 4 miesięcy temu
rodzic
commit
366857cd26

+ 13 - 3
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessageContentType,
     PromptMessageTool,
     SystemPromptMessage,
+    TextPromptMessageContent,
     ToolPromptMessage,
     UserPromptMessage,
 )
@@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         """
 
         try:
-            ping_message = SystemPromptMessage(content="ping")
+            ping_message = UserPromptMessage(content="ping")
             self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
 
         except Exception as ex:
@@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
             config_kwargs["stop_sequences"] = stop
 
         genai.configure(api_key=credentials["google_api_key"])
-        google_model = genai.GenerativeModel(model_name=model)
 
         history = []
+        system_instruction = None
 
         for msg in prompt_messages:  # makes message roles strictly alternating
             content = self._format_message_to_glm_content(msg)
             if history and history[-1]["role"] == content["role"]:
                 history[-1]["parts"].extend(content["parts"])
+            elif content["role"] == "system":
+                system_instruction = content["parts"][0]
             else:
                 history.append(content)
 
+        if not history:
+            raise InvokeError("The user prompt message is required. You only add a system prompt message.")
+
+        google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
         response = google_model.generate_content(
             contents=history,
             generation_config=genai.types.GenerationConfig(**config_kwargs),
@@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
                 )
             return glm_content
         elif isinstance(message, SystemPromptMessage):
-            return {"role": "user", "parts": [to_part(message.content)]}
+            if isinstance(message.content, list):
+                text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
+                message.content = "".join(c.data for c in text_contents)
+            return {"role": "system", "parts": [to_part(message.content)]}
         elif isinstance(message, ToolPromptMessage):
             return {
                 "role": "function",