request.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from typing import Any, Literal, Optional
  2. from pydantic import BaseModel, ConfigDict, Field, field_validator
  3. from core.entities.provider_entities import BasicProviderConfig
  4. from core.model_runtime.entities.message_entities import (
  5. AssistantPromptMessage,
  6. PromptMessage,
  7. PromptMessageRole,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. ToolPromptMessage,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.workflow.nodes.parameter_extractor.entities import (
  15. ModelConfig as ParameterExtractorModelConfig,
  16. )
  17. from core.workflow.nodes.parameter_extractor.entities import (
  18. ParameterConfig,
  19. )
  20. from core.workflow.nodes.question_classifier.entities import (
  21. ClassConfig,
  22. )
  23. from core.workflow.nodes.question_classifier.entities import (
  24. ModelConfig as QuestionClassifierModelConfig,
  25. )
  26. class RequestInvokeTool(BaseModel):
  27. """
  28. Request to invoke a tool
  29. """
  30. tool_type: Literal["builtin", "workflow", "api"]
  31. provider: str
  32. tool: str
  33. tool_parameters: dict
  34. class BaseRequestInvokeModel(BaseModel):
  35. provider: str
  36. model: str
  37. model_type: ModelType
  38. model_config = ConfigDict(protected_namespaces=())
  39. class RequestInvokeLLM(BaseRequestInvokeModel):
  40. """
  41. Request to invoke LLM
  42. """
  43. model_type: ModelType = ModelType.LLM
  44. mode: str
  45. completion_params: dict[str, Any] = Field(default_factory=dict)
  46. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  47. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
  48. stop: Optional[list[str]] = Field(default_factory=list)
  49. stream: Optional[bool] = False
  50. model_config = ConfigDict(protected_namespaces=())
  51. @field_validator("prompt_messages", mode="before")
  52. @classmethod
  53. def convert_prompt_messages(cls, v):
  54. if not isinstance(v, list):
  55. raise ValueError("prompt_messages must be a list")
  56. for i in range(len(v)):
  57. if v[i]["role"] == PromptMessageRole.USER.value:
  58. v[i] = UserPromptMessage(**v[i])
  59. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  60. v[i] = AssistantPromptMessage(**v[i])
  61. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  62. v[i] = SystemPromptMessage(**v[i])
  63. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  64. v[i] = ToolPromptMessage(**v[i])
  65. else:
  66. v[i] = PromptMessage(**v[i])
  67. return v
  68. class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
  69. """
  70. Request to invoke text embedding
  71. """
  72. model_type: ModelType = ModelType.TEXT_EMBEDDING
  73. texts: list[str]
  74. class RequestInvokeRerank(BaseRequestInvokeModel):
  75. """
  76. Request to invoke rerank
  77. """
  78. model_type: ModelType = ModelType.RERANK
  79. query: str
  80. docs: list[str]
  81. score_threshold: float
  82. top_n: int
  83. class RequestInvokeTTS(BaseRequestInvokeModel):
  84. """
  85. Request to invoke TTS
  86. """
  87. model_type: ModelType = ModelType.TTS
  88. content_text: str
  89. voice: str
  90. class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
  91. """
  92. Request to invoke speech2text
  93. """
  94. model_type: ModelType = ModelType.SPEECH2TEXT
  95. file: bytes
  96. @field_validator("file", mode="before")
  97. @classmethod
  98. def convert_file(cls, v):
  99. # hex string to bytes
  100. if isinstance(v, str):
  101. return bytes.fromhex(v)
  102. else:
  103. raise ValueError("file must be a hex string")
  104. class RequestInvokeModeration(BaseRequestInvokeModel):
  105. """
  106. Request to invoke moderation
  107. """
  108. model_type: ModelType = ModelType.MODERATION
  109. text: str
  110. class RequestInvokeParameterExtractorNode(BaseModel):
  111. """
  112. Request to invoke parameter extractor node
  113. """
  114. parameters: list[ParameterConfig]
  115. model: ParameterExtractorModelConfig
  116. instruction: str
  117. query: str
  118. class RequestInvokeQuestionClassifierNode(BaseModel):
  119. """
  120. Request to invoke question classifier node
  121. """
  122. query: str
  123. model: QuestionClassifierModelConfig
  124. classes: list[ClassConfig]
  125. instruction: str
  126. class RequestInvokeApp(BaseModel):
  127. """
  128. Request to invoke app
  129. """
  130. app_id: str
  131. inputs: dict[str, Any]
  132. query: Optional[str] = None
  133. response_mode: Literal["blocking", "streaming"]
  134. conversation_id: Optional[str] = None
  135. user: Optional[str] = None
  136. files: list[dict] = Field(default_factory=list)
  137. class RequestInvokeEncrypt(BaseModel):
  138. """
  139. Request to encryption
  140. """
  141. opt: Literal["encrypt", "decrypt", "clear"]
  142. namespace: Literal["endpoint"]
  143. identity: str
  144. data: dict = Field(default_factory=dict)
  145. config: list[BasicProviderConfig] = Field(default_factory=list)
  146. class RequestInvokeSummary(BaseModel):
  147. """
  148. Request to summary
  149. """
  150. text: str
  151. instruction: str
  152. class RequestRequestUploadFile(BaseModel):
  153. """
  154. Request to upload file
  155. """
  156. filename: str
  157. mimetype: str