|
@@ -1,9 +1,9 @@
|
|
|
from abc import ABC
|
|
|
from collections.abc import Sequence
|
|
|
from enum import Enum, StrEnum
|
|
|
-from typing import Literal, Optional
|
|
|
+from typing import Optional
|
|
|
|
|
|
-from pydantic import BaseModel, Field, field_validator
|
|
|
+from pydantic import BaseModel, Field, computed_field, field_validator
|
|
|
|
|
|
|
|
|
class PromptMessageRole(Enum):
|
|
@@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
|
|
|
"""
|
|
|
|
|
|
type: PromptMessageContentType
|
|
|
- data: str
|
|
|
|
|
|
|
|
|
class TextPromptMessageContent(PromptMessageContent):
|
|
@@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
|
|
|
"""
|
|
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
|
|
+ data: str
|
|
|
+
|
|
|
+
|
|
|
+class MultiModalPromptMessageContent(PromptMessageContent):
|
|
|
+ """
|
|
|
+ Model class for multi-modal prompt message content.
|
|
|
+ """
|
|
|
+
|
|
|
+ type: PromptMessageContentType
|
|
|
+ format: str = Field(..., description="the format of multi-modal file")
|
|
|
+ base64_data: str = Field("", description="the base64 data of multi-modal file")
|
|
|
+ url: str = Field("", description="the url of multi-modal file")
|
|
|
+ mime_type: str = Field(..., description="the mime type of multi-modal file")
|
|
|
|
|
|
+ @computed_field(return_type=str)
|
|
|
+ @property
|
|
|
+ def data(self):
|
|
|
+ return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
|
|
|
|
|
-class VideoPromptMessageContent(PromptMessageContent):
|
|
|
+
|
|
|
+class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
|
|
- data: str = Field(..., description="Base64 encoded video data")
|
|
|
- format: str = Field(..., description="Video format")
|
|
|
|
|
|
|
|
|
-class AudioPromptMessageContent(PromptMessageContent):
|
|
|
+class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
|
|
- data: str = Field(..., description="Base64 encoded audio data")
|
|
|
- format: str = Field(..., description="Audio format")
|
|
|
|
|
|
|
|
|
-class ImagePromptMessageContent(PromptMessageContent):
|
|
|
+class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|
|
"""
|
|
|
Model class for image prompt message content.
|
|
|
"""
|
|
@@ -101,14 +114,10 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
|
|
detail: DETAIL = DETAIL.LOW
|
|
|
- format: str = Field("jpg", description="Image format")
|
|
|
|
|
|
|
|
|
-class DocumentPromptMessageContent(PromptMessageContent):
|
|
|
+class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
|
|
- encode_format: Literal["base64"]
|
|
|
- data: str
|
|
|
- format: str = Field(..., description="Document format")
|
|
|
|
|
|
|
|
|
class PromptMessage(ABC, BaseModel):
|