commands.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import datetime
  2. import json
  3. import math
  4. import random
  5. import string
  6. import time
  7. import click
  8. from tqdm import tqdm
  9. from flask import current_app
  10. from langchain.embeddings import OpenAIEmbeddings
  11. from werkzeug.exceptions import NotFound
  12. from core.embedding.cached_embedding import CacheEmbedding
  13. from core.index.index import IndexBuilder
  14. from core.model_providers.model_factory import ModelFactory
  15. from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
  16. from core.model_providers.models.entity.model_params import ModelType
  17. from core.model_providers.providers.hosted import hosted_model_providers
  18. from core.model_providers.providers.openai_provider import OpenAIProvider
  19. from libs.password import password_pattern, valid_password, hash_password
  20. from libs.helper import email as email_validate
  21. from extensions.ext_database import db
  22. from libs.rsa import generate_key_pair
  23. from models.account import InvitationCode, Tenant, TenantAccountJoin
  24. from models.dataset import Dataset, DatasetQuery, Document
  25. from models.model import Account, AppModelConfig, App
  26. import secrets
  27. import base64
  28. from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
  29. @click.command('reset-password', help='Reset the account password.')
  30. @click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
  31. @click.option('--new-password', prompt=True, help='the new password.')
  32. @click.option('--password-confirm', prompt=True, help='the new password confirm.')
  33. def reset_password(email, new_password, password_confirm):
  34. if str(new_password).strip() != str(password_confirm).strip():
  35. click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
  36. return
  37. account = db.session.query(Account). \
  38. filter(Account.email == email). \
  39. one_or_none()
  40. if not account:
  41. click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
  42. return
  43. try:
  44. valid_password(new_password)
  45. except:
  46. click.echo(
  47. click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
  48. return
  49. # generate password salt
  50. salt = secrets.token_bytes(16)
  51. base64_salt = base64.b64encode(salt).decode()
  52. # encrypt password with salt
  53. password_hashed = hash_password(new_password, salt)
  54. base64_password_hashed = base64.b64encode(password_hashed).decode()
  55. account.password = base64_password_hashed
  56. account.password_salt = base64_salt
  57. db.session.commit()
  58. click.echo(click.style('Congratulations!, password has been reset.', fg='green'))
  59. @click.command('reset-email', help='Reset the account email.')
  60. @click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
  61. @click.option('--new-email', prompt=True, help='the new email.')
  62. @click.option('--email-confirm', prompt=True, help='the new email confirm.')
  63. def reset_email(email, new_email, email_confirm):
  64. if str(new_email).strip() != str(email_confirm).strip():
  65. click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
  66. return
  67. account = db.session.query(Account). \
  68. filter(Account.email == email). \
  69. one_or_none()
  70. if not account:
  71. click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
  72. return
  73. try:
  74. email_validate(new_email)
  75. except:
  76. click.echo(
  77. click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
  78. return
  79. account.email = new_email
  80. db.session.commit()
  81. click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
  82. @click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
  83. 'After the reset, all LLM credentials will become invalid, '
  84. 'requiring re-entry.'
  85. 'Only support SELF_HOSTED mode.')
  86. @click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
  87. ' this operation cannot be rolled back!', fg='red'))
  88. def reset_encrypt_key_pair():
  89. if current_app.config['EDITION'] != 'SELF_HOSTED':
  90. click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
  91. return
  92. tenant = db.session.query(Tenant).first()
  93. if not tenant:
  94. click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
  95. return
  96. tenant.encrypt_public_key = generate_key_pair(tenant.id)
  97. db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
  98. db.session.query(ProviderModel).delete()
  99. db.session.commit()
  100. click.echo(click.style('Congratulations! '
  101. 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
  102. @click.command('generate-invitation-codes', help='Generate invitation codes.')
  103. @click.option('--batch', help='The batch of invitation codes.')
  104. @click.option('--count', prompt=True, help='Invitation codes count.')
  105. def generate_invitation_codes(batch, count):
  106. if not batch:
  107. now = datetime.datetime.now()
  108. batch = now.strftime('%Y%m%d%H%M%S')
  109. if not count or int(count) <= 0:
  110. click.echo(click.style('sorry. the count must be greater than 0.', fg='red'))
  111. return
  112. count = int(count)
  113. click.echo('Start generate {} invitation codes for batch {}.'.format(count, batch))
  114. codes = ''
  115. for i in range(count):
  116. code = generate_invitation_code()
  117. invitation_code = InvitationCode(
  118. code=code,
  119. batch=batch
  120. )
  121. db.session.add(invitation_code)
  122. click.echo(code)
  123. codes += code + "\n"
  124. db.session.commit()
  125. filename = 'storage/invitation-codes-{}.txt'.format(batch)
  126. with open(filename, 'w') as f:
  127. f.write(codes)
  128. click.echo(click.style(
  129. 'Congratulations! Generated {} invitation codes for batch {} and saved to the file \'{}\''.format(count, batch,
  130. filename),
  131. fg='green'))
  132. def generate_invitation_code():
  133. code = generate_upper_string()
  134. while db.session.query(InvitationCode).filter(InvitationCode.code == code).count() > 0:
  135. code = generate_upper_string()
  136. return code
  137. def generate_upper_string():
  138. letters_digits = string.ascii_uppercase + string.digits
  139. result = ""
  140. for i in range(8):
  141. result += random.choice(letters_digits)
  142. return result
  143. @click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
  144. def recreate_all_dataset_indexes():
  145. click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
  146. recreate_count = 0
  147. page = 1
  148. while True:
  149. try:
  150. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  151. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  152. except NotFound:
  153. break
  154. page += 1
  155. for dataset in datasets:
  156. try:
  157. click.echo('Recreating dataset index: {}'.format(dataset.id))
  158. index = IndexBuilder.get_index(dataset, 'high_quality')
  159. if index and index._is_origin():
  160. index.recreate_dataset(dataset)
  161. recreate_count += 1
  162. else:
  163. click.echo('passed.')
  164. except Exception as e:
  165. click.echo(
  166. click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
  167. continue
  168. click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
  169. @click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
  170. def clean_unused_dataset_indexes():
  171. click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
  172. clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
  173. start_at = time.perf_counter()
  174. thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
  175. page = 1
  176. while True:
  177. try:
  178. datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
  179. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  180. except NotFound:
  181. break
  182. page += 1
  183. for dataset in datasets:
  184. dataset_query = db.session.query(DatasetQuery).filter(
  185. DatasetQuery.created_at > thirty_days_ago,
  186. DatasetQuery.dataset_id == dataset.id
  187. ).all()
  188. if not dataset_query or len(dataset_query) == 0:
  189. documents = db.session.query(Document).filter(
  190. Document.dataset_id == dataset.id,
  191. Document.indexing_status == 'completed',
  192. Document.enabled == True,
  193. Document.archived == False,
  194. Document.updated_at > thirty_days_ago
  195. ).all()
  196. if not documents or len(documents) == 0:
  197. try:
  198. # remove index
  199. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  200. kw_index = IndexBuilder.get_index(dataset, 'economy')
  201. # delete from vector index
  202. if vector_index:
  203. vector_index.delete()
  204. kw_index.delete()
  205. # update document
  206. update_params = {
  207. Document.enabled: False
  208. }
  209. Document.query.filter_by(dataset_id=dataset.id).update(update_params)
  210. db.session.commit()
  211. click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
  212. fg='green'))
  213. except Exception as e:
  214. click.echo(
  215. click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
  216. fg='red'))
  217. end_at = time.perf_counter()
  218. click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
  219. @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
  220. def sync_anthropic_hosted_providers():
  221. if not hosted_model_providers.anthropic:
  222. click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
  223. return
  224. click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
  225. count = 0
  226. new_quota_limit = hosted_model_providers.anthropic.quota_limit
  227. page = 1
  228. while True:
  229. try:
  230. providers = db.session.query(Provider).filter(
  231. Provider.provider_name == 'anthropic',
  232. Provider.provider_type == ProviderType.SYSTEM.value,
  233. Provider.quota_type == ProviderQuotaType.TRIAL.value,
  234. Provider.quota_limit != new_quota_limit
  235. ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
  236. except NotFound:
  237. break
  238. page += 1
  239. for provider in providers:
  240. try:
  241. click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
  242. .format(provider.tenant_id, provider.quota_limit, provider.quota_used))
  243. original_quota_limit = provider.quota_limit
  244. division = math.ceil(new_quota_limit / 1000)
  245. provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
  246. else original_quota_limit * division
  247. provider.quota_used = division * provider.quota_used
  248. db.session.commit()
  249. count += 1
  250. except Exception as e:
  251. click.echo(click.style(
  252. 'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
  253. fg='red'))
  254. continue
  255. click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
  256. @click.command('create-qdrant-indexes', help='Create qdrant indexes.')
  257. def create_qdrant_indexes():
  258. click.echo(click.style('Start create qdrant indexes.', fg='green'))
  259. create_count = 0
  260. page = 1
  261. while True:
  262. try:
  263. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  264. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  265. except NotFound:
  266. break
  267. page += 1
  268. for dataset in datasets:
  269. if dataset.index_struct_dict:
  270. if dataset.index_struct_dict['type'] != 'qdrant':
  271. try:
  272. click.echo('Create dataset qdrant index: {}'.format(dataset.id))
  273. try:
  274. embedding_model = ModelFactory.get_embedding_model(
  275. tenant_id=dataset.tenant_id,
  276. model_provider_name=dataset.embedding_model_provider,
  277. model_name=dataset.embedding_model
  278. )
  279. except Exception:
  280. try:
  281. embedding_model = ModelFactory.get_embedding_model(
  282. tenant_id=dataset.tenant_id
  283. )
  284. dataset.embedding_model = embedding_model.name
  285. dataset.embedding_model_provider = embedding_model.model_provider.provider_name
  286. except Exception:
  287. provider = Provider(
  288. id='provider_id',
  289. tenant_id=dataset.tenant_id,
  290. provider_name='openai',
  291. provider_type=ProviderType.SYSTEM.value,
  292. encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
  293. is_valid=True,
  294. )
  295. model_provider = OpenAIProvider(provider=provider)
  296. embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
  297. embeddings = CacheEmbedding(embedding_model)
  298. from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
  299. index = QdrantVectorIndex(
  300. dataset=dataset,
  301. config=QdrantConfig(
  302. endpoint=current_app.config.get('QDRANT_URL'),
  303. api_key=current_app.config.get('QDRANT_API_KEY'),
  304. root_path=current_app.root_path
  305. ),
  306. embeddings=embeddings
  307. )
  308. if index:
  309. index.create_qdrant_dataset(dataset)
  310. index_struct = {
  311. "type": 'qdrant',
  312. "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
  313. }
  314. dataset.index_struct = json.dumps(index_struct)
  315. db.session.commit()
  316. create_count += 1
  317. else:
  318. click.echo('passed.')
  319. except Exception as e:
  320. click.echo(
  321. click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
  322. continue
  323. click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
  324. @click.command('update-qdrant-indexes', help='Update qdrant indexes.')
  325. def update_qdrant_indexes():
  326. click.echo(click.style('Start Update qdrant indexes.', fg='green'))
  327. create_count = 0
  328. page = 1
  329. while True:
  330. try:
  331. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  332. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  333. except NotFound:
  334. break
  335. page += 1
  336. for dataset in datasets:
  337. if dataset.index_struct_dict:
  338. if dataset.index_struct_dict['type'] != 'qdrant':
  339. try:
  340. click.echo('Update dataset qdrant index: {}'.format(dataset.id))
  341. try:
  342. embedding_model = ModelFactory.get_embedding_model(
  343. tenant_id=dataset.tenant_id,
  344. model_provider_name=dataset.embedding_model_provider,
  345. model_name=dataset.embedding_model
  346. )
  347. except Exception:
  348. provider = Provider(
  349. id='provider_id',
  350. tenant_id=dataset.tenant_id,
  351. provider_name='openai',
  352. provider_type=ProviderType.CUSTOM.value,
  353. encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
  354. is_valid=True,
  355. )
  356. model_provider = OpenAIProvider(provider=provider)
  357. embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
  358. embeddings = CacheEmbedding(embedding_model)
  359. from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
  360. index = QdrantVectorIndex(
  361. dataset=dataset,
  362. config=QdrantConfig(
  363. endpoint=current_app.config.get('QDRANT_URL'),
  364. api_key=current_app.config.get('QDRANT_API_KEY'),
  365. root_path=current_app.root_path
  366. ),
  367. embeddings=embeddings
  368. )
  369. if index:
  370. index.update_qdrant_dataset(dataset)
  371. create_count += 1
  372. else:
  373. click.echo('passed.')
  374. except Exception as e:
  375. click.echo(
  376. click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
  377. continue
  378. click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
  379. @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
  380. @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
  381. def update_app_model_configs(batch_size):
  382. pre_prompt_template = '{{default_input}}'
  383. user_input_form_template = {
  384. "en-US": [
  385. {
  386. "paragraph": {
  387. "label": "Query",
  388. "variable": "default_input",
  389. "required": False,
  390. "default": ""
  391. }
  392. }
  393. ],
  394. "zh-Hans": [
  395. {
  396. "paragraph": {
  397. "label": "查询内容",
  398. "variable": "default_input",
  399. "required": False,
  400. "default": ""
  401. }
  402. }
  403. ]
  404. }
  405. click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
  406. total_records = db.session.query(AppModelConfig) \
  407. .join(App, App.app_model_config_id == AppModelConfig.id) \
  408. .filter(App.mode == 'completion') \
  409. .count()
  410. if total_records == 0:
  411. click.secho("No data to migrate.", fg='green')
  412. return
  413. num_batches = (total_records + batch_size - 1) // batch_size
  414. with tqdm(total=total_records, desc="Migrating Data") as pbar:
  415. for i in range(num_batches):
  416. offset = i * batch_size
  417. limit = min(batch_size, total_records - offset)
  418. click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
  419. data_batch = db.session.query(AppModelConfig) \
  420. .join(App, App.app_model_config_id == AppModelConfig.id) \
  421. .filter(App.mode == 'completion') \
  422. .order_by(App.created_at) \
  423. .offset(offset).limit(limit).all()
  424. if not data_batch:
  425. click.secho("No more data to migrate.", fg='green')
  426. break
  427. try:
  428. click.secho(f"Migrating {len(data_batch)} records...", fg='green')
  429. for data in data_batch:
  430. # click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
  431. if data.pre_prompt is None:
  432. data.pre_prompt = pre_prompt_template
  433. else:
  434. if pre_prompt_template in data.pre_prompt:
  435. continue
  436. data.pre_prompt += pre_prompt_template
  437. app_data = db.session.query(App) \
  438. .filter(App.id == data.app_id) \
  439. .one()
  440. account_data = db.session.query(Account) \
  441. .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
  442. .filter(TenantAccountJoin.role == 'owner') \
  443. .filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
  444. .one_or_none()
  445. if not account_data:
  446. continue
  447. if data.user_input_form is None or data.user_input_form == 'null':
  448. data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
  449. else:
  450. raw_json_data = json.loads(data.user_input_form)
  451. raw_json_data.append(user_input_form_template[account_data.interface_language][0])
  452. data.user_input_form = json.dumps(raw_json_data)
  453. # click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
  454. db.session.commit()
  455. except Exception as e:
  456. click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
  457. continue
  458. click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
  459. pbar.update(len(data_batch))
  460. def register_commands(app):
  461. app.cli.add_command(reset_password)
  462. app.cli.add_command(reset_email)
  463. app.cli.add_command(generate_invitation_codes)
  464. app.cli.add_command(reset_encrypt_key_pair)
  465. app.cli.add_command(recreate_all_dataset_indexes)
  466. app.cli.add_command(sync_anthropic_hosted_providers)
  467. app.cli.add_command(clean_unused_dataset_indexes)
  468. app.cli.add_command(create_qdrant_indexes)
  469. app.cli.add_command(update_qdrant_indexes)
  470. app.cli.add_command(update_app_model_configs)