Ver código fonte

fix: Unable to display images generated by Dall-E 3 (#6155)

Weishan-0 9 meses atrás
pai
commit
7b45a5d452
1 arquivos alterados com 43 adições e 4 exclusões
  1. 43 4
      api/core/tools/provider/builtin/dalle/tools/dalle3.py

+ 43 - 4
api/core/tools/provider/builtin/dalle/tools/dalle3.py

@@ -1,5 +1,5 @@
+import base64
 import random
-from base64 import b64decode
 from typing import Any, Union
 
 from openai import OpenAI
@@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
         result = []
 
         for image in response.data:
-            result.append(self.create_blob_message(blob=b64decode(image.b64_json),
-                                                   meta={'mime_type': 'image/png'},
-                                                   save_as=self.VARIABLE_KEY.IMAGE.value))
+            mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
+            blob_message = self.create_blob_message(blob=blob_image,
+                                                    meta={'mime_type': mime_type},
+                                                    save_as=self.VARIABLE_KEY.IMAGE.value)
+            result.append(blob_message)
         return result
 
+    @staticmethod
+    def _decode_image(base64_image: str) -> tuple[str, bytes]:
+        """
+        Decode a base64 encoded image. If the image is not prefixed with a MIME type,
+        it assumes 'image/png' as the default.
+
+        :param base64_image: Base64 encoded image string
+        :return: A tuple containing the MIME type and the decoded image bytes
+        """
+        if DallE3Tool._is_plain_base64(base64_image):
+            return 'image/png', base64.b64decode(base64_image)
+        else:
+            return DallE3Tool._extract_mime_and_data(base64_image)
+
+    @staticmethod
+    def _is_plain_base64(encoded_str: str) -> bool:
+        """
+        Check if the given encoded string is plain base64 without a MIME type prefix.
+
+        :param encoded_str: Base64 encoded image string
+        :return: True if the string is plain base64, False otherwise
+        """
+        return not encoded_str.startswith('data:image')
+
+    @staticmethod
+    def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
+        """
+        Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
+
+        :param encoded_str: Base64 encoded image string with MIME type prefix
+        :return: A tuple containing the MIME type and the decoded image bytes
+        """
+        mime_type = encoded_str.split(';')[0].split(':')[1]
+        image_data_base64 = encoded_str.split(',')[1]
+        decoded_data = base64.b64decode(image_data_base64)
+        return mime_type, decoded_data
+
     @staticmethod
     def _generate_random_id(length=8):
         characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'