model_provider_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. import logging
  2. from typing import Optional
  3. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
  4. from core.model_runtime.entities.model_entities import ModelType, ParameterRule
  5. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  6. from core.provider_manager import ProviderManager
  7. from models.provider import ProviderType
  8. from services.entities.model_provider_entities import (
  9. CustomConfigurationResponse,
  10. CustomConfigurationStatus,
  11. DefaultModelResponse,
  12. ModelWithProviderEntityResponse,
  13. ProviderResponse,
  14. ProviderWithModelsResponse,
  15. SimpleProviderEntityResponse,
  16. SystemConfigurationResponse,
  17. )
  18. logger = logging.getLogger(__name__)
  19. class ModelProviderService:
  20. """
  21. Model Provider Service
  22. """
  23. def __init__(self) -> None:
  24. self.provider_manager = ProviderManager()
  25. def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
  26. """
  27. get provider list.
  28. :param tenant_id: workspace id
  29. :param model_type: model type
  30. :return:
  31. """
  32. # Get all provider configurations of the current workspace
  33. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  34. provider_responses = []
  35. for provider_configuration in provider_configurations.values():
  36. if model_type:
  37. model_type_entity = ModelType.value_of(model_type)
  38. if model_type_entity not in provider_configuration.provider.supported_model_types:
  39. continue
  40. provider_response = ProviderResponse(
  41. tenant_id=tenant_id,
  42. provider=provider_configuration.provider.provider,
  43. label=provider_configuration.provider.label,
  44. description=provider_configuration.provider.description,
  45. icon_small=provider_configuration.provider.icon_small,
  46. icon_large=provider_configuration.provider.icon_large,
  47. background=provider_configuration.provider.background,
  48. help=provider_configuration.provider.help,
  49. supported_model_types=provider_configuration.provider.supported_model_types,
  50. configurate_methods=provider_configuration.provider.configurate_methods,
  51. provider_credential_schema=provider_configuration.provider.provider_credential_schema,
  52. model_credential_schema=provider_configuration.provider.model_credential_schema,
  53. preferred_provider_type=provider_configuration.preferred_provider_type,
  54. custom_configuration=CustomConfigurationResponse(
  55. status=CustomConfigurationStatus.ACTIVE
  56. if provider_configuration.is_custom_configuration_available()
  57. else CustomConfigurationStatus.NO_CONFIGURE
  58. ),
  59. system_configuration=SystemConfigurationResponse(
  60. enabled=provider_configuration.system_configuration.enabled,
  61. current_quota_type=provider_configuration.system_configuration.current_quota_type,
  62. quota_configurations=provider_configuration.system_configuration.quota_configurations,
  63. ),
  64. )
  65. provider_responses.append(provider_response)
  66. return provider_responses
  67. def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
  68. """
  69. get provider models.
  70. For the model provider page,
  71. only supports passing in a single provider to query the list of supported models.
  72. :param tenant_id:
  73. :param provider:
  74. :return:
  75. """
  76. # Get all provider configurations of the current workspace
  77. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  78. # Get provider available models
  79. return [
  80. ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
  81. for model in provider_configurations.get_models(provider=provider)
  82. ]
  83. def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
  84. """
  85. get provider credentials.
  86. """
  87. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  88. provider_configuration = provider_configurations.get(provider)
  89. if not provider_configuration:
  90. raise ValueError(f"Provider {provider} does not exist.")
  91. return provider_configuration.get_custom_credentials(obfuscated=True)
  92. def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
  93. """
  94. validate provider credentials.
  95. :param tenant_id:
  96. :param provider:
  97. :param credentials:
  98. """
  99. # Get all provider configurations of the current workspace
  100. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  101. # Get provider configuration
  102. provider_configuration = provider_configurations.get(provider)
  103. if not provider_configuration:
  104. raise ValueError(f"Provider {provider} does not exist.")
  105. provider_configuration.custom_credentials_validate(credentials)
  106. def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
  107. """
  108. save custom provider config.
  109. :param tenant_id: workspace id
  110. :param provider: provider name
  111. :param credentials: provider credentials
  112. :return:
  113. """
  114. # Get all provider configurations of the current workspace
  115. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  116. # Get provider configuration
  117. provider_configuration = provider_configurations.get(provider)
  118. if not provider_configuration:
  119. raise ValueError(f"Provider {provider} does not exist.")
  120. # Add or update custom provider credentials.
  121. provider_configuration.add_or_update_custom_credentials(credentials)
  122. def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
  123. """
  124. remove custom provider config.
  125. :param tenant_id: workspace id
  126. :param provider: provider name
  127. :return:
  128. """
  129. # Get all provider configurations of the current workspace
  130. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  131. # Get provider configuration
  132. provider_configuration = provider_configurations.get(provider)
  133. if not provider_configuration:
  134. raise ValueError(f"Provider {provider} does not exist.")
  135. # Remove custom provider credentials.
  136. provider_configuration.delete_custom_credentials()
  137. def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
  138. """
  139. get model credentials.
  140. :param tenant_id: workspace id
  141. :param provider: provider name
  142. :param model_type: model type
  143. :param model: model name
  144. :return:
  145. """
  146. # Get all provider configurations of the current workspace
  147. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  148. # Get provider configuration
  149. provider_configuration = provider_configurations.get(provider)
  150. if not provider_configuration:
  151. raise ValueError(f"Provider {provider} does not exist.")
  152. # Get model custom credentials from ProviderModel if exists
  153. return provider_configuration.get_custom_model_credentials(
  154. model_type=ModelType.value_of(model_type), model=model, obfuscated=True
  155. )
  156. def model_credentials_validate(
  157. self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
  158. ) -> None:
  159. """
  160. validate model credentials.
  161. :param tenant_id: workspace id
  162. :param provider: provider name
  163. :param model_type: model type
  164. :param model: model name
  165. :param credentials: model credentials
  166. :return:
  167. """
  168. # Get all provider configurations of the current workspace
  169. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  170. # Get provider configuration
  171. provider_configuration = provider_configurations.get(provider)
  172. if not provider_configuration:
  173. raise ValueError(f"Provider {provider} does not exist.")
  174. # Validate model credentials
  175. provider_configuration.custom_model_credentials_validate(
  176. model_type=ModelType.value_of(model_type), model=model, credentials=credentials
  177. )
  178. def save_model_credentials(
  179. self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
  180. ) -> None:
  181. """
  182. save model credentials.
  183. :param tenant_id: workspace id
  184. :param provider: provider name
  185. :param model_type: model type
  186. :param model: model name
  187. :param credentials: model credentials
  188. :return:
  189. """
  190. # Get all provider configurations of the current workspace
  191. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  192. # Get provider configuration
  193. provider_configuration = provider_configurations.get(provider)
  194. if not provider_configuration:
  195. raise ValueError(f"Provider {provider} does not exist.")
  196. # Add or update custom model credentials
  197. provider_configuration.add_or_update_custom_model_credentials(
  198. model_type=ModelType.value_of(model_type), model=model, credentials=credentials
  199. )
  200. def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
  201. """
  202. remove model credentials.
  203. :param tenant_id: workspace id
  204. :param provider: provider name
  205. :param model_type: model type
  206. :param model: model name
  207. :return:
  208. """
  209. # Get all provider configurations of the current workspace
  210. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  211. # Get provider configuration
  212. provider_configuration = provider_configurations.get(provider)
  213. if not provider_configuration:
  214. raise ValueError(f"Provider {provider} does not exist.")
  215. # Remove custom model credentials
  216. provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
  217. def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
  218. """
  219. get models by model type.
  220. :param tenant_id: workspace id
  221. :param model_type: model type
  222. :return:
  223. """
  224. # Get all provider configurations of the current workspace
  225. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  226. # Get provider available models
  227. models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
  228. # Group models by provider
  229. provider_models: dict[str, list[ModelWithProviderEntity]] = {}
  230. for model in models:
  231. if model.provider.provider not in provider_models:
  232. provider_models[model.provider.provider] = []
  233. if model.deprecated:
  234. continue
  235. if model.status != ModelStatus.ACTIVE:
  236. continue
  237. provider_models[model.provider.provider].append(model)
  238. # convert to ProviderWithModelsResponse list
  239. providers_with_models: list[ProviderWithModelsResponse] = []
  240. for provider, models in provider_models.items():
  241. if not models:
  242. continue
  243. first_model = models[0]
  244. providers_with_models.append(
  245. ProviderWithModelsResponse(
  246. tenant_id=tenant_id,
  247. provider=provider,
  248. label=first_model.provider.label,
  249. icon_small=first_model.provider.icon_small,
  250. icon_large=first_model.provider.icon_large,
  251. status=CustomConfigurationStatus.ACTIVE,
  252. models=[
  253. ProviderModelWithStatusEntity(
  254. model=model.model,
  255. label=model.label,
  256. model_type=model.model_type,
  257. features=model.features,
  258. fetch_from=model.fetch_from,
  259. model_properties=model.model_properties,
  260. status=model.status,
  261. load_balancing_enabled=model.load_balancing_enabled,
  262. )
  263. for model in models
  264. ],
  265. )
  266. )
  267. return providers_with_models
  268. def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
  269. """
  270. get model parameter rules.
  271. Only supports LLM.
  272. :param tenant_id: workspace id
  273. :param provider: provider name
  274. :param model: model name
  275. :return:
  276. """
  277. # Get all provider configurations of the current workspace
  278. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  279. # Get provider configuration
  280. provider_configuration = provider_configurations.get(provider)
  281. if not provider_configuration:
  282. raise ValueError(f"Provider {provider} does not exist.")
  283. # fetch credentials
  284. credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
  285. if not credentials:
  286. return []
  287. model_schema = provider_configuration.get_model_schema(
  288. model_type=ModelType.LLM, model=model, credentials=credentials
  289. )
  290. return model_schema.parameter_rules if model_schema else []
  291. def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
  292. """
  293. get default model of model type.
  294. :param tenant_id: workspace id
  295. :param model_type: model type
  296. :return:
  297. """
  298. model_type_enum = ModelType.value_of(model_type)
  299. try:
  300. result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
  301. return (
  302. DefaultModelResponse(
  303. model=result.model,
  304. model_type=result.model_type,
  305. provider=SimpleProviderEntityResponse(
  306. tenant_id=tenant_id,
  307. provider=result.provider.provider,
  308. label=result.provider.label,
  309. icon_small=result.provider.icon_small,
  310. icon_large=result.provider.icon_large,
  311. supported_model_types=result.provider.supported_model_types,
  312. ),
  313. )
  314. if result
  315. else None
  316. )
  317. except Exception as e:
  318. logger.debug(f"get_default_model_of_model_type error: {e}")
  319. return None
  320. def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
  321. """
  322. update default model of model type.
  323. :param tenant_id: workspace id
  324. :param model_type: model type
  325. :param provider: provider name
  326. :param model: model name
  327. :return:
  328. """
  329. model_type_enum = ModelType.value_of(model_type)
  330. self.provider_manager.update_default_model_record(
  331. tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
  332. )
  333. def get_model_provider_icon(
  334. self, tenant_id: str, provider: str, icon_type: str, lang: str
  335. ) -> tuple[Optional[bytes], Optional[str]]:
  336. """
  337. get model provider icon.
  338. :param tenant_id: workspace id
  339. :param provider: provider name
  340. :param icon_type: icon type (icon_small or icon_large)
  341. :param lang: language (zh_Hans or en_US)
  342. :return:
  343. """
  344. model_provider_factory = ModelProviderFactory(tenant_id)
  345. byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang)
  346. return byte_data, mime_type
  347. def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
  348. """
  349. switch preferred provider.
  350. :param tenant_id: workspace id
  351. :param provider: provider name
  352. :param preferred_provider_type: preferred provider type
  353. :return:
  354. """
  355. # Get all provider configurations of the current workspace
  356. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  357. # Convert preferred_provider_type to ProviderType
  358. preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
  359. # Get provider configuration
  360. provider_configuration = provider_configurations.get(provider)
  361. if not provider_configuration:
  362. raise ValueError(f"Provider {provider} does not exist.")
  363. # Switch preferred provider type
  364. provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
  365. def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
  366. """
  367. enable model.
  368. :param tenant_id: workspace id
  369. :param provider: provider name
  370. :param model: model name
  371. :param model_type: model type
  372. :return:
  373. """
  374. # Get all provider configurations of the current workspace
  375. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  376. # Get provider configuration
  377. provider_configuration = provider_configurations.get(provider)
  378. if not provider_configuration:
  379. raise ValueError(f"Provider {provider} does not exist.")
  380. # Enable model
  381. provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
  382. def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
  383. """
  384. disable model.
  385. :param tenant_id: workspace id
  386. :param provider: provider name
  387. :param model: model name
  388. :param model_type: model type
  389. :return:
  390. """
  391. # Get all provider configurations of the current workspace
  392. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  393. # Get provider configuration
  394. provider_configuration = provider_configurations.get(provider)
  395. if not provider_configuration:
  396. raise ValueError(f"Provider {provider} does not exist.")
  397. # Enable model
  398. provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))