123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- import tempfile
- from binascii import hexlify, unhexlify
- from collections.abc import Generator
- from core.model_manager import ModelManager
- from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
- from core.model_runtime.entities.message_entities import (
- PromptMessage,
- SystemPromptMessage,
- UserPromptMessage,
- )
- from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
- from core.plugin.entities.request import (
- RequestInvokeLLM,
- RequestInvokeModeration,
- RequestInvokeRerank,
- RequestInvokeSpeech2Text,
- RequestInvokeSummary,
- RequestInvokeTextEmbedding,
- RequestInvokeTTS,
- )
- from core.tools.entities.tool_entities import ToolProviderType
- from core.tools.utils.model_invocation_utils import ModelInvocationUtils
- from core.workflow.nodes.llm.node import LLMNode
- from models.account import Tenant
- class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
- @classmethod
- def invoke_llm(
- cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
- ) -> Generator[LLMResultChunk, None, None] | LLMResult:
- """
- invoke llm
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_llm(
- prompt_messages=payload.prompt_messages,
- model_parameters=payload.completion_params,
- tools=payload.tools,
- stop=payload.stop,
- stream=True if payload.stream is None else payload.stream,
- user=user_id,
- )
- if isinstance(response, Generator):
- def handle() -> Generator[LLMResultChunk, None, None]:
- for chunk in response:
- if chunk.delta.usage:
- LLMNode.deduct_llm_quota(
- tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
- )
- yield chunk
- return handle()
- else:
- if response.usage:
- LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
- def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
- yield LLMResultChunk(
- model=response.model,
- prompt_messages=response.prompt_messages,
- system_fingerprint=response.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=0,
- message=response.message,
- usage=response.usage,
- finish_reason="",
- ),
- )
- return handle_non_streaming(response)
- @classmethod
- def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
- """
- invoke text embedding
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_text_embedding(
- texts=payload.texts,
- user=user_id,
- )
- return response
- @classmethod
- def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
- """
- invoke rerank
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_rerank(
- query=payload.query,
- docs=payload.docs,
- score_threshold=payload.score_threshold,
- top_n=payload.top_n,
- user=user_id,
- )
- return response
- @classmethod
- def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
- """
- invoke tts
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_tts(
- content_text=payload.content_text,
- tenant_id=tenant.id,
- voice=payload.voice,
- user=user_id,
- )
- def handle() -> Generator[dict, None, None]:
- for chunk in response:
- yield {"result": hexlify(chunk).decode("utf-8")}
- return handle()
- @classmethod
- def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
- """
- invoke speech2text
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
- temp.write(unhexlify(payload.file))
- temp.flush()
- temp.seek(0)
- response = model_instance.invoke_speech2text(
- file=temp,
- user=user_id,
- )
- return {
- "result": response,
- }
- @classmethod
- def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
- """
- invoke moderation
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_moderation(
- text=payload.text,
- user=user_id,
- )
- return {
- "result": response,
- }
- @classmethod
- def get_system_model_max_tokens(cls, tenant_id: str) -> int:
- """
- get system model max tokens
- """
- return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
- @classmethod
- def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
- """
- get prompt tokens
- """
- return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
- @classmethod
- def invoke_system_model(
- cls,
- user_id: str,
- tenant: Tenant,
- prompt_messages: list[PromptMessage],
- ) -> LLMResult:
- """
- invoke system model
- """
- return ModelInvocationUtils.invoke(
- user_id=user_id,
- tenant_id=tenant.id,
- tool_type=ToolProviderType.PLUGIN,
- tool_name="plugin",
- prompt_messages=prompt_messages,
- )
- @classmethod
- def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
- """
- invoke summary
- """
- max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
- content = payload.text
- SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
- and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
- retain the original meaning and keep the key points.
- however, the text you got is too long, what you got is possible a part of the text.
- Please summarize the text you got.
- Here is the extra instruction you need to follow:
- <extra_instruction>
- {payload.instruction}
- </extra_instruction>
- """
- if (
- cls.get_prompt_tokens(
- tenant_id=tenant.id,
- prompt_messages=[UserPromptMessage(content=content)],
- )
- < max_tokens * 0.6
- ):
- return content
- def get_prompt_tokens(content: str) -> int:
- return cls.get_prompt_tokens(
- tenant_id=tenant.id,
- prompt_messages=[
- SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
- UserPromptMessage(content=content),
- ],
- )
- def summarize(content: str) -> str:
- summary = cls.invoke_system_model(
- user_id=user_id,
- tenant=tenant,
- prompt_messages=[
- SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
- UserPromptMessage(content=content),
- ],
- )
- assert isinstance(summary.message.content, str)
- return summary.message.content
- lines = content.split("\n")
- new_lines: list[str] = []
- # split long line into multiple lines
- for i in range(len(lines)):
- line = lines[i]
- if not line.strip():
- continue
- if len(line) < max_tokens * 0.5:
- new_lines.append(line)
- elif get_prompt_tokens(line) > max_tokens * 0.7:
- while get_prompt_tokens(line) > max_tokens * 0.7:
- new_lines.append(line[: int(max_tokens * 0.5)])
- line = line[int(max_tokens * 0.5) :]
- new_lines.append(line)
- else:
- new_lines.append(line)
- # merge lines into messages with max tokens
- messages: list[str] = []
- for i in new_lines: # type: ignore
- if len(messages) == 0:
- messages.append(i) # type: ignore
- else:
- if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
- messages[-1] += i # type: ignore
- if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
- messages.append(i) # type: ignore
- else:
- messages[-1] += i # type: ignore
- summaries = []
- for i in range(len(messages)):
- message = messages[i]
- summary = summarize(message)
- summaries.append(summary)
- result = "\n".join(summaries)
- if (
- cls.get_prompt_tokens(
- tenant_id=tenant.id,
- prompt_messages=[UserPromptMessage(content=result)],
- )
- > max_tokens * 0.7
- ):
- return cls.invoke_summary(
- user_id=user_id,
- tenant=tenant,
- payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
- )
- return result
|