provider_configuration.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel
  9. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  10. from core.entities.provider_entities import (
  11. CustomConfiguration,
  12. ModelSettings,
  13. SystemConfiguration,
  14. SystemConfigurationStatus,
  15. )
  16. from core.helper import encrypter
  17. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  18. from core.model_runtime.entities.model_entities import FetchFrom, ModelType
  19. from core.model_runtime.entities.provider_entities import (
  20. ConfigurateMethod,
  21. CredentialFormSchema,
  22. FormType,
  23. ProviderEntity,
  24. )
  25. from core.model_runtime.model_providers import model_provider_factory
  26. from core.model_runtime.model_providers.__base.ai_model import AIModel
  27. from core.model_runtime.model_providers.__base.model_provider import ModelProvider
  28. from extensions.ext_database import db
  29. from models.provider import (
  30. LoadBalancingModelConfig,
  31. Provider,
  32. ProviderModel,
  33. ProviderModelSetting,
  34. ProviderType,
  35. TenantPreferredModelProvider,
  36. )
  37. logger = logging.getLogger(__name__)
  38. original_provider_configurate_methods = {}
  39. class ProviderConfiguration(BaseModel):
  40. """
  41. Model class for provider configuration.
  42. """
  43. tenant_id: str
  44. provider: ProviderEntity
  45. preferred_provider_type: ProviderType
  46. using_provider_type: ProviderType
  47. system_configuration: SystemConfiguration
  48. custom_configuration: CustomConfiguration
  49. model_settings: list[ModelSettings]
  50. def __init__(self, **data):
  51. super().__init__(**data)
  52. if self.provider.provider not in original_provider_configurate_methods:
  53. original_provider_configurate_methods[self.provider.provider] = []
  54. for configurate_method in self.provider.configurate_methods:
  55. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  56. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  57. if (any([len(quota_configuration.restrict_models) > 0
  58. for quota_configuration in self.system_configuration.quota_configurations])
  59. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
  60. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  61. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  62. """
  63. Get current credentials.
  64. :param model_type: model type
  65. :param model: model name
  66. :return:
  67. """
  68. if self.model_settings:
  69. # check if model is disabled by admin
  70. for model_setting in self.model_settings:
  71. if (model_setting.model_type == model_type
  72. and model_setting.model == model):
  73. if not model_setting.enabled:
  74. raise ValueError(f'Model {model} is disabled.')
  75. if self.using_provider_type == ProviderType.SYSTEM:
  76. restrict_models = []
  77. for quota_configuration in self.system_configuration.quota_configurations:
  78. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  79. continue
  80. restrict_models = quota_configuration.restrict_models
  81. copy_credentials = self.system_configuration.credentials.copy()
  82. if restrict_models:
  83. for restrict_model in restrict_models:
  84. if (restrict_model.model_type == model_type
  85. and restrict_model.model == model
  86. and restrict_model.base_model_name):
  87. copy_credentials['base_model_name'] = restrict_model.base_model_name
  88. return copy_credentials
  89. else:
  90. credentials = None
  91. if self.custom_configuration.models:
  92. for model_configuration in self.custom_configuration.models:
  93. if model_configuration.model_type == model_type and model_configuration.model == model:
  94. credentials = model_configuration.credentials
  95. break
  96. if self.custom_configuration.provider:
  97. credentials = self.custom_configuration.provider.credentials
  98. return credentials
  99. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  100. """
  101. Get system configuration status.
  102. :return:
  103. """
  104. if self.system_configuration.enabled is False:
  105. return SystemConfigurationStatus.UNSUPPORTED
  106. current_quota_type = self.system_configuration.current_quota_type
  107. current_quota_configuration = next(
  108. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
  109. None
  110. )
  111. return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
  112. SystemConfigurationStatus.QUOTA_EXCEEDED
  113. def is_custom_configuration_available(self) -> bool:
  114. """
  115. Check custom configuration available.
  116. :return:
  117. """
  118. return (self.custom_configuration.provider is not None
  119. or len(self.custom_configuration.models) > 0)
  120. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  121. """
  122. Get custom credentials.
  123. :param obfuscated: obfuscated secret data in credentials
  124. :return:
  125. """
  126. if self.custom_configuration.provider is None:
  127. return None
  128. credentials = self.custom_configuration.provider.credentials
  129. if not obfuscated:
  130. return credentials
  131. # Obfuscate credentials
  132. return self.obfuscated_credentials(
  133. credentials=credentials,
  134. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  135. if self.provider.provider_credential_schema else []
  136. )
  137. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
  138. """
  139. Validate custom credentials.
  140. :param credentials: provider credentials
  141. :return:
  142. """
  143. # get provider
  144. provider_record = db.session.query(Provider) \
  145. .filter(
  146. Provider.tenant_id == self.tenant_id,
  147. Provider.provider_name == self.provider.provider,
  148. Provider.provider_type == ProviderType.CUSTOM.value
  149. ).first()
  150. # Get provider credential secret variables
  151. provider_credential_secret_variables = self.extract_secret_variables(
  152. self.provider.provider_credential_schema.credential_form_schemas
  153. if self.provider.provider_credential_schema else []
  154. )
  155. if provider_record:
  156. try:
  157. # fix origin data
  158. if provider_record.encrypted_config:
  159. if not provider_record.encrypted_config.startswith("{"):
  160. original_credentials = {
  161. "openai_api_key": provider_record.encrypted_config
  162. }
  163. else:
  164. original_credentials = json.loads(provider_record.encrypted_config)
  165. else:
  166. original_credentials = {}
  167. except JSONDecodeError:
  168. original_credentials = {}
  169. # encrypt credentials
  170. for key, value in credentials.items():
  171. if key in provider_credential_secret_variables:
  172. # if send [__HIDDEN__] in secret input, it will be same as original value
  173. if value == '[__HIDDEN__]' and key in original_credentials:
  174. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  175. credentials = model_provider_factory.provider_credentials_validate(
  176. self.provider.provider,
  177. credentials
  178. )
  179. for key, value in credentials.items():
  180. if key in provider_credential_secret_variables:
  181. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  182. return provider_record, credentials
  183. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  184. """
  185. Add or update custom provider credentials.
  186. :param credentials:
  187. :return:
  188. """
  189. # validate custom provider config
  190. provider_record, credentials = self.custom_credentials_validate(credentials)
  191. # save provider
  192. # Note: Do not switch the preferred provider, which allows users to use quotas first
  193. if provider_record:
  194. provider_record.encrypted_config = json.dumps(credentials)
  195. provider_record.is_valid = True
  196. provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  197. db.session.commit()
  198. else:
  199. provider_record = Provider(
  200. tenant_id=self.tenant_id,
  201. provider_name=self.provider.provider,
  202. provider_type=ProviderType.CUSTOM.value,
  203. encrypted_config=json.dumps(credentials),
  204. is_valid=True
  205. )
  206. db.session.add(provider_record)
  207. db.session.commit()
  208. provider_model_credentials_cache = ProviderCredentialsCache(
  209. tenant_id=self.tenant_id,
  210. identity_id=provider_record.id,
  211. cache_type=ProviderCredentialsCacheType.PROVIDER
  212. )
  213. provider_model_credentials_cache.delete()
  214. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  215. def delete_custom_credentials(self) -> None:
  216. """
  217. Delete custom provider credentials.
  218. :return:
  219. """
  220. # get provider
  221. provider_record = db.session.query(Provider) \
  222. .filter(
  223. Provider.tenant_id == self.tenant_id,
  224. Provider.provider_name == self.provider.provider,
  225. Provider.provider_type == ProviderType.CUSTOM.value
  226. ).first()
  227. # delete provider
  228. if provider_record:
  229. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  230. db.session.delete(provider_record)
  231. db.session.commit()
  232. provider_model_credentials_cache = ProviderCredentialsCache(
  233. tenant_id=self.tenant_id,
  234. identity_id=provider_record.id,
  235. cache_type=ProviderCredentialsCacheType.PROVIDER
  236. )
  237. provider_model_credentials_cache.delete()
  238. def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
  239. -> Optional[dict]:
  240. """
  241. Get custom model credentials.
  242. :param model_type: model type
  243. :param model: model name
  244. :param obfuscated: obfuscated secret data in credentials
  245. :return:
  246. """
  247. if not self.custom_configuration.models:
  248. return None
  249. for model_configuration in self.custom_configuration.models:
  250. if model_configuration.model_type == model_type and model_configuration.model == model:
  251. credentials = model_configuration.credentials
  252. if not obfuscated:
  253. return credentials
  254. # Obfuscate credentials
  255. return self.obfuscated_credentials(
  256. credentials=credentials,
  257. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  258. if self.provider.model_credential_schema else []
  259. )
  260. return None
  261. def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
  262. -> tuple[ProviderModel, dict]:
  263. """
  264. Validate custom model credentials.
  265. :param model_type: model type
  266. :param model: model name
  267. :param credentials: model credentials
  268. :return:
  269. """
  270. # get provider model
  271. provider_model_record = db.session.query(ProviderModel) \
  272. .filter(
  273. ProviderModel.tenant_id == self.tenant_id,
  274. ProviderModel.provider_name == self.provider.provider,
  275. ProviderModel.model_name == model,
  276. ProviderModel.model_type == model_type.to_origin_model_type()
  277. ).first()
  278. # Get provider credential secret variables
  279. provider_credential_secret_variables = self.extract_secret_variables(
  280. self.provider.model_credential_schema.credential_form_schemas
  281. if self.provider.model_credential_schema else []
  282. )
  283. if provider_model_record:
  284. try:
  285. original_credentials = json.loads(
  286. provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  287. except JSONDecodeError:
  288. original_credentials = {}
  289. # decrypt credentials
  290. for key, value in credentials.items():
  291. if key in provider_credential_secret_variables:
  292. # if send [__HIDDEN__] in secret input, it will be same as original value
  293. if value == '[__HIDDEN__]' and key in original_credentials:
  294. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  295. credentials = model_provider_factory.model_credentials_validate(
  296. provider=self.provider.provider,
  297. model_type=model_type,
  298. model=model,
  299. credentials=credentials
  300. )
  301. for key, value in credentials.items():
  302. if key in provider_credential_secret_variables:
  303. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  304. return provider_model_record, credentials
  305. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  306. """
  307. Add or update custom model credentials.
  308. :param model_type: model type
  309. :param model: model name
  310. :param credentials: model credentials
  311. :return:
  312. """
  313. # validate custom model config
  314. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  315. # save provider model
  316. # Note: Do not switch the preferred provider, which allows users to use quotas first
  317. if provider_model_record:
  318. provider_model_record.encrypted_config = json.dumps(credentials)
  319. provider_model_record.is_valid = True
  320. provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  321. db.session.commit()
  322. else:
  323. provider_model_record = ProviderModel(
  324. tenant_id=self.tenant_id,
  325. provider_name=self.provider.provider,
  326. model_name=model,
  327. model_type=model_type.to_origin_model_type(),
  328. encrypted_config=json.dumps(credentials),
  329. is_valid=True
  330. )
  331. db.session.add(provider_model_record)
  332. db.session.commit()
  333. provider_model_credentials_cache = ProviderCredentialsCache(
  334. tenant_id=self.tenant_id,
  335. identity_id=provider_model_record.id,
  336. cache_type=ProviderCredentialsCacheType.MODEL
  337. )
  338. provider_model_credentials_cache.delete()
  339. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  340. """
  341. Delete custom model credentials.
  342. :param model_type: model type
  343. :param model: model name
  344. :return:
  345. """
  346. # get provider model
  347. provider_model_record = db.session.query(ProviderModel) \
  348. .filter(
  349. ProviderModel.tenant_id == self.tenant_id,
  350. ProviderModel.provider_name == self.provider.provider,
  351. ProviderModel.model_name == model,
  352. ProviderModel.model_type == model_type.to_origin_model_type()
  353. ).first()
  354. # delete provider model
  355. if provider_model_record:
  356. db.session.delete(provider_model_record)
  357. db.session.commit()
  358. provider_model_credentials_cache = ProviderCredentialsCache(
  359. tenant_id=self.tenant_id,
  360. identity_id=provider_model_record.id,
  361. cache_type=ProviderCredentialsCacheType.MODEL
  362. )
  363. provider_model_credentials_cache.delete()
  364. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  365. """
  366. Enable model.
  367. :param model_type: model type
  368. :param model: model name
  369. :return:
  370. """
  371. model_setting = db.session.query(ProviderModelSetting) \
  372. .filter(
  373. ProviderModelSetting.tenant_id == self.tenant_id,
  374. ProviderModelSetting.provider_name == self.provider.provider,
  375. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  376. ProviderModelSetting.model_name == model
  377. ).first()
  378. if model_setting:
  379. model_setting.enabled = True
  380. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  381. db.session.commit()
  382. else:
  383. model_setting = ProviderModelSetting(
  384. tenant_id=self.tenant_id,
  385. provider_name=self.provider.provider,
  386. model_type=model_type.to_origin_model_type(),
  387. model_name=model,
  388. enabled=True
  389. )
  390. db.session.add(model_setting)
  391. db.session.commit()
  392. return model_setting
  393. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  394. """
  395. Disable model.
  396. :param model_type: model type
  397. :param model: model name
  398. :return:
  399. """
  400. model_setting = db.session.query(ProviderModelSetting) \
  401. .filter(
  402. ProviderModelSetting.tenant_id == self.tenant_id,
  403. ProviderModelSetting.provider_name == self.provider.provider,
  404. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  405. ProviderModelSetting.model_name == model
  406. ).first()
  407. if model_setting:
  408. model_setting.enabled = False
  409. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  410. db.session.commit()
  411. else:
  412. model_setting = ProviderModelSetting(
  413. tenant_id=self.tenant_id,
  414. provider_name=self.provider.provider,
  415. model_type=model_type.to_origin_model_type(),
  416. model_name=model,
  417. enabled=False
  418. )
  419. db.session.add(model_setting)
  420. db.session.commit()
  421. return model_setting
  422. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  423. """
  424. Get provider model setting.
  425. :param model_type: model type
  426. :param model: model name
  427. :return:
  428. """
  429. return db.session.query(ProviderModelSetting) \
  430. .filter(
  431. ProviderModelSetting.tenant_id == self.tenant_id,
  432. ProviderModelSetting.provider_name == self.provider.provider,
  433. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  434. ProviderModelSetting.model_name == model
  435. ).first()
  436. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  437. """
  438. Enable model load balancing.
  439. :param model_type: model type
  440. :param model: model name
  441. :return:
  442. """
  443. load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
  444. .filter(
  445. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  446. LoadBalancingModelConfig.provider_name == self.provider.provider,
  447. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  448. LoadBalancingModelConfig.model_name == model
  449. ).count()
  450. if load_balancing_config_count <= 1:
  451. raise ValueError('Model load balancing configuration must be more than 1.')
  452. model_setting = db.session.query(ProviderModelSetting) \
  453. .filter(
  454. ProviderModelSetting.tenant_id == self.tenant_id,
  455. ProviderModelSetting.provider_name == self.provider.provider,
  456. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  457. ProviderModelSetting.model_name == model
  458. ).first()
  459. if model_setting:
  460. model_setting.load_balancing_enabled = True
  461. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  462. db.session.commit()
  463. else:
  464. model_setting = ProviderModelSetting(
  465. tenant_id=self.tenant_id,
  466. provider_name=self.provider.provider,
  467. model_type=model_type.to_origin_model_type(),
  468. model_name=model,
  469. load_balancing_enabled=True
  470. )
  471. db.session.add(model_setting)
  472. db.session.commit()
  473. return model_setting
  474. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  475. """
  476. Disable model load balancing.
  477. :param model_type: model type
  478. :param model: model name
  479. :return:
  480. """
  481. model_setting = db.session.query(ProviderModelSetting) \
  482. .filter(
  483. ProviderModelSetting.tenant_id == self.tenant_id,
  484. ProviderModelSetting.provider_name == self.provider.provider,
  485. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  486. ProviderModelSetting.model_name == model
  487. ).first()
  488. if model_setting:
  489. model_setting.load_balancing_enabled = False
  490. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  491. db.session.commit()
  492. else:
  493. model_setting = ProviderModelSetting(
  494. tenant_id=self.tenant_id,
  495. provider_name=self.provider.provider,
  496. model_type=model_type.to_origin_model_type(),
  497. model_name=model,
  498. load_balancing_enabled=False
  499. )
  500. db.session.add(model_setting)
  501. db.session.commit()
  502. return model_setting
  503. def get_provider_instance(self) -> ModelProvider:
  504. """
  505. Get provider instance.
  506. :return:
  507. """
  508. return model_provider_factory.get_provider_instance(self.provider.provider)
  509. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  510. """
  511. Get current model type instance.
  512. :param model_type: model type
  513. :return:
  514. """
  515. # Get provider instance
  516. provider_instance = self.get_provider_instance()
  517. # Get model instance of LLM
  518. return provider_instance.get_model_instance(model_type)
  519. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  520. """
  521. Switch preferred provider type.
  522. :param provider_type:
  523. :return:
  524. """
  525. if provider_type == self.preferred_provider_type:
  526. return
  527. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  528. return
  529. # get preferred provider
  530. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  531. .filter(
  532. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  533. TenantPreferredModelProvider.provider_name == self.provider.provider
  534. ).first()
  535. if preferred_model_provider:
  536. preferred_model_provider.preferred_provider_type = provider_type.value
  537. else:
  538. preferred_model_provider = TenantPreferredModelProvider(
  539. tenant_id=self.tenant_id,
  540. provider_name=self.provider.provider,
  541. preferred_provider_type=provider_type.value
  542. )
  543. db.session.add(preferred_model_provider)
  544. db.session.commit()
  545. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  546. """
  547. Extract secret input form variables.
  548. :param credential_form_schemas:
  549. :return:
  550. """
  551. secret_input_form_variables = []
  552. for credential_form_schema in credential_form_schemas:
  553. if credential_form_schema.type == FormType.SECRET_INPUT:
  554. secret_input_form_variables.append(credential_form_schema.variable)
  555. return secret_input_form_variables
  556. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  557. """
  558. Obfuscated credentials.
  559. :param credentials: credentials
  560. :param credential_form_schemas: credential form schemas
  561. :return:
  562. """
  563. # Get provider credential secret variables
  564. credential_secret_variables = self.extract_secret_variables(
  565. credential_form_schemas
  566. )
  567. # Obfuscate provider credentials
  568. copy_credentials = credentials.copy()
  569. for key, value in copy_credentials.items():
  570. if key in credential_secret_variables:
  571. copy_credentials[key] = encrypter.obfuscated_token(value)
  572. return copy_credentials
  573. def get_provider_model(self, model_type: ModelType,
  574. model: str,
  575. only_active: bool = False) -> Optional[ModelWithProviderEntity]:
  576. """
  577. Get provider model.
  578. :param model_type: model type
  579. :param model: model name
  580. :param only_active: return active model only
  581. :return:
  582. """
  583. provider_models = self.get_provider_models(model_type, only_active)
  584. for provider_model in provider_models:
  585. if provider_model.model == model:
  586. return provider_model
  587. return None
  588. def get_provider_models(self, model_type: Optional[ModelType] = None,
  589. only_active: bool = False) -> list[ModelWithProviderEntity]:
  590. """
  591. Get provider models.
  592. :param model_type: model type
  593. :param only_active: only active models
  594. :return:
  595. """
  596. provider_instance = self.get_provider_instance()
  597. model_types = []
  598. if model_type:
  599. model_types.append(model_type)
  600. else:
  601. model_types = provider_instance.get_provider_schema().supported_model_types
  602. # Group model settings by model type and model
  603. model_setting_map = defaultdict(dict)
  604. for model_setting in self.model_settings:
  605. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  606. if self.using_provider_type == ProviderType.SYSTEM:
  607. provider_models = self._get_system_provider_models(
  608. model_types=model_types,
  609. provider_instance=provider_instance,
  610. model_setting_map=model_setting_map
  611. )
  612. else:
  613. provider_models = self._get_custom_provider_models(
  614. model_types=model_types,
  615. provider_instance=provider_instance,
  616. model_setting_map=model_setting_map
  617. )
  618. if only_active:
  619. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  620. # resort provider_models
  621. return sorted(provider_models, key=lambda x: x.model_type.value)
  622. def _get_system_provider_models(self,
  623. model_types: list[ModelType],
  624. provider_instance: ModelProvider,
  625. model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
  626. -> list[ModelWithProviderEntity]:
  627. """
  628. Get system provider models.
  629. :param model_types: model types
  630. :param provider_instance: provider instance
  631. :param model_setting_map: model setting map
  632. :return:
  633. """
  634. provider_models = []
  635. for model_type in model_types:
  636. for m in provider_instance.models(model_type):
  637. status = ModelStatus.ACTIVE
  638. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  639. model_setting = model_setting_map[m.model_type][m.model]
  640. if model_setting.enabled is False:
  641. status = ModelStatus.DISABLED
  642. provider_models.append(
  643. ModelWithProviderEntity(
  644. model=m.model,
  645. label=m.label,
  646. model_type=m.model_type,
  647. features=m.features,
  648. fetch_from=m.fetch_from,
  649. model_properties=m.model_properties,
  650. deprecated=m.deprecated,
  651. provider=SimpleModelProviderEntity(self.provider),
  652. status=status
  653. )
  654. )
  655. if self.provider.provider not in original_provider_configurate_methods:
  656. original_provider_configurate_methods[self.provider.provider] = []
  657. for configurate_method in provider_instance.get_provider_schema().configurate_methods:
  658. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  659. should_use_custom_model = False
  660. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  661. should_use_custom_model = True
  662. for quota_configuration in self.system_configuration.quota_configurations:
  663. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  664. continue
  665. restrict_models = quota_configuration.restrict_models
  666. if len(restrict_models) == 0:
  667. break
  668. if should_use_custom_model:
  669. if original_provider_configurate_methods[self.provider.provider] == [
  670. ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  671. # only customizable model
  672. for restrict_model in restrict_models:
  673. copy_credentials = self.system_configuration.credentials.copy()
  674. if restrict_model.base_model_name:
  675. copy_credentials['base_model_name'] = restrict_model.base_model_name
  676. try:
  677. custom_model_schema = (
  678. provider_instance.get_model_instance(restrict_model.model_type)
  679. .get_customizable_model_schema_from_credentials(
  680. restrict_model.model,
  681. copy_credentials
  682. )
  683. )
  684. except Exception as ex:
  685. logger.warning(f'get custom model schema failed, {ex}')
  686. continue
  687. if not custom_model_schema:
  688. continue
  689. if custom_model_schema.model_type not in model_types:
  690. continue
  691. status = ModelStatus.ACTIVE
  692. if (custom_model_schema.model_type in model_setting_map
  693. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
  694. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  695. if model_setting.enabled is False:
  696. status = ModelStatus.DISABLED
  697. provider_models.append(
  698. ModelWithProviderEntity(
  699. model=custom_model_schema.model,
  700. label=custom_model_schema.label,
  701. model_type=custom_model_schema.model_type,
  702. features=custom_model_schema.features,
  703. fetch_from=FetchFrom.PREDEFINED_MODEL,
  704. model_properties=custom_model_schema.model_properties,
  705. deprecated=custom_model_schema.deprecated,
  706. provider=SimpleModelProviderEntity(self.provider),
  707. status=status
  708. )
  709. )
  710. # if llm name not in restricted llm list, remove it
  711. restrict_model_names = [rm.model for rm in restrict_models]
  712. for m in provider_models:
  713. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  714. m.status = ModelStatus.NO_PERMISSION
  715. elif not quota_configuration.is_valid:
  716. m.status = ModelStatus.QUOTA_EXCEEDED
  717. return provider_models
  718. def _get_custom_provider_models(self,
  719. model_types: list[ModelType],
  720. provider_instance: ModelProvider,
  721. model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
  722. -> list[ModelWithProviderEntity]:
  723. """
  724. Get custom provider models.
  725. :param model_types: model types
  726. :param provider_instance: provider instance
  727. :param model_setting_map: model setting map
  728. :return:
  729. """
  730. provider_models = []
  731. credentials = None
  732. if self.custom_configuration.provider:
  733. credentials = self.custom_configuration.provider.credentials
  734. for model_type in model_types:
  735. if model_type not in self.provider.supported_model_types:
  736. continue
  737. models = provider_instance.models(model_type)
  738. for m in models:
  739. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  740. load_balancing_enabled = False
  741. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  742. model_setting = model_setting_map[m.model_type][m.model]
  743. if model_setting.enabled is False:
  744. status = ModelStatus.DISABLED
  745. if len(model_setting.load_balancing_configs) > 1:
  746. load_balancing_enabled = True
  747. provider_models.append(
  748. ModelWithProviderEntity(
  749. model=m.model,
  750. label=m.label,
  751. model_type=m.model_type,
  752. features=m.features,
  753. fetch_from=m.fetch_from,
  754. model_properties=m.model_properties,
  755. deprecated=m.deprecated,
  756. provider=SimpleModelProviderEntity(self.provider),
  757. status=status,
  758. load_balancing_enabled=load_balancing_enabled
  759. )
  760. )
  761. # custom models
  762. for model_configuration in self.custom_configuration.models:
  763. if model_configuration.model_type not in model_types:
  764. continue
  765. try:
  766. custom_model_schema = (
  767. provider_instance.get_model_instance(model_configuration.model_type)
  768. .get_customizable_model_schema_from_credentials(
  769. model_configuration.model,
  770. model_configuration.credentials
  771. )
  772. )
  773. except Exception as ex:
  774. logger.warning(f'get custom model schema failed, {ex}')
  775. continue
  776. if not custom_model_schema:
  777. continue
  778. status = ModelStatus.ACTIVE
  779. load_balancing_enabled = False
  780. if (custom_model_schema.model_type in model_setting_map
  781. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
  782. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  783. if model_setting.enabled is False:
  784. status = ModelStatus.DISABLED
  785. if len(model_setting.load_balancing_configs) > 1:
  786. load_balancing_enabled = True
  787. provider_models.append(
  788. ModelWithProviderEntity(
  789. model=custom_model_schema.model,
  790. label=custom_model_schema.label,
  791. model_type=custom_model_schema.model_type,
  792. features=custom_model_schema.features,
  793. fetch_from=custom_model_schema.fetch_from,
  794. model_properties=custom_model_schema.model_properties,
  795. deprecated=custom_model_schema.deprecated,
  796. provider=SimpleModelProviderEntity(self.provider),
  797. status=status,
  798. load_balancing_enabled=load_balancing_enabled
  799. )
  800. )
  801. return provider_models
  802. class ProviderConfigurations(BaseModel):
  803. """
  804. Model class for provider configuration dict.
  805. """
  806. tenant_id: str
  807. configurations: dict[str, ProviderConfiguration] = {}
  808. def __init__(self, tenant_id: str):
  809. super().__init__(tenant_id=tenant_id)
  810. def get_models(self,
  811. provider: Optional[str] = None,
  812. model_type: Optional[ModelType] = None,
  813. only_active: bool = False) \
  814. -> list[ModelWithProviderEntity]:
  815. """
  816. Get available models.
  817. If preferred provider type is `system`:
  818. Get the current **system mode** if provider supported,
  819. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  820. If there is no model configured in custom mode, it is treated as no_configure.
  821. system > custom > no_configure
  822. If preferred provider type is `custom`:
  823. If custom credentials are configured, it is treated as custom mode.
  824. Otherwise, get the current **system mode** if supported,
  825. If all system modes are not available (no quota), it is treated as no_configure.
  826. custom > system > no_configure
  827. If real mode is `system`, use system credentials to get models,
  828. paid quotas > provider free quotas > system free quotas
  829. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  830. If real mode is `custom`, use workspace custom credentials to get models,
  831. include pre-defined models, custom models(manual append).
  832. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  833. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  834. model status marked as `active` is available.
  835. :param provider: provider name
  836. :param model_type: model type
  837. :param only_active: only active models
  838. :return:
  839. """
  840. all_models = []
  841. for provider_configuration in self.values():
  842. if provider and provider_configuration.provider.provider != provider:
  843. continue
  844. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  845. return all_models
  846. def to_list(self) -> list[ProviderConfiguration]:
  847. """
  848. Convert to list.
  849. :return:
  850. """
  851. return list(self.values())
  852. def __getitem__(self, key):
  853. return self.configurations[key]
  854. def __setitem__(self, key, value):
  855. self.configurations[key] = value
  856. def __iter__(self):
  857. return iter(self.configurations)
  858. def values(self) -> Iterator[ProviderConfiguration]:
  859. return self.configurations.values()
  860. def get(self, key, default=None):
  861. return self.configurations.get(key, default)
  862. class ProviderModelBundle(BaseModel):
  863. """
  864. Provider model bundle.
  865. """
  866. configuration: ProviderConfiguration
  867. provider_instance: ModelProvider
  868. model_type_instance: AIModel
  869. class Config:
  870. """Configuration for this pydantic object."""
  871. arbitrary_types_allowed = True