provider_configuration.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from sqlalchemy import or_
  10. from constants import HIDDEN_VALUE
  11. from core.entities import DEFAULT_PLUGIN_ID
  12. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  13. from core.entities.provider_entities import (
  14. CustomConfiguration,
  15. ModelSettings,
  16. SystemConfiguration,
  17. SystemConfigurationStatus,
  18. )
  19. from core.helper import encrypter
  20. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  21. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  22. from core.model_runtime.entities.provider_entities import (
  23. ConfigurateMethod,
  24. CredentialFormSchema,
  25. FormType,
  26. ProviderEntity,
  27. )
  28. from core.model_runtime.model_providers.__base.ai_model import AIModel
  29. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  30. from core.plugin.entities.plugin import ModelProviderID
  31. from extensions.ext_database import db
  32. from models.provider import (
  33. LoadBalancingModelConfig,
  34. Provider,
  35. ProviderModel,
  36. ProviderModelSetting,
  37. ProviderType,
  38. TenantPreferredModelProvider,
  39. )
  40. logger = logging.getLogger(__name__)
  41. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  42. class ProviderConfiguration(BaseModel):
  43. """
  44. Model class for provider configuration.
  45. """
  46. tenant_id: str
  47. provider: ProviderEntity
  48. preferred_provider_type: ProviderType
  49. using_provider_type: ProviderType
  50. system_configuration: SystemConfiguration
  51. custom_configuration: CustomConfiguration
  52. model_settings: list[ModelSettings]
  53. # pydantic configs
  54. model_config = ConfigDict(protected_namespaces=())
  55. def __init__(self, **data):
  56. super().__init__(**data)
  57. if self.provider.provider not in original_provider_configurate_methods:
  58. original_provider_configurate_methods[self.provider.provider] = []
  59. for configurate_method in self.provider.configurate_methods:
  60. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  61. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  62. if (
  63. any(
  64. len(quota_configuration.restrict_models) > 0
  65. for quota_configuration in self.system_configuration.quota_configurations
  66. )
  67. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  68. ):
  69. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  70. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  71. """
  72. Get current credentials.
  73. :param model_type: model type
  74. :param model: model name
  75. :return:
  76. """
  77. if self.model_settings:
  78. # check if model is disabled by admin
  79. for model_setting in self.model_settings:
  80. if model_setting.model_type == model_type and model_setting.model == model:
  81. if not model_setting.enabled:
  82. raise ValueError(f"Model {model} is disabled.")
  83. if self.using_provider_type == ProviderType.SYSTEM:
  84. restrict_models = []
  85. for quota_configuration in self.system_configuration.quota_configurations:
  86. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  87. continue
  88. restrict_models = quota_configuration.restrict_models
  89. copy_credentials = (
  90. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  91. )
  92. if restrict_models:
  93. for restrict_model in restrict_models:
  94. if (
  95. restrict_model.model_type == model_type
  96. and restrict_model.model == model
  97. and restrict_model.base_model_name
  98. ):
  99. copy_credentials["base_model_name"] = restrict_model.base_model_name
  100. return copy_credentials
  101. else:
  102. credentials = None
  103. if self.custom_configuration.models:
  104. for model_configuration in self.custom_configuration.models:
  105. if model_configuration.model_type == model_type and model_configuration.model == model:
  106. credentials = model_configuration.credentials
  107. break
  108. if not credentials and self.custom_configuration.provider:
  109. credentials = self.custom_configuration.provider.credentials
  110. return credentials
  111. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  112. """
  113. Get system configuration status.
  114. :return:
  115. """
  116. if self.system_configuration.enabled is False:
  117. return SystemConfigurationStatus.UNSUPPORTED
  118. current_quota_type = self.system_configuration.current_quota_type
  119. current_quota_configuration = next(
  120. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  121. )
  122. if current_quota_configuration is None:
  123. return None
  124. if not current_quota_configuration:
  125. return SystemConfigurationStatus.UNSUPPORTED
  126. return (
  127. SystemConfigurationStatus.ACTIVE
  128. if current_quota_configuration.is_valid
  129. else SystemConfigurationStatus.QUOTA_EXCEEDED
  130. )
  131. def is_custom_configuration_available(self) -> bool:
  132. """
  133. Check custom configuration available.
  134. :return:
  135. """
  136. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  137. def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
  138. """
  139. Get custom credentials.
  140. :param obfuscated: obfuscated secret data in credentials
  141. :return:
  142. """
  143. if self.custom_configuration.provider is None:
  144. return None
  145. credentials = self.custom_configuration.provider.credentials
  146. if not obfuscated:
  147. return credentials
  148. # Obfuscate credentials
  149. return self.obfuscated_credentials(
  150. credentials=credentials,
  151. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  152. if self.provider.provider_credential_schema
  153. else [],
  154. )
  155. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
  156. """
  157. Validate custom credentials.
  158. :param credentials: provider credentials
  159. :return:
  160. """
  161. # get provider
  162. provider_record = (
  163. db.session.query(Provider)
  164. .filter(
  165. Provider.tenant_id == self.tenant_id,
  166. Provider.provider_type == ProviderType.CUSTOM.value,
  167. or_(
  168. Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
  169. Provider.provider_name == self.provider.provider,
  170. ),
  171. )
  172. .first()
  173. )
  174. # Get provider credential secret variables
  175. provider_credential_secret_variables = self.extract_secret_variables(
  176. self.provider.provider_credential_schema.credential_form_schemas
  177. if self.provider.provider_credential_schema
  178. else []
  179. )
  180. if provider_record:
  181. try:
  182. # fix origin data
  183. if provider_record.encrypted_config:
  184. if not provider_record.encrypted_config.startswith("{"):
  185. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  186. else:
  187. original_credentials = json.loads(provider_record.encrypted_config)
  188. else:
  189. original_credentials = {}
  190. except JSONDecodeError:
  191. original_credentials = {}
  192. # encrypt credentials
  193. for key, value in credentials.items():
  194. if key in provider_credential_secret_variables:
  195. # if send [__HIDDEN__] in secret input, it will be same as original value
  196. if value == HIDDEN_VALUE and key in original_credentials:
  197. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  198. model_provider_factory = ModelProviderFactory(self.tenant_id)
  199. credentials = model_provider_factory.provider_credentials_validate(
  200. provider=self.provider.provider, credentials=credentials
  201. )
  202. for key, value in credentials.items():
  203. if key in provider_credential_secret_variables:
  204. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  205. return provider_record, credentials
  206. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  207. """
  208. Add or update custom provider credentials.
  209. :param credentials:
  210. :return:
  211. """
  212. # validate custom provider config
  213. provider_record, credentials = self.custom_credentials_validate(credentials)
  214. # save provider
  215. # Note: Do not switch the preferred provider, which allows users to use quotas first
  216. if provider_record:
  217. provider_record.encrypted_config = json.dumps(credentials)
  218. provider_record.is_valid = True
  219. provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  220. db.session.commit()
  221. else:
  222. provider_record = Provider()
  223. provider_record.tenant_id = self.tenant_id
  224. provider_record.provider_name = self.provider.provider
  225. provider_record.provider_type = ProviderType.CUSTOM.value
  226. provider_record.encrypted_config = json.dumps(credentials)
  227. provider_record.is_valid = True
  228. db.session.add(provider_record)
  229. db.session.commit()
  230. provider_model_credentials_cache = ProviderCredentialsCache(
  231. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  232. )
  233. provider_model_credentials_cache.delete()
  234. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  235. def delete_custom_credentials(self) -> None:
  236. """
  237. Delete custom provider credentials.
  238. :return:
  239. """
  240. # get provider
  241. provider_record = (
  242. db.session.query(Provider)
  243. .filter(
  244. Provider.tenant_id == self.tenant_id,
  245. or_(
  246. Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
  247. Provider.provider_name == self.provider.provider,
  248. ),
  249. Provider.provider_type == ProviderType.CUSTOM.value,
  250. )
  251. .first()
  252. )
  253. # delete provider
  254. if provider_record:
  255. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  256. db.session.delete(provider_record)
  257. db.session.commit()
  258. provider_model_credentials_cache = ProviderCredentialsCache(
  259. tenant_id=self.tenant_id,
  260. identity_id=provider_record.id,
  261. cache_type=ProviderCredentialsCacheType.PROVIDER,
  262. )
  263. provider_model_credentials_cache.delete()
  264. def get_custom_model_credentials(
  265. self, model_type: ModelType, model: str, obfuscated: bool = False
  266. ) -> Optional[dict]:
  267. """
  268. Get custom model credentials.
  269. :param model_type: model type
  270. :param model: model name
  271. :param obfuscated: obfuscated secret data in credentials
  272. :return:
  273. """
  274. if not self.custom_configuration.models:
  275. return None
  276. for model_configuration in self.custom_configuration.models:
  277. if model_configuration.model_type == model_type and model_configuration.model == model:
  278. credentials = model_configuration.credentials
  279. if not obfuscated:
  280. return credentials
  281. # Obfuscate credentials
  282. return self.obfuscated_credentials(
  283. credentials=credentials,
  284. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  285. if self.provider.model_credential_schema
  286. else [],
  287. )
  288. return None
  289. def custom_model_credentials_validate(
  290. self, model_type: ModelType, model: str, credentials: dict
  291. ) -> tuple[ProviderModel | None, dict]:
  292. """
  293. Validate custom model credentials.
  294. :param model_type: model type
  295. :param model: model name
  296. :param credentials: model credentials
  297. :return:
  298. """
  299. # get provider model
  300. provider_model_record = (
  301. db.session.query(ProviderModel)
  302. .filter(
  303. ProviderModel.tenant_id == self.tenant_id,
  304. ProviderModel.provider_name == self.provider.provider,
  305. ProviderModel.model_name == model,
  306. ProviderModel.model_type == model_type.to_origin_model_type(),
  307. )
  308. .first()
  309. )
  310. # Get provider credential secret variables
  311. provider_credential_secret_variables = self.extract_secret_variables(
  312. self.provider.model_credential_schema.credential_form_schemas
  313. if self.provider.model_credential_schema
  314. else []
  315. )
  316. if provider_model_record:
  317. try:
  318. original_credentials = (
  319. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  320. )
  321. except JSONDecodeError:
  322. original_credentials = {}
  323. # decrypt credentials
  324. for key, value in credentials.items():
  325. if key in provider_credential_secret_variables:
  326. # if send [__HIDDEN__] in secret input, it will be same as original value
  327. if value == HIDDEN_VALUE and key in original_credentials:
  328. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  329. model_provider_factory = ModelProviderFactory(self.tenant_id)
  330. credentials = model_provider_factory.model_credentials_validate(
  331. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  332. )
  333. for key, value in credentials.items():
  334. if key in provider_credential_secret_variables:
  335. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  336. return provider_model_record, credentials
  337. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  338. """
  339. Add or update custom model credentials.
  340. :param model_type: model type
  341. :param model: model name
  342. :param credentials: model credentials
  343. :return:
  344. """
  345. # validate custom model config
  346. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  347. # save provider model
  348. # Note: Do not switch the preferred provider, which allows users to use quotas first
  349. if provider_model_record:
  350. provider_model_record.encrypted_config = json.dumps(credentials)
  351. provider_model_record.is_valid = True
  352. provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  353. db.session.commit()
  354. else:
  355. provider_model_record = ProviderModel()
  356. provider_model_record.tenant_id = self.tenant_id
  357. provider_model_record.provider_name = self.provider.provider
  358. provider_model_record.model_name = model
  359. provider_model_record.model_type = model_type.to_origin_model_type()
  360. provider_model_record.encrypted_config = json.dumps(credentials)
  361. provider_model_record.is_valid = True
  362. db.session.add(provider_model_record)
  363. db.session.commit()
  364. provider_model_credentials_cache = ProviderCredentialsCache(
  365. tenant_id=self.tenant_id,
  366. identity_id=provider_model_record.id,
  367. cache_type=ProviderCredentialsCacheType.MODEL,
  368. )
  369. provider_model_credentials_cache.delete()
  370. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  371. """
  372. Delete custom model credentials.
  373. :param model_type: model type
  374. :param model: model name
  375. :return:
  376. """
  377. # get provider model
  378. provider_model_record = (
  379. db.session.query(ProviderModel)
  380. .filter(
  381. ProviderModel.tenant_id == self.tenant_id,
  382. ProviderModel.provider_name == self.provider.provider,
  383. ProviderModel.model_name == model,
  384. ProviderModel.model_type == model_type.to_origin_model_type(),
  385. )
  386. .first()
  387. )
  388. # delete provider model
  389. if provider_model_record:
  390. db.session.delete(provider_model_record)
  391. db.session.commit()
  392. provider_model_credentials_cache = ProviderCredentialsCache(
  393. tenant_id=self.tenant_id,
  394. identity_id=provider_model_record.id,
  395. cache_type=ProviderCredentialsCacheType.MODEL,
  396. )
  397. provider_model_credentials_cache.delete()
  398. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  399. """
  400. Enable model.
  401. :param model_type: model type
  402. :param model: model name
  403. :return:
  404. """
  405. model_setting = (
  406. db.session.query(ProviderModelSetting)
  407. .filter(
  408. ProviderModelSetting.tenant_id == self.tenant_id,
  409. ProviderModelSetting.provider_name == self.provider.provider,
  410. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  411. ProviderModelSetting.model_name == model,
  412. )
  413. .first()
  414. )
  415. if model_setting:
  416. model_setting.enabled = True
  417. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  418. db.session.commit()
  419. else:
  420. model_setting = ProviderModelSetting()
  421. model_setting.tenant_id = self.tenant_id
  422. model_setting.provider_name = self.provider.provider
  423. model_setting.model_type = model_type.to_origin_model_type()
  424. model_setting.model_name = model
  425. model_setting.enabled = True
  426. db.session.add(model_setting)
  427. db.session.commit()
  428. return model_setting
  429. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  430. """
  431. Disable model.
  432. :param model_type: model type
  433. :param model: model name
  434. :return:
  435. """
  436. model_setting = (
  437. db.session.query(ProviderModelSetting)
  438. .filter(
  439. ProviderModelSetting.tenant_id == self.tenant_id,
  440. ProviderModelSetting.provider_name == self.provider.provider,
  441. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  442. ProviderModelSetting.model_name == model,
  443. )
  444. .first()
  445. )
  446. if model_setting:
  447. model_setting.enabled = False
  448. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  449. db.session.commit()
  450. else:
  451. model_setting = ProviderModelSetting()
  452. model_setting.tenant_id = self.tenant_id
  453. model_setting.provider_name = self.provider.provider
  454. model_setting.model_type = model_type.to_origin_model_type()
  455. model_setting.model_name = model
  456. model_setting.enabled = False
  457. db.session.add(model_setting)
  458. db.session.commit()
  459. return model_setting
  460. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  461. """
  462. Get provider model setting.
  463. :param model_type: model type
  464. :param model: model name
  465. :return:
  466. """
  467. return (
  468. db.session.query(ProviderModelSetting)
  469. .filter(
  470. ProviderModelSetting.tenant_id == self.tenant_id,
  471. ProviderModelSetting.provider_name == self.provider.provider,
  472. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  473. ProviderModelSetting.model_name == model,
  474. )
  475. .first()
  476. )
  477. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  478. """
  479. Enable model load balancing.
  480. :param model_type: model type
  481. :param model: model name
  482. :return:
  483. """
  484. load_balancing_config_count = (
  485. db.session.query(LoadBalancingModelConfig)
  486. .filter(
  487. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  488. LoadBalancingModelConfig.provider_name == self.provider.provider,
  489. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  490. LoadBalancingModelConfig.model_name == model,
  491. )
  492. .count()
  493. )
  494. if load_balancing_config_count <= 1:
  495. raise ValueError("Model load balancing configuration must be more than 1.")
  496. model_setting = (
  497. db.session.query(ProviderModelSetting)
  498. .filter(
  499. ProviderModelSetting.tenant_id == self.tenant_id,
  500. ProviderModelSetting.provider_name == self.provider.provider,
  501. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  502. ProviderModelSetting.model_name == model,
  503. )
  504. .first()
  505. )
  506. if model_setting:
  507. model_setting.load_balancing_enabled = True
  508. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  509. db.session.commit()
  510. else:
  511. model_setting = ProviderModelSetting()
  512. model_setting.tenant_id = self.tenant_id
  513. model_setting.provider_name = self.provider.provider
  514. model_setting.model_type = model_type.to_origin_model_type()
  515. model_setting.model_name = model
  516. model_setting.load_balancing_enabled = True
  517. db.session.add(model_setting)
  518. db.session.commit()
  519. return model_setting
  520. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  521. """
  522. Disable model load balancing.
  523. :param model_type: model type
  524. :param model: model name
  525. :return:
  526. """
  527. model_setting = (
  528. db.session.query(ProviderModelSetting)
  529. .filter(
  530. ProviderModelSetting.tenant_id == self.tenant_id,
  531. ProviderModelSetting.provider_name == self.provider.provider,
  532. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  533. ProviderModelSetting.model_name == model,
  534. )
  535. .first()
  536. )
  537. if model_setting:
  538. model_setting.load_balancing_enabled = False
  539. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  540. db.session.commit()
  541. else:
  542. model_setting = ProviderModelSetting()
  543. model_setting.tenant_id = self.tenant_id
  544. model_setting.provider_name = self.provider.provider
  545. model_setting.model_type = model_type.to_origin_model_type()
  546. model_setting.model_name = model
  547. model_setting.load_balancing_enabled = False
  548. db.session.add(model_setting)
  549. db.session.commit()
  550. return model_setting
  551. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  552. """
  553. Get current model type instance.
  554. :param model_type: model type
  555. :return:
  556. """
  557. model_provider_factory = ModelProviderFactory(self.tenant_id)
  558. # Get model instance of LLM
  559. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  560. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
  561. """
  562. Get model schema
  563. """
  564. model_provider_factory = ModelProviderFactory(self.tenant_id)
  565. return model_provider_factory.get_model_schema(
  566. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  567. )
  568. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  569. """
  570. Switch preferred provider type.
  571. :param provider_type:
  572. :return:
  573. """
  574. if provider_type == self.preferred_provider_type:
  575. return
  576. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  577. return
  578. # get preferred provider
  579. preferred_model_provider = (
  580. db.session.query(TenantPreferredModelProvider)
  581. .filter(
  582. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  583. TenantPreferredModelProvider.provider_name == self.provider.provider,
  584. )
  585. .first()
  586. )
  587. if preferred_model_provider:
  588. preferred_model_provider.preferred_provider_type = provider_type.value
  589. else:
  590. preferred_model_provider = TenantPreferredModelProvider()
  591. preferred_model_provider.tenant_id = self.tenant_id
  592. preferred_model_provider.provider_name = self.provider.provider
  593. preferred_model_provider.preferred_provider_type = provider_type.value
  594. db.session.add(preferred_model_provider)
  595. db.session.commit()
  596. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  597. """
  598. Extract secret input form variables.
  599. :param credential_form_schemas:
  600. :return:
  601. """
  602. secret_input_form_variables = []
  603. for credential_form_schema in credential_form_schemas:
  604. if credential_form_schema.type == FormType.SECRET_INPUT:
  605. secret_input_form_variables.append(credential_form_schema.variable)
  606. return secret_input_form_variables
  607. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  608. """
  609. Obfuscated credentials.
  610. :param credentials: credentials
  611. :param credential_form_schemas: credential form schemas
  612. :return:
  613. """
  614. # Get provider credential secret variables
  615. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  616. # Obfuscate provider credentials
  617. copy_credentials = credentials.copy()
  618. for key, value in copy_credentials.items():
  619. if key in credential_secret_variables:
  620. copy_credentials[key] = encrypter.obfuscated_token(value)
  621. return copy_credentials
  622. def get_provider_model(
  623. self, model_type: ModelType, model: str, only_active: bool = False
  624. ) -> Optional[ModelWithProviderEntity]:
  625. """
  626. Get provider model.
  627. :param model_type: model type
  628. :param model: model name
  629. :param only_active: return active model only
  630. :return:
  631. """
  632. provider_models = self.get_provider_models(model_type, only_active)
  633. for provider_model in provider_models:
  634. if provider_model.model == model:
  635. return provider_model
  636. return None
  637. def get_provider_models(
  638. self, model_type: Optional[ModelType] = None, only_active: bool = False
  639. ) -> list[ModelWithProviderEntity]:
  640. """
  641. Get provider models.
  642. :param model_type: model type
  643. :param only_active: only active models
  644. :return:
  645. """
  646. model_provider_factory = ModelProviderFactory(self.tenant_id)
  647. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  648. model_types: list[ModelType] = []
  649. if model_type:
  650. model_types.append(model_type)
  651. else:
  652. model_types = list(provider_schema.supported_model_types)
  653. # Group model settings by model type and model
  654. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  655. for model_setting in self.model_settings:
  656. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  657. if self.using_provider_type == ProviderType.SYSTEM:
  658. provider_models = self._get_system_provider_models(
  659. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  660. )
  661. else:
  662. provider_models = self._get_custom_provider_models(
  663. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  664. )
  665. if only_active:
  666. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  667. # resort provider_models
  668. return sorted(provider_models, key=lambda x: x.model_type.value)
  669. def _get_system_provider_models(
  670. self,
  671. model_types: Sequence[ModelType],
  672. provider_schema: ProviderEntity,
  673. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  674. ) -> list[ModelWithProviderEntity]:
  675. """
  676. Get system provider models.
  677. :param model_types: model types
  678. :param provider_schema: provider schema
  679. :param model_setting_map: model setting map
  680. :return:
  681. """
  682. provider_models = []
  683. for model_type in model_types:
  684. for m in provider_schema.models:
  685. if m.model_type != model_type:
  686. continue
  687. status = ModelStatus.ACTIVE
  688. if m.model in model_setting_map:
  689. model_setting = model_setting_map[m.model_type][m.model]
  690. if model_setting.enabled is False:
  691. status = ModelStatus.DISABLED
  692. provider_models.append(
  693. ModelWithProviderEntity(
  694. model=m.model,
  695. label=m.label,
  696. model_type=m.model_type,
  697. features=m.features,
  698. fetch_from=m.fetch_from,
  699. model_properties=m.model_properties,
  700. deprecated=m.deprecated,
  701. provider=SimpleModelProviderEntity(self.provider),
  702. status=status,
  703. )
  704. )
  705. if self.provider.provider not in original_provider_configurate_methods:
  706. original_provider_configurate_methods[self.provider.provider] = []
  707. for configurate_method in provider_schema.configurate_methods:
  708. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  709. should_use_custom_model = False
  710. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  711. should_use_custom_model = True
  712. for quota_configuration in self.system_configuration.quota_configurations:
  713. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  714. continue
  715. restrict_models = quota_configuration.restrict_models
  716. if len(restrict_models) == 0:
  717. break
  718. if should_use_custom_model:
  719. if original_provider_configurate_methods[self.provider.provider] == [
  720. ConfigurateMethod.CUSTOMIZABLE_MODEL
  721. ]:
  722. # only customizable model
  723. for restrict_model in restrict_models:
  724. copy_credentials = (
  725. self.system_configuration.credentials.copy()
  726. if self.system_configuration.credentials
  727. else {}
  728. )
  729. if restrict_model.base_model_name:
  730. copy_credentials["base_model_name"] = restrict_model.base_model_name
  731. try:
  732. custom_model_schema = self.get_model_schema(
  733. model_type=restrict_model.model_type,
  734. model=restrict_model.model,
  735. credentials=copy_credentials,
  736. )
  737. except Exception as ex:
  738. logger.warning(f"get custom model schema failed, {ex}")
  739. if not custom_model_schema:
  740. continue
  741. if custom_model_schema.model_type not in model_types:
  742. continue
  743. status = ModelStatus.ACTIVE
  744. if (
  745. custom_model_schema.model_type in model_setting_map
  746. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  747. ):
  748. model_setting = model_setting_map[custom_model_schema.model_type][
  749. custom_model_schema.model
  750. ]
  751. if model_setting.enabled is False:
  752. status = ModelStatus.DISABLED
  753. provider_models.append(
  754. ModelWithProviderEntity(
  755. model=custom_model_schema.model,
  756. label=custom_model_schema.label,
  757. model_type=custom_model_schema.model_type,
  758. features=custom_model_schema.features,
  759. fetch_from=FetchFrom.PREDEFINED_MODEL,
  760. model_properties=custom_model_schema.model_properties,
  761. deprecated=custom_model_schema.deprecated,
  762. provider=SimpleModelProviderEntity(self.provider),
  763. status=status,
  764. )
  765. )
  766. # if llm name not in restricted llm list, remove it
  767. restrict_model_names = [rm.model for rm in restrict_models]
  768. for model in provider_models:
  769. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  770. model.status = ModelStatus.NO_PERMISSION
  771. elif not quota_configuration.is_valid:
  772. model.status = ModelStatus.QUOTA_EXCEEDED
  773. return provider_models
  774. def _get_custom_provider_models(
  775. self,
  776. model_types: Sequence[ModelType],
  777. provider_schema: ProviderEntity,
  778. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  779. ) -> list[ModelWithProviderEntity]:
  780. """
  781. Get custom provider models.
  782. :param model_types: model types
  783. :param provider_schema: provider schema
  784. :param model_setting_map: model setting map
  785. :return:
  786. """
  787. provider_models = []
  788. credentials = None
  789. if self.custom_configuration.provider:
  790. credentials = self.custom_configuration.provider.credentials
  791. for model_type in model_types:
  792. if model_type not in self.provider.supported_model_types:
  793. continue
  794. for m in provider_schema.models:
  795. if m.model_type != model_type:
  796. continue
  797. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  798. load_balancing_enabled = False
  799. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  800. model_setting = model_setting_map[m.model_type][m.model]
  801. if model_setting.enabled is False:
  802. status = ModelStatus.DISABLED
  803. if len(model_setting.load_balancing_configs) > 1:
  804. load_balancing_enabled = True
  805. provider_models.append(
  806. ModelWithProviderEntity(
  807. model=m.model,
  808. label=m.label,
  809. model_type=m.model_type,
  810. features=m.features,
  811. fetch_from=m.fetch_from,
  812. model_properties=m.model_properties,
  813. deprecated=m.deprecated,
  814. provider=SimpleModelProviderEntity(self.provider),
  815. status=status,
  816. load_balancing_enabled=load_balancing_enabled,
  817. )
  818. )
  819. # custom models
  820. for model_configuration in self.custom_configuration.models:
  821. if model_configuration.model_type not in model_types:
  822. continue
  823. try:
  824. custom_model_schema = self.get_model_schema(
  825. model_type=model_configuration.model_type,
  826. model=model_configuration.model,
  827. credentials=model_configuration.credentials,
  828. )
  829. except Exception as ex:
  830. logger.warning(f"get custom model schema failed, {ex}")
  831. continue
  832. if not custom_model_schema:
  833. continue
  834. status = ModelStatus.ACTIVE
  835. load_balancing_enabled = False
  836. if (
  837. custom_model_schema.model_type in model_setting_map
  838. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  839. ):
  840. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  841. if model_setting.enabled is False:
  842. status = ModelStatus.DISABLED
  843. if len(model_setting.load_balancing_configs) > 1:
  844. load_balancing_enabled = True
  845. provider_models.append(
  846. ModelWithProviderEntity(
  847. model=custom_model_schema.model,
  848. label=custom_model_schema.label,
  849. model_type=custom_model_schema.model_type,
  850. features=custom_model_schema.features,
  851. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  852. model_properties=custom_model_schema.model_properties,
  853. deprecated=custom_model_schema.deprecated,
  854. provider=SimpleModelProviderEntity(self.provider),
  855. status=status,
  856. load_balancing_enabled=load_balancing_enabled,
  857. )
  858. )
  859. return provider_models
  860. class ProviderConfigurations(BaseModel):
  861. """
  862. Model class for provider configuration dict.
  863. """
  864. tenant_id: str
  865. configurations: dict[str, ProviderConfiguration] = {}
  866. def __init__(self, tenant_id: str):
  867. super().__init__(tenant_id=tenant_id)
  868. def get_models(
  869. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  870. ) -> list[ModelWithProviderEntity]:
  871. """
  872. Get available models.
  873. If preferred provider type is `system`:
  874. Get the current **system mode** if provider supported,
  875. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  876. If there is no model configured in custom mode, it is treated as no_configure.
  877. system > custom > no_configure
  878. If preferred provider type is `custom`:
  879. If custom credentials are configured, it is treated as custom mode.
  880. Otherwise, get the current **system mode** if supported,
  881. If all system modes are not available (no quota), it is treated as no_configure.
  882. custom > system > no_configure
  883. If real mode is `system`, use system credentials to get models,
  884. paid quotas > provider free quotas > system free quotas
  885. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  886. If real mode is `custom`, use workspace custom credentials to get models,
  887. include pre-defined models, custom models(manual append).
  888. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  889. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  890. model status marked as `active` is available.
  891. :param provider: provider name
  892. :param model_type: model type
  893. :param only_active: only active models
  894. :return:
  895. """
  896. all_models = []
  897. for provider_configuration in self.values():
  898. if provider and provider_configuration.provider.provider != provider:
  899. continue
  900. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  901. return all_models
  902. def to_list(self) -> list[ProviderConfiguration]:
  903. """
  904. Convert to list.
  905. :return:
  906. """
  907. return list(self.values())
  908. def __getitem__(self, key):
  909. if "/" not in key:
  910. key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
  911. return self.configurations[key]
  912. def __setitem__(self, key, value):
  913. self.configurations[key] = value
  914. def __iter__(self):
  915. return iter(self.configurations)
  916. def values(self) -> Iterator[ProviderConfiguration]:
  917. return iter(self.configurations.values())
  918. def get(self, key, default=None) -> ProviderConfiguration | None:
  919. if "/" not in key:
  920. key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
  921. return self.configurations.get(key, default) # type: ignore
  922. class ProviderModelBundle(BaseModel):
  923. """
  924. Provider model bundle.
  925. """
  926. configuration: ProviderConfiguration
  927. model_type_instance: AIModel
  928. # pydantic configs
  929. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())