|
@@ -1,5 +1,5 @@
|
|
|
+import base64
|
|
|
import random
|
|
|
-from base64 import b64decode
|
|
|
from typing import Any, Union
|
|
|
|
|
|
from openai import OpenAI
|
|
@@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
|
|
|
result = []
|
|
|
|
|
|
for image in response.data:
|
|
|
- result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
|
|
- meta={'mime_type': 'image/png'},
|
|
|
- save_as=self.VARIABLE_KEY.IMAGE.value))
|
|
|
+ mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
|
|
+ blob_message = self.create_blob_message(blob=blob_image,
|
|
|
+ meta={'mime_type': mime_type},
|
|
|
+ save_as=self.VARIABLE_KEY.IMAGE.value)
|
|
|
+ result.append(blob_message)
|
|
|
return result
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _decode_image(base64_image: str) -> tuple[str, bytes]:
|
|
|
+ """
|
|
|
+ Decode a base64 encoded image. If the image is not prefixed with a MIME type,
|
|
|
+ it assumes 'image/png' as the default.
|
|
|
+
|
|
|
+ :param base64_image: Base64 encoded image string
|
|
|
+ :return: A tuple containing the MIME type and the decoded image bytes
|
|
|
+ """
|
|
|
+ if DallE3Tool._is_plain_base64(base64_image):
|
|
|
+ return 'image/png', base64.b64decode(base64_image)
|
|
|
+ else:
|
|
|
+ return DallE3Tool._extract_mime_and_data(base64_image)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _is_plain_base64(encoded_str: str) -> bool:
|
|
|
+ """
|
|
|
+ Check if the given encoded string is plain base64 without a MIME type prefix.
|
|
|
+
|
|
|
+ :param encoded_str: Base64 encoded image string
|
|
|
+ :return: True if the string is plain base64, False otherwise
|
|
|
+ """
|
|
|
+ return not encoded_str.startswith('data:image')
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
|
|
|
+ """
|
|
|
+ Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
|
|
|
+
|
|
|
+ :param encoded_str: Base64 encoded image string with MIME type prefix
|
|
|
+ :return: A tuple containing the MIME type and the decoded image bytes
|
|
|
+ """
|
|
|
+ mime_type = encoded_str.split(';')[0].split(':')[1]
|
|
|
+ image_data_base64 = encoded_str.split(',')[1]
|
|
|
+ decoded_data = base64.b64decode(image_data_base64)
|
|
|
+ return mime_type, decoded_data
|
|
|
+
|
|
|
@staticmethod
|
|
|
def _generate_random_id(length=8):
|
|
|
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|