provider_configuration.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. import datetime
  2. import json
  3. import time
  4. from json import JSONDecodeError
  5. from typing import Optional, List, Dict, Tuple, Iterator
  6. from pydantic import BaseModel
  7. from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
  8. from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
  9. from core.helper import encrypter
  10. from core.model_runtime.entities.model_entities import ModelType
  11. from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
  12. from core.model_runtime.model_providers import model_provider_factory
  13. from core.model_runtime.model_providers.__base.ai_model import AIModel
  14. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  15. from core.model_runtime.utils import encoders
  16. from extensions.ext_database import db
  17. from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
  18. class ProviderConfiguration(BaseModel):
  19. """
  20. Model class for provider configuration.
  21. """
  22. tenant_id: str
  23. provider: ProviderEntity
  24. preferred_provider_type: ProviderType
  25. using_provider_type: ProviderType
  26. system_configuration: SystemConfiguration
  27. custom_configuration: CustomConfiguration
  28. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  29. """
  30. Get current credentials.
  31. :param model_type: model type
  32. :param model: model name
  33. :return:
  34. """
  35. if self.using_provider_type == ProviderType.SYSTEM:
  36. return self.system_configuration.credentials
  37. else:
  38. if self.custom_configuration.models:
  39. for model_configuration in self.custom_configuration.models:
  40. if model_configuration.model_type == model_type and model_configuration.model == model:
  41. return model_configuration.credentials
  42. if self.custom_configuration.provider:
  43. return self.custom_configuration.provider.credentials
  44. else:
  45. return None
  46. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  47. """
  48. Get system configuration status.
  49. :return:
  50. """
  51. if self.system_configuration.enabled is False:
  52. return SystemConfigurationStatus.UNSUPPORTED
  53. current_quota_type = self.system_configuration.current_quota_type
  54. current_quota_configuration = next(
  55. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  56. None
  57. )
  58. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  59. SystemConfigurationStatus.QUOTA_EXCEEDED
  60. def is_custom_configuration_available(self) -> bool:
  61. """
  62. Check custom configuration available.
  63. :return:
  64. """
  65. return (self.custom_configuration.provider is not None
  66. or len(self.custom_configuration.models) > 0)
  67. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  68. """
  69. Get custom credentials.
  70. :param obfuscated: obfuscated secret data in credentials
  71. :return:
  72. """
  73. if self.custom_configuration.provider is None:
  74. return None
  75. credentials = self.custom_configuration.provider.credentials
  76. if not obfuscated:
  77. return credentials
  78. # Obfuscate credentials
  79. return self._obfuscated_credentials(
  80. credentials=credentials,
  81. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  82. if self.provider.provider_credential_schema else []
  83. )
  84. def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
  85. """
  86. Validate custom credentials.
  87. :param credentials: provider credentials
  88. :return:
  89. """
  90. # get provider
  91. provider_record = db.session.query(Provider) \
  92. .filter(
  93. Provider.tenant_id == self.tenant_id,
  94. Provider.provider_name == self.provider.provider,
  95. Provider.provider_type == ProviderType.CUSTOM.value
  96. ).first()
  97. # Get provider credential secret variables
  98. provider_credential_secret_variables = self._extract_secret_variables(
  99. self.provider.provider_credential_schema.credential_form_schemas
  100. if self.provider.provider_credential_schema else []
  101. )
  102. if provider_record:
  103. try:
  104. original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
  105. except JSONDecodeError:
  106. original_credentials = {}
  107. # encrypt credentials
  108. for key, value in credentials.items():
  109. if key in provider_credential_secret_variables:
  110. # if send [__HIDDEN__] in secret input, it will be same as original value
  111. if value == '[__HIDDEN__]' and key in original_credentials:
  112. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  113. model_provider_factory.provider_credentials_validate(
  114. self.provider.provider,
  115. credentials
  116. )
  117. for key, value in credentials.items():
  118. if key in provider_credential_secret_variables:
  119. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  120. return provider_record, credentials
  121. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  122. """
  123. Add or update custom provider credentials.
  124. :param credentials:
  125. :return:
  126. """
  127. # validate custom provider config
  128. provider_record, credentials = self.custom_credentials_validate(credentials)
  129. # save provider
  130. # Note: Do not switch the preferred provider, which allows users to use quotas first
  131. if provider_record:
  132. provider_record.encrypted_config = json.dumps(credentials)
  133. provider_record.is_valid = True
  134. provider_record.updated_at = datetime.datetime.utcnow()
  135. db.session.commit()
  136. else:
  137. provider_record = Provider(
  138. tenant_id=self.tenant_id,
  139. provider_name=self.provider.provider,
  140. provider_type=ProviderType.CUSTOM.value,
  141. encrypted_config=json.dumps(credentials),
  142. is_valid=True
  143. )
  144. db.session.add(provider_record)
  145. db.session.commit()
  146. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  147. def delete_custom_credentials(self) -> None:
  148. """
  149. Delete custom provider credentials.
  150. :return:
  151. """
  152. # get provider
  153. provider_record = db.session.query(Provider) \
  154. .filter(
  155. Provider.tenant_id == self.tenant_id,
  156. Provider.provider_name == self.provider.provider,
  157. Provider.provider_type == ProviderType.CUSTOM.value
  158. ).first()
  159. # delete provider
  160. if provider_record:
  161. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  162. db.session.delete(provider_record)
  163. db.session.commit()
  164. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  165. -> Optional[dict]:
  166. """
  167. Get custom model credentials.
  168. :param model_type: model type
  169. :param model: model name
  170. :param obfuscated: obfuscated secret data in credentials
  171. :return:
  172. """
  173. if not self.custom_configuration.models:
  174. return None
  175. for model_configuration in self.custom_configuration.models:
  176. if model_configuration.model_type == model_type and model_configuration.model == model:
  177. credentials = model_configuration.credentials
  178. if not obfuscated:
  179. return credentials
  180. # Obfuscate credentials
  181. return self._obfuscated_credentials(
  182. credentials=credentials,
  183. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  184. if self.provider.model_credential_schema else []
  185. )
  186. return None
  187. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  188. -> Tuple[ProviderModel, dict]:
  189. """
  190. Validate custom model credentials.
  191. :param model_type: model type
  192. :param model: model name
  193. :param credentials: model credentials
  194. :return:
  195. """
  196. # get provider model
  197. provider_model_record = db.session.query(ProviderModel) \
  198. .filter(
  199. ProviderModel.tenant_id == self.tenant_id,
  200. ProviderModel.provider_name == self.provider.provider,
  201. ProviderModel.model_name == model,
  202. ProviderModel.model_type == model_type.to_origin_model_type()
  203. ).first()
  204. # Get provider credential secret variables
  205. provider_credential_secret_variables = self._extract_secret_variables(
  206. self.provider.model_credential_schema.credential_form_schemas
  207. if self.provider.model_credential_schema else []
  208. )
  209. if provider_model_record:
  210. try:
  211. original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  212. except JSONDecodeError:
  213. original_credentials = {}
  214. # decrypt credentials
  215. for key, value in credentials.items():
  216. if key in provider_credential_secret_variables:
  217. # if send [__HIDDEN__] in secret input, it will be same as original value
  218. if value == '[__HIDDEN__]' and key in original_credentials:
  219. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  220. model_provider_factory.model_credentials_validate(
  221. provider=self.provider.provider,
  222. model_type=model_type,
  223. model=model,
  224. credentials=credentials
  225. )
  226. model_schema = (
  227. model_provider_factory.get_provider_instance(self.provider.provider)
  228. .get_model_instance(model_type)._get_customizable_model_schema(
  229. model=model,
  230. credentials=credentials
  231. )
  232. )
  233. if model_schema:
  234. credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
  235. for key, value in credentials.items():
  236. if key in provider_credential_secret_variables:
  237. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  238. return provider_model_record, credentials
  239. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  240. """
  241. Add or update custom model credentials.
  242. :param model_type: model type
  243. :param model: model name
  244. :param credentials: model credentials
  245. :return:
  246. """
  247. # validate custom model config
  248. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  249. # save provider model
  250. # Note: Do not switch the preferred provider, which allows users to use quotas first
  251. if provider_model_record:
  252. provider_model_record.encrypted_config = json.dumps(credentials)
  253. provider_model_record.is_valid = True
  254. provider_model_record.updated_at = datetime.datetime.utcnow()
  255. db.session.commit()
  256. else:
  257. provider_model_record = ProviderModel(
  258. tenant_id=self.tenant_id,
  259. provider_name=self.provider.provider,
  260. model_name=model,
  261. model_type=model_type.to_origin_model_type(),
  262. encrypted_config=json.dumps(credentials),
  263. is_valid=True
  264. )
  265. db.session.add(provider_model_record)
  266. db.session.commit()
  267. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  268. """
  269. Delete custom model credentials.
  270. :param model_type: model type
  271. :param model: model name
  272. :return:
  273. """
  274. # get provider model
  275. provider_model_record = db.session.query(ProviderModel) \
  276. .filter(
  277. ProviderModel.tenant_id == self.tenant_id,
  278. ProviderModel.provider_name == self.provider.provider,
  279. ProviderModel.model_name == model,
  280. ProviderModel.model_type == model_type.to_origin_model_type()
  281. ).first()
  282. # delete provider model
  283. if provider_model_record:
  284. db.session.delete(provider_model_record)
  285. db.session.commit()
  286. def get_provider_instance(self) -> ModelProvider:
  287. """
  288. Get provider instance.
  289. :return:
  290. """
  291. return model_provider_factory.get_provider_instance(self.provider.provider)
  292. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  293. """
  294. Get current model type instance.
  295. :param model_type: model type
  296. :return:
  297. """
  298. # Get provider instance
  299. provider_instance = self.get_provider_instance()
  300. # Get model instance of LLM
  301. return provider_instance.get_model_instance(model_type)
  302. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  303. """
  304. Switch preferred provider type.
  305. :param provider_type:
  306. :return:
  307. """
  308. if provider_type == self.preferred_provider_type:
  309. return
  310. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  311. return
  312. # get preferred provider
  313. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  314. .filter(
  315. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  316. TenantPreferredModelProvider.provider_name == self.provider.provider
  317. ).first()
  318. if preferred_model_provider:
  319. preferred_model_provider.preferred_provider_type = provider_type.value
  320. else:
  321. preferred_model_provider = TenantPreferredModelProvider(
  322. tenant_id=self.tenant_id,
  323. provider_name=self.provider.provider,
  324. preferred_provider_type=provider_type.value
  325. )
  326. db.session.add(preferred_model_provider)
  327. db.session.commit()
  328. def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  329. """
  330. Extract secret input form variables.
  331. :param credential_form_schemas:
  332. :return:
  333. """
  334. secret_input_form_variables = []
  335. for credential_form_schema in credential_form_schemas:
  336. if credential_form_schema.type == FormType.SECRET_INPUT:
  337. secret_input_form_variables.append(credential_form_schema.variable)
  338. return secret_input_form_variables
  339. def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  340. """
  341. Obfuscated credentials.
  342. :param credentials: credentials
  343. :param credential_form_schemas: credential form schemas
  344. :return:
  345. """
  346. # Get provider credential secret variables
  347. credential_secret_variables = self._extract_secret_variables(
  348. credential_form_schemas
  349. )
  350. # Obfuscate provider credentials
  351. copy_credentials = credentials.copy()
  352. for key, value in copy_credentials.items():
  353. if key in credential_secret_variables:
  354. copy_credentials[key] = encrypter.obfuscated_token(value)
  355. return copy_credentials
  356. def get_provider_model(self, model_type: ModelType,
  357. model: str,
  358. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  359. """
  360. Get provider model.
  361. :param model_type: model type
  362. :param model: model name
  363. :param only_active: return active model only
  364. :return:
  365. """
  366. provider_models = self.get_provider_models(model_type, only_active)
  367. for provider_model in provider_models:
  368. if provider_model.model == model:
  369. return provider_model
  370. return None
  371. def get_provider_models(self, model_type: Optional[ModelType] = None,
  372. only_active: bool = False) -> list[ModelWithProviderEntity]:
  373. """
  374. Get provider models.
  375. :param model_type: model type
  376. :param only_active: only active models
  377. :return:
  378. """
  379. provider_instance = self.get_provider_instance()
  380. model_types = []
  381. if model_type:
  382. model_types.append(model_type)
  383. else:
  384. model_types = provider_instance.get_provider_schema().supported_model_types
  385. if self.using_provider_type == ProviderType.SYSTEM:
  386. provider_models = self._get_system_provider_models(
  387. model_types=model_types,
  388. provider_instance=provider_instance
  389. )
  390. else:
  391. provider_models = self._get_custom_provider_models(
  392. model_types=model_types,
  393. provider_instance=provider_instance
  394. )
  395. if only_active:
  396. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  397. # resort provider_models
  398. return sorted(provider_models, key=lambda x: x.model_type.value)
  399. def _get_system_provider_models(self,
  400. model_types: list[ModelType],
  401. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  402. """
  403. Get system provider models.
  404. :param model_types: model types
  405. :param provider_instance: provider instance
  406. :return:
  407. """
  408. provider_models = []
  409. for model_type in model_types:
  410. provider_models.extend(
  411. [
  412. ModelWithProviderEntity(
  413. **m.dict(),
  414. provider=SimpleModelProviderEntity(self.provider),
  415. status=ModelStatus.ACTIVE
  416. )
  417. for m in provider_instance.models(model_type)
  418. ]
  419. )
  420. for quota_configuration in self.system_configuration.quota_configurations:
  421. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  422. continue
  423. restrict_llms = quota_configuration.restrict_llms
  424. if not restrict_llms:
  425. break
  426. # if llm name not in restricted llm list, remove it
  427. for m in provider_models:
  428. if m.model_type == ModelType.LLM and m.model not in restrict_llms:
  429. m.status = ModelStatus.NO_PERMISSION
  430. elif not quota_configuration.is_valid:
  431. m.status = ModelStatus.QUOTA_EXCEEDED
  432. return provider_models
  433. def _get_custom_provider_models(self,
  434. model_types: list[ModelType],
  435. provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
  436. """
  437. Get custom provider models.
  438. :param model_types: model types
  439. :param provider_instance: provider instance
  440. :return:
  441. """
  442. provider_models = []
  443. credentials = None
  444. if self.custom_configuration.provider:
  445. credentials = self.custom_configuration.provider.credentials
  446. for model_type in model_types:
  447. if model_type not in self.provider.supported_model_types:
  448. continue
  449. models = provider_instance.models(model_type)
  450. for m in models:
  451. provider_models.append(
  452. ModelWithProviderEntity(
  453. **m.dict(),
  454. provider=SimpleModelProviderEntity(self.provider),
  455. status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  456. )
  457. )
  458. # custom models
  459. for model_configuration in self.custom_configuration.models:
  460. if model_configuration.model_type not in model_types:
  461. continue
  462. custom_model_schema = (
  463. provider_instance.get_model_instance(model_configuration.model_type)
  464. .get_customizable_model_schema_from_credentials(
  465. model_configuration.model,
  466. model_configuration.credentials
  467. )
  468. )
  469. if not custom_model_schema:
  470. continue
  471. provider_models.append(
  472. ModelWithProviderEntity(
  473. **custom_model_schema.dict(),
  474. provider=SimpleModelProviderEntity(self.provider),
  475. status=ModelStatus.ACTIVE
  476. )
  477. )
  478. return provider_models
  479. class ProviderConfigurations(BaseModel):
  480. """
  481. Model class for provider configuration dict.
  482. """
  483. tenant_id: str
  484. configurations: Dict[str, ProviderConfiguration] = {}
  485. def __init__(self, tenant_id: str):
  486. super().__init__(tenant_id=tenant_id)
  487. def get_models(self,
  488. provider: Optional[str] = None,
  489. model_type: Optional[ModelType] = None,
  490. only_active: bool = False) \
  491. -> list[ModelWithProviderEntity]:
  492. """
  493. Get available models.
  494. If preferred provider type is `system`:
  495. Get the current **system mode** if provider supported,
  496. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  497. If there is no model configured in custom mode, it is treated as no_configure.
  498. system > custom > no_configure
  499. If preferred provider type is `custom`:
  500. If custom credentials are configured, it is treated as custom mode.
  501. Otherwise, get the current **system mode** if supported,
  502. If all system modes are not available (no quota), it is treated as no_configure.
  503. custom > system > no_configure
  504. If real mode is `system`, use system credentials to get models,
  505. paid quotas > provider free quotas > system free quotas
  506. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  507. If real mode is `custom`, use workspace custom credentials to get models,
  508. include pre-defined models, custom models(manual append).
  509. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  510. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  511. model status marked as `active` is available.
  512. :param provider: provider name
  513. :param model_type: model type
  514. :param only_active: only active models
  515. :return:
  516. """
  517. all_models = []
  518. for provider_configuration in self.values():
  519. if provider and provider_configuration.provider.provider != provider:
  520. continue
  521. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  522. return all_models
  523. def to_list(self) -> List[ProviderConfiguration]:
  524. """
  525. Convert to list.
  526. :return:
  527. """
  528. return list(self.values())
  529. def __getitem__(self, key):
  530. return self.configurations[key]
  531. def __setitem__(self, key, value):
  532. self.configurations[key] = value
  533. def __iter__(self):
  534. return iter(self.configurations)
  535. def values(self) -> Iterator[ProviderConfiguration]:
  536. return self.configurations.values()
  537. def get(self, key, default=None):
  538. return self.configurations.get(key, default)
  539. class ProviderModelBundle(BaseModel):
  540. """
  541. Provider model bundle.
  542. """
  543. configuration: ProviderConfiguration
  544. provider_instance: ModelProvider
  545. model_type_instance: AIModel
  546. class Config:
  547. """Configuration for this pydantic object."""
  548. arbitrary_types_allowed = True