tool_entities.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import base64
  2. import enum
  3. from collections.abc import Mapping
  4. from enum import Enum
  5. from typing import Any, Optional, Union
  6. from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
  7. from core.entities.provider_entities import ProviderConfig
  8. from core.plugin.entities.parameters import (
  9. PluginParameter,
  10. PluginParameterOption,
  11. PluginParameterType,
  12. as_normal_type,
  13. cast_parameter_value,
  14. init_frontend_parameter,
  15. )
  16. from core.tools.entities.common_entities import I18nObject
  17. from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
  18. class ToolLabelEnum(Enum):
  19. SEARCH = "search"
  20. IMAGE = "image"
  21. VIDEOS = "videos"
  22. WEATHER = "weather"
  23. FINANCE = "finance"
  24. DESIGN = "design"
  25. TRAVEL = "travel"
  26. SOCIAL = "social"
  27. NEWS = "news"
  28. MEDICAL = "medical"
  29. PRODUCTIVITY = "productivity"
  30. EDUCATION = "education"
  31. BUSINESS = "business"
  32. ENTERTAINMENT = "entertainment"
  33. UTILITIES = "utilities"
  34. OTHER = "other"
  35. class ToolProviderType(enum.StrEnum):
  36. """
  37. Enum class for tool provider
  38. """
  39. PLUGIN = "plugin"
  40. BUILT_IN = "builtin"
  41. WORKFLOW = "workflow"
  42. API = "api"
  43. APP = "app"
  44. DATASET_RETRIEVAL = "dataset-retrieval"
  45. @classmethod
  46. def value_of(cls, value: str) -> "ToolProviderType":
  47. """
  48. Get value of given mode.
  49. :param value: mode value
  50. :return: mode
  51. """
  52. for mode in cls:
  53. if mode.value == value:
  54. return mode
  55. raise ValueError(f"invalid mode value {value}")
  56. class ApiProviderSchemaType(Enum):
  57. """
  58. Enum class for api provider schema type.
  59. """
  60. OPENAPI = "openapi"
  61. SWAGGER = "swagger"
  62. OPENAI_PLUGIN = "openai_plugin"
  63. OPENAI_ACTIONS = "openai_actions"
  64. @classmethod
  65. def value_of(cls, value: str) -> "ApiProviderSchemaType":
  66. """
  67. Get value of given mode.
  68. :param value: mode value
  69. :return: mode
  70. """
  71. for mode in cls:
  72. if mode.value == value:
  73. return mode
  74. raise ValueError(f"invalid mode value {value}")
  75. class ApiProviderAuthType(Enum):
  76. """
  77. Enum class for api provider auth type.
  78. """
  79. NONE = "none"
  80. API_KEY = "api_key"
  81. @classmethod
  82. def value_of(cls, value: str) -> "ApiProviderAuthType":
  83. """
  84. Get value of given mode.
  85. :param value: mode value
  86. :return: mode
  87. """
  88. for mode in cls:
  89. if mode.value == value:
  90. return mode
  91. raise ValueError(f"invalid mode value {value}")
  92. class ToolInvokeMessage(BaseModel):
  93. class TextMessage(BaseModel):
  94. text: str
  95. class JsonMessage(BaseModel):
  96. json_object: dict
  97. class BlobMessage(BaseModel):
  98. blob: bytes
  99. class FileMessage(BaseModel):
  100. pass
  101. class VariableMessage(BaseModel):
  102. variable_name: str = Field(..., description="The name of the variable")
  103. variable_value: Any = Field(..., description="The value of the variable")
  104. stream: bool = Field(default=False, description="Whether the variable is streamed")
  105. @model_validator(mode="before")
  106. @classmethod
  107. def transform_variable_value(cls, values) -> Any:
  108. """
  109. Only basic types and lists are allowed.
  110. """
  111. value = values.get("variable_value")
  112. if not isinstance(value, dict | list | str | int | float | bool):
  113. raise ValueError("Only basic types and lists are allowed.")
  114. # if stream is true, the value must be a string
  115. if values.get("stream"):
  116. if not isinstance(value, str):
  117. raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
  118. return values
  119. @field_validator("variable_name", mode="before")
  120. @classmethod
  121. def transform_variable_name(cls, value: str) -> str:
  122. """
  123. The variable name must be a string.
  124. """
  125. if value in {"json", "text", "files"}:
  126. raise ValueError(f"The variable name '{value}' is reserved.")
  127. return value
  128. class LogMessage(BaseModel):
  129. class LogStatus(Enum):
  130. START = "start"
  131. ERROR = "error"
  132. SUCCESS = "success"
  133. id: str
  134. label: str = Field(..., description="The label of the log")
  135. parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
  136. error: Optional[str] = Field(default=None, description="The error message")
  137. status: LogStatus = Field(..., description="The status of the log")
  138. data: Mapping[str, Any] = Field(..., description="Detailed log data")
  139. metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
  140. class MessageType(Enum):
  141. TEXT = "text"
  142. IMAGE = "image"
  143. LINK = "link"
  144. BLOB = "blob"
  145. JSON = "json"
  146. IMAGE_LINK = "image_link"
  147. BINARY_LINK = "binary_link"
  148. VARIABLE = "variable"
  149. FILE = "file"
  150. LOG = "log"
  151. type: MessageType = MessageType.TEXT
  152. """
  153. plain text, image url or link url
  154. """
  155. message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
  156. meta: dict[str, Any] | None = None
  157. @field_validator("message", mode="before")
  158. @classmethod
  159. def decode_blob_message(cls, v):
  160. if isinstance(v, dict) and "blob" in v:
  161. try:
  162. v["blob"] = base64.b64decode(v["blob"])
  163. except Exception:
  164. pass
  165. return v
  166. @field_serializer("message")
  167. def serialize_message(self, v):
  168. if isinstance(v, self.BlobMessage):
  169. return {"blob": base64.b64encode(v.blob).decode("utf-8")}
  170. return v
  171. class ToolInvokeMessageBinary(BaseModel):
  172. mimetype: str = Field(..., description="The mimetype of the binary")
  173. url: str = Field(..., description="The url of the binary")
  174. file_var: Optional[dict[str, Any]] = None
  175. class ToolParameter(PluginParameter):
  176. """
  177. Overrides type
  178. """
  179. class ToolParameterType(enum.StrEnum):
  180. """
  181. removes TOOLS_SELECTOR from PluginParameterType
  182. """
  183. STRING = PluginParameterType.STRING.value
  184. NUMBER = PluginParameterType.NUMBER.value
  185. BOOLEAN = PluginParameterType.BOOLEAN.value
  186. SELECT = PluginParameterType.SELECT.value
  187. SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
  188. FILE = PluginParameterType.FILE.value
  189. FILES = PluginParameterType.FILES.value
  190. APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
  191. MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
  192. # deprecated, should not use.
  193. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
  194. def as_normal_type(self):
  195. return as_normal_type(self)
  196. def cast_value(self, value: Any):
  197. return cast_parameter_value(self, value)
  198. class ToolParameterForm(Enum):
  199. SCHEMA = "schema" # should be set while adding tool
  200. FORM = "form" # should be set before invoking tool
  201. LLM = "llm" # will be set by LLM
  202. type: ToolParameterType = Field(..., description="The type of the parameter")
  203. human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
  204. form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
  205. llm_description: Optional[str] = None
  206. @classmethod
  207. def get_simple_instance(
  208. cls,
  209. name: str,
  210. llm_description: str,
  211. typ: ToolParameterType,
  212. required: bool,
  213. options: Optional[list[str]] = None,
  214. ) -> "ToolParameter":
  215. """
  216. get a simple tool parameter
  217. :param name: the name of the parameter
  218. :param llm_description: the description presented to the LLM
  219. :param type: the type of the parameter
  220. :param required: if the parameter is required
  221. :param options: the options of the parameter
  222. """
  223. # convert options to ToolParameterOption
  224. # FIXME fix the type error
  225. if options:
  226. option_objs = [
  227. PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
  228. for option in options
  229. ]
  230. else:
  231. option_objs = []
  232. return cls(
  233. name=name,
  234. label=I18nObject(en_US="", zh_Hans=""),
  235. placeholder=None,
  236. human_description=I18nObject(en_US="", zh_Hans=""),
  237. type=typ,
  238. form=cls.ToolParameterForm.LLM,
  239. llm_description=llm_description,
  240. required=required,
  241. options=option_objs,
  242. )
  243. def init_frontend_parameter(self, value: Any):
  244. return init_frontend_parameter(self, self.type, value)
  245. class ToolProviderIdentity(BaseModel):
  246. author: str = Field(..., description="The author of the tool")
  247. name: str = Field(..., description="The name of the tool")
  248. description: I18nObject = Field(..., description="The description of the tool")
  249. icon: str = Field(..., description="The icon of the tool")
  250. label: I18nObject = Field(..., description="The label of the tool")
  251. tags: Optional[list[ToolLabelEnum]] = Field(
  252. default=[],
  253. description="The tags of the tool",
  254. )
  255. class ToolIdentity(BaseModel):
  256. author: str = Field(..., description="The author of the tool")
  257. name: str = Field(..., description="The name of the tool")
  258. label: I18nObject = Field(..., description="The label of the tool")
  259. provider: str = Field(..., description="The provider of the tool")
  260. icon: Optional[str] = None
  261. class ToolDescription(BaseModel):
  262. human: I18nObject = Field(..., description="The description presented to the user")
  263. llm: str = Field(..., description="The description presented to the LLM")
  264. class ToolEntity(BaseModel):
  265. identity: ToolIdentity
  266. parameters: list[ToolParameter] = Field(default_factory=list)
  267. description: Optional[ToolDescription] = None
  268. output_schema: Optional[dict] = None
  269. has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
  270. # pydantic configs
  271. model_config = ConfigDict(protected_namespaces=())
  272. @field_validator("parameters", mode="before")
  273. @classmethod
  274. def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
  275. return v or []
  276. class ToolProviderEntity(BaseModel):
  277. identity: ToolProviderIdentity
  278. plugin_id: Optional[str] = None
  279. credentials_schema: list[ProviderConfig] = Field(default_factory=list)
  280. class ToolProviderEntityWithPlugin(ToolProviderEntity):
  281. tools: list[ToolEntity] = Field(default_factory=list)
  282. class WorkflowToolParameterConfiguration(BaseModel):
  283. """
  284. Workflow tool configuration
  285. """
  286. name: str = Field(..., description="The name of the parameter")
  287. description: str = Field(..., description="The description of the parameter")
  288. form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
  289. class ToolInvokeMeta(BaseModel):
  290. """
  291. Tool invoke meta
  292. """
  293. time_cost: float = Field(..., description="The time cost of the tool invoke")
  294. error: Optional[str] = None
  295. tool_config: Optional[dict] = None
  296. @classmethod
  297. def empty(cls) -> "ToolInvokeMeta":
  298. """
  299. Get an empty instance of ToolInvokeMeta
  300. """
  301. return cls(time_cost=0.0, error=None, tool_config={})
  302. @classmethod
  303. def error_instance(cls, error: str) -> "ToolInvokeMeta":
  304. """
  305. Get an instance of ToolInvokeMeta with error
  306. """
  307. return cls(time_cost=0.0, error=error, tool_config={})
  308. def to_dict(self) -> dict:
  309. return {
  310. "time_cost": self.time_cost,
  311. "error": self.error,
  312. "tool_config": self.tool_config,
  313. }
  314. class ToolLabel(BaseModel):
  315. """
  316. Tool label
  317. """
  318. name: str = Field(..., description="The name of the tool")
  319. label: I18nObject = Field(..., description="The label of the tool")
  320. icon: str = Field(..., description="The icon of the tool")
  321. class ToolInvokeFrom(Enum):
  322. """
  323. Enum class for tool invoke
  324. """
  325. WORKFLOW = "workflow"
  326. AGENT = "agent"
  327. PLUGIN = "plugin"
  328. class ToolSelector(BaseModel):
  329. dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
  330. class Parameter(BaseModel):
  331. name: str = Field(..., description="The name of the parameter")
  332. type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
  333. required: bool = Field(..., description="Whether the parameter is required")
  334. description: str = Field(..., description="The description of the parameter")
  335. default: Optional[Union[int, float, str]] = None
  336. options: Optional[list[PluginParameterOption]] = None
  337. provider_id: str = Field(..., description="The id of the provider")
  338. tool_name: str = Field(..., description="The name of the tool")
  339. tool_description: str = Field(..., description="The description of the tool")
  340. tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
  341. tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
  342. def to_plugin_parameter(self) -> dict[str, Any]:
  343. return self.model_dump()