ops_trace_manager.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from flask import current_app
  11. from sqlalchemy import select
  12. from sqlalchemy.orm import Session
  13. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  14. from core.ops.entities.config_entity import (
  15. OPS_FILE_PATH,
  16. LangfuseConfig,
  17. LangSmithConfig,
  18. OpikConfig,
  19. TracingProviderEnum,
  20. )
  21. from core.ops.entities.trace_entity import (
  22. DatasetRetrievalTraceInfo,
  23. GenerateNameTraceInfo,
  24. MessageTraceInfo,
  25. ModerationTraceInfo,
  26. SuggestedQuestionTraceInfo,
  27. TaskData,
  28. ToolTraceInfo,
  29. TraceTaskName,
  30. WorkflowTraceInfo,
  31. )
  32. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  33. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  34. from core.ops.utils import get_message_data
  35. from extensions.ext_database import db
  36. from extensions.ext_storage import storage
  37. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  38. from models.workflow import WorkflowAppLog, WorkflowRun
  39. from tasks.ops_trace_task import process_trace_tasks
  40. def build_opik_trace_instance(config: OpikConfig):
  41. from core.ops.opik_trace.opik_trace import OpikDataTrace
  42. return OpikDataTrace(config)
  43. provider_config_map: dict[str, dict[str, Any]] = {
  44. TracingProviderEnum.LANGFUSE.value: {
  45. "config_class": LangfuseConfig,
  46. "secret_keys": ["public_key", "secret_key"],
  47. "other_keys": ["host", "project_key"],
  48. "trace_instance": LangFuseDataTrace,
  49. },
  50. TracingProviderEnum.LANGSMITH.value: {
  51. "config_class": LangSmithConfig,
  52. "secret_keys": ["api_key"],
  53. "other_keys": ["project", "endpoint"],
  54. "trace_instance": LangSmithDataTrace,
  55. },
  56. TracingProviderEnum.OPIK.value: {
  57. "config_class": OpikConfig,
  58. "secret_keys": ["api_key"],
  59. "other_keys": ["project", "url", "workspace"],
  60. "trace_instance": lambda config: build_opik_trace_instance(config),
  61. },
  62. }
  63. class OpsTraceManager:
  64. @classmethod
  65. def encrypt_tracing_config(
  66. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  67. ):
  68. """
  69. Encrypt tracing config.
  70. :param tenant_id: tenant id
  71. :param tracing_provider: tracing provider
  72. :param tracing_config: tracing config dictionary to be encrypted
  73. :param current_trace_config: current tracing configuration for keeping existing values
  74. :return: encrypted tracing configuration
  75. """
  76. # Get the configuration class and the keys that require encryption
  77. config_class, secret_keys, other_keys = (
  78. provider_config_map[tracing_provider]["config_class"],
  79. provider_config_map[tracing_provider]["secret_keys"],
  80. provider_config_map[tracing_provider]["other_keys"],
  81. )
  82. new_config = {}
  83. # Encrypt necessary keys
  84. for key in secret_keys:
  85. if key in tracing_config:
  86. if "*" in tracing_config[key]:
  87. # If the key contains '*', retain the original value from the current config
  88. new_config[key] = current_trace_config.get(key, tracing_config[key])
  89. else:
  90. # Otherwise, encrypt the key
  91. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  92. for key in other_keys:
  93. new_config[key] = tracing_config.get(key, "")
  94. # Create a new instance of the config class with the new configuration
  95. encrypted_config = config_class(**new_config)
  96. return encrypted_config.model_dump()
  97. @classmethod
  98. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  99. """
  100. Decrypt tracing config
  101. :param tenant_id: tenant id
  102. :param tracing_provider: tracing provider
  103. :param tracing_config: tracing config
  104. :return:
  105. """
  106. config_class, secret_keys, other_keys = (
  107. provider_config_map[tracing_provider]["config_class"],
  108. provider_config_map[tracing_provider]["secret_keys"],
  109. provider_config_map[tracing_provider]["other_keys"],
  110. )
  111. new_config = {}
  112. for key in secret_keys:
  113. if key in tracing_config:
  114. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  115. for key in other_keys:
  116. new_config[key] = tracing_config.get(key, "")
  117. return config_class(**new_config).model_dump()
  118. @classmethod
  119. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  120. """
  121. Decrypt tracing config
  122. :param tracing_provider: tracing provider
  123. :param decrypt_tracing_config: tracing config
  124. :return:
  125. """
  126. config_class, secret_keys, other_keys = (
  127. provider_config_map[tracing_provider]["config_class"],
  128. provider_config_map[tracing_provider]["secret_keys"],
  129. provider_config_map[tracing_provider]["other_keys"],
  130. )
  131. new_config = {}
  132. for key in secret_keys:
  133. if key in decrypt_tracing_config:
  134. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  135. for key in other_keys:
  136. new_config[key] = decrypt_tracing_config.get(key, "")
  137. return config_class(**new_config).model_dump()
  138. @classmethod
  139. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  140. """
  141. Get decrypted tracing config
  142. :param app_id: app id
  143. :param tracing_provider: tracing provider
  144. :return:
  145. """
  146. trace_config_data: Optional[TraceAppConfig] = (
  147. db.session.query(TraceAppConfig)
  148. .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  149. .first()
  150. )
  151. if not trace_config_data:
  152. return None
  153. # decrypt_token
  154. app = db.session.query(App).filter(App.id == app_id).first()
  155. if not app:
  156. raise ValueError("App not found")
  157. tenant_id = app.tenant_id
  158. decrypt_tracing_config = cls.decrypt_tracing_config(
  159. tenant_id, tracing_provider, trace_config_data.tracing_config
  160. )
  161. return decrypt_tracing_config
  162. @classmethod
  163. def get_ops_trace_instance(
  164. cls,
  165. app_id: Optional[Union[UUID, str]] = None,
  166. ):
  167. """
  168. Get ops trace through model config
  169. :param app_id: app_id
  170. :return:
  171. """
  172. if isinstance(app_id, UUID):
  173. app_id = str(app_id)
  174. if app_id is None:
  175. return None
  176. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  177. if app is None:
  178. return None
  179. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  180. if app_ops_trace_config is None:
  181. return None
  182. tracing_provider = app_ops_trace_config.get("tracing_provider")
  183. if tracing_provider is None or tracing_provider not in provider_config_map:
  184. return None
  185. # decrypt_token
  186. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  187. if app_ops_trace_config.get("enabled"):
  188. trace_instance, config_class = (
  189. provider_config_map[tracing_provider]["trace_instance"],
  190. provider_config_map[tracing_provider]["config_class"],
  191. )
  192. if not decrypt_trace_config:
  193. return None
  194. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  195. return tracing_instance
  196. return None
  197. @classmethod
  198. def get_app_config_through_message_id(cls, message_id: str):
  199. app_model_config = None
  200. message_data = db.session.query(Message).filter(Message.id == message_id).first()
  201. if not message_data:
  202. return None
  203. conversation_id = message_data.conversation_id
  204. conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
  205. if not conversation_data:
  206. return None
  207. if conversation_data.app_model_config_id:
  208. app_model_config = (
  209. db.session.query(AppModelConfig)
  210. .filter(AppModelConfig.id == conversation_data.app_model_config_id)
  211. .first()
  212. )
  213. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  214. app_model_config = conversation_data.override_model_configs
  215. return app_model_config
  216. @classmethod
  217. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  218. """
  219. Update app tracing config
  220. :param app_id: app id
  221. :param enabled: enabled
  222. :param tracing_provider: tracing provider
  223. :return:
  224. """
  225. # auth check
  226. if tracing_provider not in provider_config_map and tracing_provider is not None:
  227. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  228. app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  229. if not app_config:
  230. raise ValueError("App not found")
  231. app_config.tracing = json.dumps(
  232. {
  233. "enabled": enabled,
  234. "tracing_provider": tracing_provider,
  235. }
  236. )
  237. db.session.commit()
  238. @classmethod
  239. def get_app_tracing_config(cls, app_id: str):
  240. """
  241. Get app tracing config
  242. :param app_id: app id
  243. :return:
  244. """
  245. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  246. if not app:
  247. raise ValueError("App not found")
  248. if not app.tracing:
  249. return {"enabled": False, "tracing_provider": None}
  250. app_trace_config = json.loads(app.tracing)
  251. return app_trace_config
  252. @staticmethod
  253. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  254. """
  255. Check trace config is effective
  256. :param tracing_config: tracing config
  257. :param tracing_provider: tracing provider
  258. :return:
  259. """
  260. config_type, trace_instance = (
  261. provider_config_map[tracing_provider]["config_class"],
  262. provider_config_map[tracing_provider]["trace_instance"],
  263. )
  264. tracing_config = config_type(**tracing_config)
  265. return trace_instance(tracing_config).api_check()
  266. @staticmethod
  267. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  268. """
  269. get trace config is project key
  270. :param tracing_config: tracing config
  271. :param tracing_provider: tracing provider
  272. :return:
  273. """
  274. config_type, trace_instance = (
  275. provider_config_map[tracing_provider]["config_class"],
  276. provider_config_map[tracing_provider]["trace_instance"],
  277. )
  278. tracing_config = config_type(**tracing_config)
  279. return trace_instance(tracing_config).get_project_key()
  280. @staticmethod
  281. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  282. """
  283. get trace config is project key
  284. :param tracing_config: tracing config
  285. :param tracing_provider: tracing provider
  286. :return:
  287. """
  288. config_type, trace_instance = (
  289. provider_config_map[tracing_provider]["config_class"],
  290. provider_config_map[tracing_provider]["trace_instance"],
  291. )
  292. tracing_config = config_type(**tracing_config)
  293. return trace_instance(tracing_config).get_project_url()
  294. class TraceTask:
  295. def __init__(
  296. self,
  297. trace_type: Any,
  298. message_id: Optional[str] = None,
  299. workflow_run: Optional[WorkflowRun] = None,
  300. conversation_id: Optional[str] = None,
  301. user_id: Optional[str] = None,
  302. timer: Optional[Any] = None,
  303. **kwargs,
  304. ):
  305. self.trace_type = trace_type
  306. self.message_id = message_id
  307. self.workflow_run_id = workflow_run.id if workflow_run else None
  308. self.conversation_id = conversation_id
  309. self.user_id = user_id
  310. self.timer = timer
  311. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  312. self.app_id = None
  313. self.kwargs = kwargs
  314. def execute(self):
  315. return self.preprocess()
  316. def preprocess(self):
  317. preprocess_map = {
  318. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  319. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  320. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  321. ),
  322. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  323. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  324. message_id=self.message_id, timer=self.timer, **self.kwargs
  325. ),
  326. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  327. message_id=self.message_id, timer=self.timer, **self.kwargs
  328. ),
  329. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  330. message_id=self.message_id, timer=self.timer, **self.kwargs
  331. ),
  332. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  333. message_id=self.message_id, timer=self.timer, **self.kwargs
  334. ),
  335. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  336. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  337. ),
  338. }
  339. return preprocess_map.get(self.trace_type, lambda: None)()
  340. # process methods for different trace types
  341. def conversation_trace(self, **kwargs):
  342. return kwargs
  343. def workflow_trace(
  344. self,
  345. *,
  346. workflow_run_id: str | None,
  347. conversation_id: str | None,
  348. user_id: str | None,
  349. ):
  350. if not workflow_run_id:
  351. return {}
  352. with Session(db.engine) as session:
  353. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  354. workflow_run = session.scalars(workflow_run_stmt).first()
  355. if not workflow_run:
  356. raise ValueError("Workflow run not found")
  357. workflow_id = workflow_run.workflow_id
  358. tenant_id = workflow_run.tenant_id
  359. workflow_run_id = workflow_run.id
  360. workflow_run_elapsed_time = workflow_run.elapsed_time
  361. workflow_run_status = workflow_run.status
  362. workflow_run_inputs = workflow_run.inputs_dict
  363. workflow_run_outputs = workflow_run.outputs_dict
  364. workflow_run_version = workflow_run.version
  365. error = workflow_run.error or ""
  366. total_tokens = workflow_run.total_tokens
  367. file_list = workflow_run_inputs.get("sys.file") or []
  368. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  369. # get workflow_app_log_id
  370. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  371. WorkflowAppLog.tenant_id == tenant_id,
  372. WorkflowAppLog.app_id == workflow_run.app_id,
  373. WorkflowAppLog.workflow_run_id == workflow_run.id,
  374. )
  375. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  376. # get message_id
  377. message_id = None
  378. if conversation_id:
  379. message_data_stmt = select(Message.id).where(
  380. Message.conversation_id == conversation_id,
  381. Message.workflow_run_id == workflow_run_id,
  382. )
  383. message_id = session.scalar(message_data_stmt)
  384. metadata = {
  385. "workflow_id": workflow_id,
  386. "conversation_id": conversation_id,
  387. "workflow_run_id": workflow_run_id,
  388. "tenant_id": tenant_id,
  389. "elapsed_time": workflow_run_elapsed_time,
  390. "status": workflow_run_status,
  391. "version": workflow_run_version,
  392. "total_tokens": total_tokens,
  393. "file_list": file_list,
  394. "triggered_form": workflow_run.triggered_from,
  395. "user_id": user_id,
  396. }
  397. workflow_trace_info = WorkflowTraceInfo(
  398. workflow_data=workflow_run.to_dict(),
  399. conversation_id=conversation_id,
  400. workflow_id=workflow_id,
  401. tenant_id=tenant_id,
  402. workflow_run_id=workflow_run_id,
  403. workflow_run_elapsed_time=workflow_run_elapsed_time,
  404. workflow_run_status=workflow_run_status,
  405. workflow_run_inputs=workflow_run_inputs,
  406. workflow_run_outputs=workflow_run_outputs,
  407. workflow_run_version=workflow_run_version,
  408. error=error,
  409. total_tokens=total_tokens,
  410. file_list=file_list,
  411. query=query,
  412. metadata=metadata,
  413. workflow_app_log_id=workflow_app_log_id,
  414. message_id=message_id,
  415. start_time=workflow_run.created_at,
  416. end_time=workflow_run.finished_at,
  417. )
  418. return workflow_trace_info
  419. def message_trace(self, message_id: str | None):
  420. if not message_id:
  421. return {}
  422. message_data = get_message_data(message_id)
  423. if not message_data:
  424. return {}
  425. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  426. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  427. if not conversation_mode or len(conversation_mode) == 0:
  428. return {}
  429. conversation_mode = conversation_mode[0]
  430. created_at = message_data.created_at
  431. inputs = message_data.message
  432. # get message file data
  433. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  434. file_list = []
  435. if message_file_data and message_file_data.url is not None:
  436. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  437. file_list.append(file_url)
  438. metadata = {
  439. "conversation_id": message_data.conversation_id,
  440. "ls_provider": message_data.model_provider,
  441. "ls_model_name": message_data.model_id,
  442. "status": message_data.status,
  443. "from_end_user_id": message_data.from_end_user_id,
  444. "from_account_id": message_data.from_account_id,
  445. "agent_based": message_data.agent_based,
  446. "workflow_run_id": message_data.workflow_run_id,
  447. "from_source": message_data.from_source,
  448. "message_id": message_id,
  449. }
  450. message_tokens = message_data.message_tokens
  451. message_trace_info = MessageTraceInfo(
  452. message_id=message_id,
  453. message_data=message_data.to_dict(),
  454. conversation_model=conversation_mode,
  455. message_tokens=message_tokens,
  456. answer_tokens=message_data.answer_tokens,
  457. total_tokens=message_tokens + message_data.answer_tokens,
  458. error=message_data.error or "",
  459. inputs=inputs,
  460. outputs=message_data.answer,
  461. file_list=file_list,
  462. start_time=created_at,
  463. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  464. metadata=metadata,
  465. message_file_data=message_file_data,
  466. conversation_mode=conversation_mode,
  467. )
  468. return message_trace_info
  469. def moderation_trace(self, message_id, timer, **kwargs):
  470. moderation_result = kwargs.get("moderation_result")
  471. if not moderation_result:
  472. return {}
  473. inputs = kwargs.get("inputs")
  474. message_data = get_message_data(message_id)
  475. if not message_data:
  476. return {}
  477. metadata = {
  478. "message_id": message_id,
  479. "action": moderation_result.action,
  480. "preset_response": moderation_result.preset_response,
  481. "query": moderation_result.query,
  482. }
  483. # get workflow_app_log_id
  484. workflow_app_log_id = None
  485. if message_data.workflow_run_id:
  486. workflow_app_log_data = (
  487. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  488. )
  489. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  490. moderation_trace_info = ModerationTraceInfo(
  491. message_id=workflow_app_log_id or message_id,
  492. inputs=inputs,
  493. message_data=message_data.to_dict(),
  494. flagged=moderation_result.flagged,
  495. action=moderation_result.action,
  496. preset_response=moderation_result.preset_response,
  497. query=moderation_result.query,
  498. start_time=timer.get("start"),
  499. end_time=timer.get("end"),
  500. metadata=metadata,
  501. )
  502. return moderation_trace_info
  503. def suggested_question_trace(self, message_id, timer, **kwargs):
  504. suggested_question = kwargs.get("suggested_question", [])
  505. message_data = get_message_data(message_id)
  506. if not message_data:
  507. return {}
  508. metadata = {
  509. "message_id": message_id,
  510. "ls_provider": message_data.model_provider,
  511. "ls_model_name": message_data.model_id,
  512. "status": message_data.status,
  513. "from_end_user_id": message_data.from_end_user_id,
  514. "from_account_id": message_data.from_account_id,
  515. "agent_based": message_data.agent_based,
  516. "workflow_run_id": message_data.workflow_run_id,
  517. "from_source": message_data.from_source,
  518. }
  519. # get workflow_app_log_id
  520. workflow_app_log_id = None
  521. if message_data.workflow_run_id:
  522. workflow_app_log_data = (
  523. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  524. )
  525. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  526. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  527. message_id=workflow_app_log_id or message_id,
  528. message_data=message_data.to_dict(),
  529. inputs=message_data.message,
  530. outputs=message_data.answer,
  531. start_time=timer.get("start"),
  532. end_time=timer.get("end"),
  533. metadata=metadata,
  534. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  535. status=message_data.status,
  536. error=message_data.error,
  537. from_account_id=message_data.from_account_id,
  538. agent_based=message_data.agent_based,
  539. from_source=message_data.from_source,
  540. model_provider=message_data.model_provider,
  541. model_id=message_data.model_id,
  542. suggested_question=suggested_question,
  543. level=message_data.status,
  544. status_message=message_data.error,
  545. )
  546. return suggested_question_trace_info
  547. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  548. documents = kwargs.get("documents")
  549. message_data = get_message_data(message_id)
  550. if not message_data:
  551. return {}
  552. metadata = {
  553. "message_id": message_id,
  554. "ls_provider": message_data.model_provider,
  555. "ls_model_name": message_data.model_id,
  556. "status": message_data.status,
  557. "from_end_user_id": message_data.from_end_user_id,
  558. "from_account_id": message_data.from_account_id,
  559. "agent_based": message_data.agent_based,
  560. "workflow_run_id": message_data.workflow_run_id,
  561. "from_source": message_data.from_source,
  562. }
  563. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  564. message_id=message_id,
  565. inputs=message_data.query or message_data.inputs,
  566. documents=[doc.model_dump() for doc in documents] if documents else [],
  567. start_time=timer.get("start"),
  568. end_time=timer.get("end"),
  569. metadata=metadata,
  570. message_data=message_data.to_dict(),
  571. )
  572. return dataset_retrieval_trace_info
  573. def tool_trace(self, message_id, timer, **kwargs):
  574. tool_name = kwargs.get("tool_name", "")
  575. tool_inputs = kwargs.get("tool_inputs", {})
  576. tool_outputs = kwargs.get("tool_outputs", {})
  577. message_data = get_message_data(message_id)
  578. if not message_data:
  579. return {}
  580. tool_config = {}
  581. time_cost = 0
  582. error = None
  583. tool_parameters = {}
  584. created_time = message_data.created_at
  585. end_time = message_data.updated_at
  586. agent_thoughts = message_data.agent_thoughts
  587. for agent_thought in agent_thoughts:
  588. if tool_name in agent_thought.tools:
  589. created_time = agent_thought.created_at
  590. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  591. tool_config = tool_meta_data.get("tool_config", {})
  592. time_cost = tool_meta_data.get("time_cost", 0)
  593. end_time = created_time + timedelta(seconds=time_cost)
  594. error = tool_meta_data.get("error", "")
  595. tool_parameters = tool_meta_data.get("tool_parameters", {})
  596. metadata = {
  597. "message_id": message_id,
  598. "tool_name": tool_name,
  599. "tool_inputs": tool_inputs,
  600. "tool_outputs": tool_outputs,
  601. "tool_config": tool_config,
  602. "time_cost": time_cost,
  603. "error": error,
  604. "tool_parameters": tool_parameters,
  605. }
  606. file_url = ""
  607. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  608. if message_file_data:
  609. message_file_id = message_file_data.id if message_file_data else None
  610. type = message_file_data.type
  611. created_by_role = message_file_data.created_by_role
  612. created_user_id = message_file_data.created_by
  613. file_url = f"{self.file_base_url}/{message_file_data.url}"
  614. metadata.update(
  615. {
  616. "message_file_id": message_file_id,
  617. "created_by_role": created_by_role,
  618. "created_user_id": created_user_id,
  619. "type": type,
  620. }
  621. )
  622. tool_trace_info = ToolTraceInfo(
  623. message_id=message_id,
  624. message_data=message_data.to_dict(),
  625. tool_name=tool_name,
  626. start_time=timer.get("start") if timer else created_time,
  627. end_time=timer.get("end") if timer else end_time,
  628. tool_inputs=tool_inputs,
  629. tool_outputs=tool_outputs,
  630. metadata=metadata,
  631. message_file_data=message_file_data,
  632. error=error,
  633. inputs=message_data.message,
  634. outputs=message_data.answer,
  635. tool_config=tool_config,
  636. time_cost=time_cost,
  637. tool_parameters=tool_parameters,
  638. file_url=file_url,
  639. )
  640. return tool_trace_info
  641. def generate_name_trace(self, conversation_id, timer, **kwargs):
  642. generate_conversation_name = kwargs.get("generate_conversation_name")
  643. inputs = kwargs.get("inputs")
  644. tenant_id = kwargs.get("tenant_id")
  645. if not tenant_id:
  646. return {}
  647. start_time = timer.get("start")
  648. end_time = timer.get("end")
  649. metadata = {
  650. "conversation_id": conversation_id,
  651. "tenant_id": tenant_id,
  652. }
  653. generate_name_trace_info = GenerateNameTraceInfo(
  654. conversation_id=conversation_id,
  655. inputs=inputs,
  656. outputs=generate_conversation_name,
  657. start_time=start_time,
  658. end_time=end_time,
  659. metadata=metadata,
  660. tenant_id=tenant_id,
  661. )
  662. return generate_name_trace_info
  663. trace_manager_timer: Optional[threading.Timer] = None
  664. trace_manager_queue: queue.Queue = queue.Queue()
  665. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  666. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  667. class TraceQueueManager:
  668. def __init__(self, app_id=None, user_id=None):
  669. global trace_manager_timer
  670. self.app_id = app_id
  671. self.user_id = user_id
  672. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  673. self.flask_app = current_app._get_current_object() # type: ignore
  674. if trace_manager_timer is None:
  675. self.start_timer()
  676. def add_trace_task(self, trace_task: TraceTask):
  677. global trace_manager_timer, trace_manager_queue
  678. try:
  679. if self.trace_instance:
  680. trace_task.app_id = self.app_id
  681. trace_manager_queue.put(trace_task)
  682. except Exception as e:
  683. logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
  684. finally:
  685. self.start_timer()
  686. def collect_tasks(self):
  687. global trace_manager_queue
  688. tasks: list[TraceTask] = []
  689. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  690. task = trace_manager_queue.get_nowait()
  691. tasks.append(task)
  692. trace_manager_queue.task_done()
  693. return tasks
  694. def run(self):
  695. try:
  696. tasks = self.collect_tasks()
  697. if tasks:
  698. self.send_to_celery(tasks)
  699. except Exception as e:
  700. logging.exception("Error processing trace tasks")
  701. def start_timer(self):
  702. global trace_manager_timer
  703. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  704. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  705. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  706. trace_manager_timer.daemon = False
  707. trace_manager_timer.start()
  708. def send_to_celery(self, tasks: list[TraceTask]):
  709. with self.flask_app.app_context():
  710. for task in tasks:
  711. if task.app_id is None:
  712. continue
  713. file_id = uuid4().hex
  714. trace_info = task.execute()
  715. task_data = TaskData(
  716. app_id=task.app_id,
  717. trace_info_type=type(trace_info).__name__,
  718. trace_info=trace_info.model_dump() if trace_info else None,
  719. )
  720. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  721. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  722. file_info = {
  723. "file_id": file_id,
  724. "app_id": task.app_id,
  725. }
  726. process_trace_tasks.delay(file_info)