api_entities.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from typing import Literal, Optional
  2. from pydantic import BaseModel, Field, field_validator
  3. from core.model_runtime.utils.encoders import jsonable_encoder
  4. from core.tools.entities.common_entities import I18nObject
  5. from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
  6. from core.tools.tool.tool import ToolParameter
  7. class UserTool(BaseModel):
  8. author: str
  9. name: str # identifier
  10. label: I18nObject # label
  11. description: I18nObject
  12. parameters: Optional[list[ToolParameter]] = None
  13. labels: list[str] | None = None
  14. UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
  15. class UserToolProvider(BaseModel):
  16. id: str
  17. author: str
  18. name: str # identifier
  19. description: I18nObject
  20. icon: str
  21. label: I18nObject # label
  22. type: ToolProviderType
  23. masked_credentials: Optional[dict] = None
  24. original_credentials: Optional[dict] = None
  25. is_team_authorization: bool = False
  26. allow_delete: bool = True
  27. tools: list[UserTool] = Field(default_factory=list)
  28. labels: list[str] | None = None
  29. @field_validator("tools", mode="before")
  30. @classmethod
  31. def convert_none_to_empty_list(cls, v):
  32. return v if v is not None else []
  33. def to_dict(self) -> dict:
  34. # -------------
  35. # overwrite tool parameter types for temp fix
  36. tools = jsonable_encoder(self.tools)
  37. for tool in tools:
  38. if tool.get("parameters"):
  39. for parameter in tool.get("parameters"):
  40. if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
  41. parameter["type"] = "files"
  42. # -------------
  43. return {
  44. "id": self.id,
  45. "author": self.author,
  46. "name": self.name,
  47. "description": self.description.to_dict(),
  48. "icon": self.icon,
  49. "label": self.label.to_dict(),
  50. "type": self.type.value,
  51. "team_credentials": self.masked_credentials,
  52. "is_team_authorization": self.is_team_authorization,
  53. "allow_delete": self.allow_delete,
  54. "tools": tools,
  55. "labels": self.labels,
  56. }
  57. class UserToolProviderCredentials(BaseModel):
  58. credentials: dict[str, ToolProviderCredentials]