message_entities.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import enum
  2. from typing import Any, cast
  3. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
  4. PromptMessage, SystemPromptMessage, TextPromptMessageContent,
  5. ToolPromptMessage, UserPromptMessage)
  6. from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
  7. from pydantic import BaseModel
  8. class PromptMessageFileType(enum.Enum):
  9. IMAGE = 'image'
  10. @staticmethod
  11. def value_of(value):
  12. for member in PromptMessageFileType:
  13. if member.value == value:
  14. return member
  15. raise ValueError(f"No matching enum found for value '{value}'")
  16. class PromptMessageFile(BaseModel):
  17. type: PromptMessageFileType
  18. data: Any
  19. class ImagePromptMessageFile(PromptMessageFile):
  20. class DETAIL(enum.Enum):
  21. LOW = 'low'
  22. HIGH = 'high'
  23. type: PromptMessageFileType = PromptMessageFileType.IMAGE
  24. detail: DETAIL = DETAIL.LOW
  25. class LCHumanMessageWithFiles(HumanMessage):
  26. # content: Union[str, List[Union[str, Dict]]]
  27. content: str
  28. files: list[PromptMessageFile]
  29. def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
  30. prompt_messages = []
  31. for message in messages:
  32. if isinstance(message, HumanMessage):
  33. if isinstance(message, LCHumanMessageWithFiles):
  34. file_prompt_message_contents = []
  35. for file in message.files:
  36. if file.type == PromptMessageFileType.IMAGE:
  37. file = cast(ImagePromptMessageFile, file)
  38. file_prompt_message_contents.append(ImagePromptMessageContent(
  39. data=file.data,
  40. detail=ImagePromptMessageContent.DETAIL.HIGH
  41. if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
  42. ))
  43. prompt_message_contents = [TextPromptMessageContent(data=message.content)]
  44. prompt_message_contents.extend(file_prompt_message_contents)
  45. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  46. else:
  47. prompt_messages.append(UserPromptMessage(content=message.content))
  48. elif isinstance(message, AIMessage):
  49. message_kwargs = {
  50. 'content': message.content
  51. }
  52. if 'function_call' in message.additional_kwargs:
  53. message_kwargs['tool_calls'] = [
  54. AssistantPromptMessage.ToolCall(
  55. id=message.additional_kwargs['function_call']['id'],
  56. type='function',
  57. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  58. name=message.additional_kwargs['function_call']['name'],
  59. arguments=message.additional_kwargs['function_call']['arguments']
  60. )
  61. )
  62. ]
  63. prompt_messages.append(AssistantPromptMessage(**message_kwargs))
  64. elif isinstance(message, SystemMessage):
  65. prompt_messages.append(SystemPromptMessage(content=message.content))
  66. elif isinstance(message, FunctionMessage):
  67. prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
  68. return prompt_messages
  69. def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
  70. messages = []
  71. for prompt_message in prompt_messages:
  72. if isinstance(prompt_message, UserPromptMessage):
  73. if isinstance(prompt_message.content, str):
  74. messages.append(HumanMessage(content=prompt_message.content))
  75. else:
  76. message_contents = []
  77. for content in prompt_message.content:
  78. if isinstance(content, TextPromptMessageContent):
  79. message_contents.append(content.data)
  80. elif isinstance(content, ImagePromptMessageContent):
  81. message_contents.append({
  82. 'type': 'image',
  83. 'data': content.data,
  84. 'detail': content.detail.value
  85. })
  86. messages.append(HumanMessage(content=message_contents))
  87. elif isinstance(prompt_message, AssistantPromptMessage):
  88. message_kwargs = {
  89. 'content': prompt_message.content
  90. }
  91. if prompt_message.tool_calls:
  92. message_kwargs['additional_kwargs'] = {
  93. 'function_call': {
  94. 'id': prompt_message.tool_calls[0].id,
  95. 'name': prompt_message.tool_calls[0].function.name,
  96. 'arguments': prompt_message.tool_calls[0].function.arguments
  97. }
  98. }
  99. messages.append(AIMessage(**message_kwargs))
  100. elif isinstance(prompt_message, SystemPromptMessage):
  101. messages.append(SystemMessage(content=prompt_message.content))
  102. elif isinstance(prompt_message, ToolPromptMessage):
  103. messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
  104. return messages