provider_configuration.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from constants import HIDDEN_VALUE
  10. from core.entities import DEFAULT_PLUGIN_ID
  11. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  12. from core.entities.provider_entities import (
  13. CustomConfiguration,
  14. ModelSettings,
  15. SystemConfiguration,
  16. SystemConfigurationStatus,
  17. )
  18. from core.helper import encrypter
  19. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  20. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  21. from core.model_runtime.entities.provider_entities import (
  22. ConfigurateMethod,
  23. CredentialFormSchema,
  24. FormType,
  25. ProviderEntity,
  26. )
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  29. from extensions.ext_database import db
  30. from models.provider import (
  31. LoadBalancingModelConfig,
  32. Provider,
  33. ProviderModel,
  34. ProviderModelSetting,
  35. ProviderType,
  36. TenantPreferredModelProvider,
  37. )
  38. logger = logging.getLogger(__name__)
  39. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  40. class ProviderConfiguration(BaseModel):
  41. """
  42. Model class for provider configuration.
  43. """
  44. tenant_id: str
  45. provider: ProviderEntity
  46. preferred_provider_type: ProviderType
  47. using_provider_type: ProviderType
  48. system_configuration: SystemConfiguration
  49. custom_configuration: CustomConfiguration
  50. model_settings: list[ModelSettings]
  51. # pydantic configs
  52. model_config = ConfigDict(protected_namespaces=())
  53. def __init__(self, **data):
  54. super().__init__(**data)
  55. if self.provider.provider not in original_provider_configurate_methods:
  56. original_provider_configurate_methods[self.provider.provider] = []
  57. for configurate_method in self.provider.configurate_methods:
  58. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  59. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  60. if (
  61. any(
  62. len(quota_configuration.restrict_models) > 0
  63. for quota_configuration in self.system_configuration.quota_configurations
  64. )
  65. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  66. ):
  67. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  68. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  69. """
  70. Get current credentials.
  71. :param model_type: model type
  72. :param model: model name
  73. :return:
  74. """
  75. if self.model_settings:
  76. # check if model is disabled by admin
  77. for model_setting in self.model_settings:
  78. if model_setting.model_type == model_type and model_setting.model == model:
  79. if not model_setting.enabled:
  80. raise ValueError(f"Model {model} is disabled.")
  81. if self.using_provider_type == ProviderType.SYSTEM:
  82. restrict_models = []
  83. for quota_configuration in self.system_configuration.quota_configurations:
  84. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  85. continue
  86. restrict_models = quota_configuration.restrict_models
  87. copy_credentials = (
  88. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  89. )
  90. if restrict_models:
  91. for restrict_model in restrict_models:
  92. if (
  93. restrict_model.model_type == model_type
  94. and restrict_model.model == model
  95. and restrict_model.base_model_name
  96. ):
  97. copy_credentials["base_model_name"] = restrict_model.base_model_name
  98. return copy_credentials
  99. else:
  100. credentials = None
  101. if self.custom_configuration.models:
  102. for model_configuration in self.custom_configuration.models:
  103. if model_configuration.model_type == model_type and model_configuration.model == model:
  104. credentials = model_configuration.credentials
  105. break
  106. if not credentials and self.custom_configuration.provider:
  107. credentials = self.custom_configuration.provider.credentials
  108. return credentials
  109. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  110. """
  111. Get system configuration status.
  112. :return:
  113. """
  114. if self.system_configuration.enabled is False:
  115. return SystemConfigurationStatus.UNSUPPORTED
  116. current_quota_type = self.system_configuration.current_quota_type
  117. current_quota_configuration = next(
  118. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  119. )
  120. if current_quota_configuration is None:
  121. return None
  122. if not current_quota_configuration:
  123. return SystemConfigurationStatus.UNSUPPORTED
  124. return (
  125. SystemConfigurationStatus.ACTIVE
  126. if current_quota_configuration.is_valid
  127. else SystemConfigurationStatus.QUOTA_EXCEEDED
  128. )
  129. def is_custom_configuration_available(self) -> bool:
  130. """
  131. Check custom configuration available.
  132. :return:
  133. """
  134. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  135. def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
  136. """
  137. Get custom credentials.
  138. :param obfuscated: obfuscated secret data in credentials
  139. :return:
  140. """
  141. if self.custom_configuration.provider is None:
  142. return None
  143. credentials = self.custom_configuration.provider.credentials
  144. if not obfuscated:
  145. return credentials
  146. # Obfuscate credentials
  147. return self.obfuscated_credentials(
  148. credentials=credentials,
  149. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  150. if self.provider.provider_credential_schema
  151. else [],
  152. )
  153. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
  154. """
  155. Validate custom credentials.
  156. :param credentials: provider credentials
  157. :return:
  158. """
  159. # get provider
  160. provider_record = (
  161. db.session.query(Provider)
  162. .filter(
  163. Provider.tenant_id == self.tenant_id,
  164. Provider.provider_name == self.provider.provider,
  165. Provider.provider_type == ProviderType.CUSTOM.value,
  166. )
  167. .first()
  168. )
  169. # Get provider credential secret variables
  170. provider_credential_secret_variables = self.extract_secret_variables(
  171. self.provider.provider_credential_schema.credential_form_schemas
  172. if self.provider.provider_credential_schema
  173. else []
  174. )
  175. if provider_record:
  176. try:
  177. # fix origin data
  178. if provider_record.encrypted_config:
  179. if not provider_record.encrypted_config.startswith("{"):
  180. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  181. else:
  182. original_credentials = json.loads(provider_record.encrypted_config)
  183. else:
  184. original_credentials = {}
  185. except JSONDecodeError:
  186. original_credentials = {}
  187. # encrypt credentials
  188. for key, value in credentials.items():
  189. if key in provider_credential_secret_variables:
  190. # if send [__HIDDEN__] in secret input, it will be same as original value
  191. if value == HIDDEN_VALUE and key in original_credentials:
  192. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  193. model_provider_factory = ModelProviderFactory(self.tenant_id)
  194. credentials = model_provider_factory.provider_credentials_validate(
  195. provider=self.provider.provider, credentials=credentials
  196. )
  197. for key, value in credentials.items():
  198. if key in provider_credential_secret_variables:
  199. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  200. return provider_record, credentials
  201. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  202. """
  203. Add or update custom provider credentials.
  204. :param credentials:
  205. :return:
  206. """
  207. # validate custom provider config
  208. provider_record, credentials = self.custom_credentials_validate(credentials)
  209. # save provider
  210. # Note: Do not switch the preferred provider, which allows users to use quotas first
  211. if provider_record:
  212. provider_record.encrypted_config = json.dumps(credentials)
  213. provider_record.is_valid = True
  214. provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  215. db.session.commit()
  216. else:
  217. provider_record = Provider()
  218. provider_record.tenant_id = self.tenant_id
  219. provider_record.provider_name = self.provider.provider
  220. provider_record.provider_type = ProviderType.CUSTOM.value
  221. provider_record.encrypted_config = json.dumps(credentials)
  222. provider_record.is_valid = True
  223. db.session.add(provider_record)
  224. db.session.commit()
  225. provider_model_credentials_cache = ProviderCredentialsCache(
  226. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  227. )
  228. provider_model_credentials_cache.delete()
  229. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  230. def delete_custom_credentials(self) -> None:
  231. """
  232. Delete custom provider credentials.
  233. :return:
  234. """
  235. # get provider
  236. provider_record = (
  237. db.session.query(Provider)
  238. .filter(
  239. Provider.tenant_id == self.tenant_id,
  240. Provider.provider_name == self.provider.provider,
  241. Provider.provider_type == ProviderType.CUSTOM.value,
  242. )
  243. .first()
  244. )
  245. # delete provider
  246. if provider_record:
  247. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  248. db.session.delete(provider_record)
  249. db.session.commit()
  250. provider_model_credentials_cache = ProviderCredentialsCache(
  251. tenant_id=self.tenant_id,
  252. identity_id=provider_record.id,
  253. cache_type=ProviderCredentialsCacheType.PROVIDER,
  254. )
  255. provider_model_credentials_cache.delete()
  256. def get_custom_model_credentials(
  257. self, model_type: ModelType, model: str, obfuscated: bool = False
  258. ) -> Optional[dict]:
  259. """
  260. Get custom model credentials.
  261. :param model_type: model type
  262. :param model: model name
  263. :param obfuscated: obfuscated secret data in credentials
  264. :return:
  265. """
  266. if not self.custom_configuration.models:
  267. return None
  268. for model_configuration in self.custom_configuration.models:
  269. if model_configuration.model_type == model_type and model_configuration.model == model:
  270. credentials = model_configuration.credentials
  271. if not obfuscated:
  272. return credentials
  273. # Obfuscate credentials
  274. return self.obfuscated_credentials(
  275. credentials=credentials,
  276. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  277. if self.provider.model_credential_schema
  278. else [],
  279. )
  280. return None
  281. def custom_model_credentials_validate(
  282. self, model_type: ModelType, model: str, credentials: dict
  283. ) -> tuple[ProviderModel | None, dict]:
  284. """
  285. Validate custom model credentials.
  286. :param model_type: model type
  287. :param model: model name
  288. :param credentials: model credentials
  289. :return:
  290. """
  291. # get provider model
  292. provider_model_record = (
  293. db.session.query(ProviderModel)
  294. .filter(
  295. ProviderModel.tenant_id == self.tenant_id,
  296. ProviderModel.provider_name == self.provider.provider,
  297. ProviderModel.model_name == model,
  298. ProviderModel.model_type == model_type.to_origin_model_type(),
  299. )
  300. .first()
  301. )
  302. # Get provider credential secret variables
  303. provider_credential_secret_variables = self.extract_secret_variables(
  304. self.provider.model_credential_schema.credential_form_schemas
  305. if self.provider.model_credential_schema
  306. else []
  307. )
  308. if provider_model_record:
  309. try:
  310. original_credentials = (
  311. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  312. )
  313. except JSONDecodeError:
  314. original_credentials = {}
  315. # decrypt credentials
  316. for key, value in credentials.items():
  317. if key in provider_credential_secret_variables:
  318. # if send [__HIDDEN__] in secret input, it will be same as original value
  319. if value == HIDDEN_VALUE and key in original_credentials:
  320. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  321. model_provider_factory = ModelProviderFactory(self.tenant_id)
  322. credentials = model_provider_factory.model_credentials_validate(
  323. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  324. )
  325. for key, value in credentials.items():
  326. if key in provider_credential_secret_variables:
  327. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  328. return provider_model_record, credentials
  329. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  330. """
  331. Add or update custom model credentials.
  332. :param model_type: model type
  333. :param model: model name
  334. :param credentials: model credentials
  335. :return:
  336. """
  337. # validate custom model config
  338. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  339. # save provider model
  340. # Note: Do not switch the preferred provider, which allows users to use quotas first
  341. if provider_model_record:
  342. provider_model_record.encrypted_config = json.dumps(credentials)
  343. provider_model_record.is_valid = True
  344. provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  345. db.session.commit()
  346. else:
  347. provider_model_record = ProviderModel()
  348. provider_model_record.tenant_id = self.tenant_id
  349. provider_model_record.provider_name = self.provider.provider
  350. provider_model_record.model_name = model
  351. provider_model_record.model_type = model_type.to_origin_model_type()
  352. provider_model_record.encrypted_config = json.dumps(credentials)
  353. provider_model_record.is_valid = True
  354. db.session.add(provider_model_record)
  355. db.session.commit()
  356. provider_model_credentials_cache = ProviderCredentialsCache(
  357. tenant_id=self.tenant_id,
  358. identity_id=provider_model_record.id,
  359. cache_type=ProviderCredentialsCacheType.MODEL,
  360. )
  361. provider_model_credentials_cache.delete()
  362. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  363. """
  364. Delete custom model credentials.
  365. :param model_type: model type
  366. :param model: model name
  367. :return:
  368. """
  369. # get provider model
  370. provider_model_record = (
  371. db.session.query(ProviderModel)
  372. .filter(
  373. ProviderModel.tenant_id == self.tenant_id,
  374. ProviderModel.provider_name == self.provider.provider,
  375. ProviderModel.model_name == model,
  376. ProviderModel.model_type == model_type.to_origin_model_type(),
  377. )
  378. .first()
  379. )
  380. # delete provider model
  381. if provider_model_record:
  382. db.session.delete(provider_model_record)
  383. db.session.commit()
  384. provider_model_credentials_cache = ProviderCredentialsCache(
  385. tenant_id=self.tenant_id,
  386. identity_id=provider_model_record.id,
  387. cache_type=ProviderCredentialsCacheType.MODEL,
  388. )
  389. provider_model_credentials_cache.delete()
  390. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  391. """
  392. Enable model.
  393. :param model_type: model type
  394. :param model: model name
  395. :return:
  396. """
  397. model_setting = (
  398. db.session.query(ProviderModelSetting)
  399. .filter(
  400. ProviderModelSetting.tenant_id == self.tenant_id,
  401. ProviderModelSetting.provider_name == self.provider.provider,
  402. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  403. ProviderModelSetting.model_name == model,
  404. )
  405. .first()
  406. )
  407. if model_setting:
  408. model_setting.enabled = True
  409. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  410. db.session.commit()
  411. else:
  412. model_setting = ProviderModelSetting()
  413. model_setting.tenant_id = self.tenant_id
  414. model_setting.provider_name = self.provider.provider
  415. model_setting.model_type = model_type.to_origin_model_type()
  416. model_setting.model_name = model
  417. model_setting.enabled = True
  418. db.session.add(model_setting)
  419. db.session.commit()
  420. return model_setting
  421. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  422. """
  423. Disable model.
  424. :param model_type: model type
  425. :param model: model name
  426. :return:
  427. """
  428. model_setting = (
  429. db.session.query(ProviderModelSetting)
  430. .filter(
  431. ProviderModelSetting.tenant_id == self.tenant_id,
  432. ProviderModelSetting.provider_name == self.provider.provider,
  433. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  434. ProviderModelSetting.model_name == model,
  435. )
  436. .first()
  437. )
  438. if model_setting:
  439. model_setting.enabled = False
  440. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  441. db.session.commit()
  442. else:
  443. model_setting = ProviderModelSetting()
  444. model_setting.tenant_id = self.tenant_id
  445. model_setting.provider_name = self.provider.provider
  446. model_setting.model_type = model_type.to_origin_model_type()
  447. model_setting.model_name = model
  448. model_setting.enabled = False
  449. db.session.add(model_setting)
  450. db.session.commit()
  451. return model_setting
  452. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  453. """
  454. Get provider model setting.
  455. :param model_type: model type
  456. :param model: model name
  457. :return:
  458. """
  459. return (
  460. db.session.query(ProviderModelSetting)
  461. .filter(
  462. ProviderModelSetting.tenant_id == self.tenant_id,
  463. ProviderModelSetting.provider_name == self.provider.provider,
  464. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  465. ProviderModelSetting.model_name == model,
  466. )
  467. .first()
  468. )
  469. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  470. """
  471. Enable model load balancing.
  472. :param model_type: model type
  473. :param model: model name
  474. :return:
  475. """
  476. load_balancing_config_count = (
  477. db.session.query(LoadBalancingModelConfig)
  478. .filter(
  479. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  480. LoadBalancingModelConfig.provider_name == self.provider.provider,
  481. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  482. LoadBalancingModelConfig.model_name == model,
  483. )
  484. .count()
  485. )
  486. if load_balancing_config_count <= 1:
  487. raise ValueError("Model load balancing configuration must be more than 1.")
  488. model_setting = (
  489. db.session.query(ProviderModelSetting)
  490. .filter(
  491. ProviderModelSetting.tenant_id == self.tenant_id,
  492. ProviderModelSetting.provider_name == self.provider.provider,
  493. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  494. ProviderModelSetting.model_name == model,
  495. )
  496. .first()
  497. )
  498. if model_setting:
  499. model_setting.load_balancing_enabled = True
  500. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  501. db.session.commit()
  502. else:
  503. model_setting = ProviderModelSetting()
  504. model_setting.tenant_id = self.tenant_id
  505. model_setting.provider_name = self.provider.provider
  506. model_setting.model_type = model_type.to_origin_model_type()
  507. model_setting.model_name = model
  508. model_setting.load_balancing_enabled = True
  509. db.session.add(model_setting)
  510. db.session.commit()
  511. return model_setting
  512. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  513. """
  514. Disable model load balancing.
  515. :param model_type: model type
  516. :param model: model name
  517. :return:
  518. """
  519. model_setting = (
  520. db.session.query(ProviderModelSetting)
  521. .filter(
  522. ProviderModelSetting.tenant_id == self.tenant_id,
  523. ProviderModelSetting.provider_name == self.provider.provider,
  524. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  525. ProviderModelSetting.model_name == model,
  526. )
  527. .first()
  528. )
  529. if model_setting:
  530. model_setting.load_balancing_enabled = False
  531. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  532. db.session.commit()
  533. else:
  534. model_setting = ProviderModelSetting()
  535. model_setting.tenant_id = self.tenant_id
  536. model_setting.provider_name = self.provider.provider
  537. model_setting.model_type = model_type.to_origin_model_type()
  538. model_setting.model_name = model
  539. model_setting.load_balancing_enabled = False
  540. db.session.add(model_setting)
  541. db.session.commit()
  542. return model_setting
  543. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  544. """
  545. Get current model type instance.
  546. :param model_type: model type
  547. :return:
  548. """
  549. model_provider_factory = ModelProviderFactory(self.tenant_id)
  550. # Get model instance of LLM
  551. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  552. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
  553. """
  554. Get model schema
  555. """
  556. model_provider_factory = ModelProviderFactory(self.tenant_id)
  557. return model_provider_factory.get_model_schema(
  558. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  559. )
  560. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  561. """
  562. Switch preferred provider type.
  563. :param provider_type:
  564. :return:
  565. """
  566. if provider_type == self.preferred_provider_type:
  567. return
  568. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  569. return
  570. # get preferred provider
  571. preferred_model_provider = (
  572. db.session.query(TenantPreferredModelProvider)
  573. .filter(
  574. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  575. TenantPreferredModelProvider.provider_name == self.provider.provider,
  576. )
  577. .first()
  578. )
  579. if preferred_model_provider:
  580. preferred_model_provider.preferred_provider_type = provider_type.value
  581. else:
  582. preferred_model_provider = TenantPreferredModelProvider()
  583. preferred_model_provider.tenant_id = self.tenant_id
  584. preferred_model_provider.provider_name = self.provider.provider
  585. preferred_model_provider.preferred_provider_type = provider_type.value
  586. db.session.add(preferred_model_provider)
  587. db.session.commit()
  588. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  589. """
  590. Extract secret input form variables.
  591. :param credential_form_schemas:
  592. :return:
  593. """
  594. secret_input_form_variables = []
  595. for credential_form_schema in credential_form_schemas:
  596. if credential_form_schema.type == FormType.SECRET_INPUT:
  597. secret_input_form_variables.append(credential_form_schema.variable)
  598. return secret_input_form_variables
  599. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  600. """
  601. Obfuscated credentials.
  602. :param credentials: credentials
  603. :param credential_form_schemas: credential form schemas
  604. :return:
  605. """
  606. # Get provider credential secret variables
  607. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  608. # Obfuscate provider credentials
  609. copy_credentials = credentials.copy()
  610. for key, value in copy_credentials.items():
  611. if key in credential_secret_variables:
  612. copy_credentials[key] = encrypter.obfuscated_token(value)
  613. return copy_credentials
  614. def get_provider_model(
  615. self, model_type: ModelType, model: str, only_active: bool = False
  616. ) -> Optional[ModelWithProviderEntity]:
  617. """
  618. Get provider model.
  619. :param model_type: model type
  620. :param model: model name
  621. :param only_active: return active model only
  622. :return:
  623. """
  624. provider_models = self.get_provider_models(model_type, only_active)
  625. for provider_model in provider_models:
  626. if provider_model.model == model:
  627. return provider_model
  628. return None
  629. def get_provider_models(
  630. self, model_type: Optional[ModelType] = None, only_active: bool = False
  631. ) -> list[ModelWithProviderEntity]:
  632. """
  633. Get provider models.
  634. :param model_type: model type
  635. :param only_active: only active models
  636. :return:
  637. """
  638. model_provider_factory = ModelProviderFactory(self.tenant_id)
  639. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  640. model_types: list[ModelType] = []
  641. if model_type:
  642. model_types.append(model_type)
  643. else:
  644. model_types = list(provider_schema.supported_model_types)
  645. # Group model settings by model type and model
  646. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  647. for model_setting in self.model_settings:
  648. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  649. if self.using_provider_type == ProviderType.SYSTEM:
  650. provider_models = self._get_system_provider_models(
  651. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  652. )
  653. else:
  654. provider_models = self._get_custom_provider_models(
  655. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  656. )
  657. if only_active:
  658. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  659. # resort provider_models
  660. return sorted(provider_models, key=lambda x: x.model_type.value)
  661. def _get_system_provider_models(
  662. self,
  663. model_types: Sequence[ModelType],
  664. provider_schema: ProviderEntity,
  665. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  666. ) -> list[ModelWithProviderEntity]:
  667. """
  668. Get system provider models.
  669. :param model_types: model types
  670. :param provider_schema: provider schema
  671. :param model_setting_map: model setting map
  672. :return:
  673. """
  674. provider_models = []
  675. for model_type in model_types:
  676. for m in provider_schema.models:
  677. if m.model_type != model_type:
  678. continue
  679. status = ModelStatus.ACTIVE
  680. if m.model in model_setting_map:
  681. model_setting = model_setting_map[m.model_type][m.model]
  682. if model_setting.enabled is False:
  683. status = ModelStatus.DISABLED
  684. provider_models.append(
  685. ModelWithProviderEntity(
  686. model=m.model,
  687. label=m.label,
  688. model_type=m.model_type,
  689. features=m.features,
  690. fetch_from=m.fetch_from,
  691. model_properties=m.model_properties,
  692. deprecated=m.deprecated,
  693. provider=SimpleModelProviderEntity(self.provider),
  694. status=status,
  695. )
  696. )
  697. if self.provider.provider not in original_provider_configurate_methods:
  698. original_provider_configurate_methods[self.provider.provider] = []
  699. for configurate_method in provider_schema.configurate_methods:
  700. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  701. should_use_custom_model = False
  702. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  703. should_use_custom_model = True
  704. for quota_configuration in self.system_configuration.quota_configurations:
  705. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  706. continue
  707. restrict_models = quota_configuration.restrict_models
  708. if len(restrict_models) == 0:
  709. break
  710. if should_use_custom_model:
  711. if original_provider_configurate_methods[self.provider.provider] == [
  712. ConfigurateMethod.CUSTOMIZABLE_MODEL
  713. ]:
  714. # only customizable model
  715. for restrict_model in restrict_models:
  716. copy_credentials = (
  717. self.system_configuration.credentials.copy()
  718. if self.system_configuration.credentials
  719. else {}
  720. )
  721. if restrict_model.base_model_name:
  722. copy_credentials["base_model_name"] = restrict_model.base_model_name
  723. try:
  724. custom_model_schema = self.get_model_schema(
  725. model_type=restrict_model.model_type,
  726. model=restrict_model.model,
  727. credentials=copy_credentials,
  728. )
  729. except Exception as ex:
  730. logger.warning(f"get custom model schema failed, {ex}")
  731. if not custom_model_schema:
  732. continue
  733. if custom_model_schema.model_type not in model_types:
  734. continue
  735. status = ModelStatus.ACTIVE
  736. if (
  737. custom_model_schema.model_type in model_setting_map
  738. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  739. ):
  740. model_setting = model_setting_map[custom_model_schema.model_type][
  741. custom_model_schema.model
  742. ]
  743. if model_setting.enabled is False:
  744. status = ModelStatus.DISABLED
  745. provider_models.append(
  746. ModelWithProviderEntity(
  747. model=custom_model_schema.model,
  748. label=custom_model_schema.label,
  749. model_type=custom_model_schema.model_type,
  750. features=custom_model_schema.features,
  751. fetch_from=FetchFrom.PREDEFINED_MODEL,
  752. model_properties=custom_model_schema.model_properties,
  753. deprecated=custom_model_schema.deprecated,
  754. provider=SimpleModelProviderEntity(self.provider),
  755. status=status,
  756. )
  757. )
  758. # if llm name not in restricted llm list, remove it
  759. restrict_model_names = [rm.model for rm in restrict_models]
  760. for model in provider_models:
  761. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  762. model.status = ModelStatus.NO_PERMISSION
  763. elif not quota_configuration.is_valid:
  764. model.status = ModelStatus.QUOTA_EXCEEDED
  765. return provider_models
  766. def _get_custom_provider_models(
  767. self,
  768. model_types: Sequence[ModelType],
  769. provider_schema: ProviderEntity,
  770. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  771. ) -> list[ModelWithProviderEntity]:
  772. """
  773. Get custom provider models.
  774. :param model_types: model types
  775. :param provider_schema: provider schema
  776. :param model_setting_map: model setting map
  777. :return:
  778. """
  779. provider_models = []
  780. credentials = None
  781. if self.custom_configuration.provider:
  782. credentials = self.custom_configuration.provider.credentials
  783. for model_type in model_types:
  784. if model_type not in self.provider.supported_model_types:
  785. continue
  786. for m in provider_schema.models:
  787. if m.model_type != model_type:
  788. continue
  789. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  790. load_balancing_enabled = False
  791. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  792. model_setting = model_setting_map[m.model_type][m.model]
  793. if model_setting.enabled is False:
  794. status = ModelStatus.DISABLED
  795. if len(model_setting.load_balancing_configs) > 1:
  796. load_balancing_enabled = True
  797. provider_models.append(
  798. ModelWithProviderEntity(
  799. model=m.model,
  800. label=m.label,
  801. model_type=m.model_type,
  802. features=m.features,
  803. fetch_from=m.fetch_from,
  804. model_properties=m.model_properties,
  805. deprecated=m.deprecated,
  806. provider=SimpleModelProviderEntity(self.provider),
  807. status=status,
  808. load_balancing_enabled=load_balancing_enabled,
  809. )
  810. )
  811. # custom models
  812. for model_configuration in self.custom_configuration.models:
  813. if model_configuration.model_type not in model_types:
  814. continue
  815. try:
  816. custom_model_schema = self.get_model_schema(
  817. model_type=model_configuration.model_type,
  818. model=model_configuration.model,
  819. credentials=model_configuration.credentials,
  820. )
  821. except Exception as ex:
  822. logger.warning(f"get custom model schema failed, {ex}")
  823. continue
  824. if not custom_model_schema:
  825. continue
  826. status = ModelStatus.ACTIVE
  827. load_balancing_enabled = False
  828. if (
  829. custom_model_schema.model_type in model_setting_map
  830. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  831. ):
  832. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  833. if model_setting.enabled is False:
  834. status = ModelStatus.DISABLED
  835. if len(model_setting.load_balancing_configs) > 1:
  836. load_balancing_enabled = True
  837. provider_models.append(
  838. ModelWithProviderEntity(
  839. model=custom_model_schema.model,
  840. label=custom_model_schema.label,
  841. model_type=custom_model_schema.model_type,
  842. features=custom_model_schema.features,
  843. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  844. model_properties=custom_model_schema.model_properties,
  845. deprecated=custom_model_schema.deprecated,
  846. provider=SimpleModelProviderEntity(self.provider),
  847. status=status,
  848. load_balancing_enabled=load_balancing_enabled,
  849. )
  850. )
  851. return provider_models
  852. class ProviderConfigurations(BaseModel):
  853. """
  854. Model class for provider configuration dict.
  855. """
  856. tenant_id: str
  857. configurations: dict[str, ProviderConfiguration] = {}
  858. def __init__(self, tenant_id: str):
  859. super().__init__(tenant_id=tenant_id)
  860. def get_models(
  861. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  862. ) -> list[ModelWithProviderEntity]:
  863. """
  864. Get available models.
  865. If preferred provider type is `system`:
  866. Get the current **system mode** if provider supported,
  867. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  868. If there is no model configured in custom mode, it is treated as no_configure.
  869. system > custom > no_configure
  870. If preferred provider type is `custom`:
  871. If custom credentials are configured, it is treated as custom mode.
  872. Otherwise, get the current **system mode** if supported,
  873. If all system modes are not available (no quota), it is treated as no_configure.
  874. custom > system > no_configure
  875. If real mode is `system`, use system credentials to get models,
  876. paid quotas > provider free quotas > system free quotas
  877. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  878. If real mode is `custom`, use workspace custom credentials to get models,
  879. include pre-defined models, custom models(manual append).
  880. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  881. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  882. model status marked as `active` is available.
  883. :param provider: provider name
  884. :param model_type: model type
  885. :param only_active: only active models
  886. :return:
  887. """
  888. all_models = []
  889. for provider_configuration in self.values():
  890. if provider and provider_configuration.provider.provider != provider:
  891. continue
  892. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  893. return all_models
  894. def to_list(self) -> list[ProviderConfiguration]:
  895. """
  896. Convert to list.
  897. :return:
  898. """
  899. return list(self.values())
  900. def __getitem__(self, key):
  901. if "/" not in key:
  902. key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
  903. return self.configurations[key]
  904. def __setitem__(self, key, value):
  905. self.configurations[key] = value
  906. def __iter__(self):
  907. return iter(self.configurations)
  908. def values(self) -> Iterator[ProviderConfiguration]:
  909. return iter(self.configurations.values())
  910. def get(self, key, default=None) -> ProviderConfiguration | None:
  911. if "/" not in key:
  912. key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
  913. return self.configurations.get(key, default) # type: ignore
  914. class ProviderModelBundle(BaseModel):
  915. """
  916. Provider model bundle.
  917. """
  918. configuration: ProviderConfiguration
  919. model_type_instance: AIModel
  920. # pydantic configs
  921. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())