Quellcode durchsuchen

fix: better memory usage from 800+ to 500+ (#11796)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
yihong vor 4 Monaten
Ursprung
Commit
7b03a0316d

+ 18 - 8
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -4,11 +4,10 @@ import json
 import logging
 import time
 from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import TYPE_CHECKING, Optional, Union, cast
 
 import google.auth.transport.requests
 import requests
-import vertexai.generative_models as glm
 from anthropic import AnthropicVertex, Stream
 from anthropic.types import (
     ContentBlockDeltaEvent,
@@ -19,8 +18,6 @@ from anthropic.types import (
     MessageStreamEvent,
 )
 from google.api_core import exceptions
-from google.cloud import aiplatform
-from google.oauth2 import service_account
 from PIL import Image
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 
+if TYPE_CHECKING:
+    import vertexai.generative_models as glm
+
 logger = logging.getLogger(__name__)
 
 
@@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
         :param stream: is stream response
         :return: full response or stream response chunk generator result
         """
+        from google.oauth2 import service_account
+
         # use Anthropic official SDK references
         # - https://github.com/anthropics/anthropic-sdk-python
         service_account_key = credentials.get("vertex_service_account_key", "")
@@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
 
         return text.rstrip()
 
-    def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
+    def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
         """
         Convert tool messages to glm tools
 
         :param tools: tool messages
         :return: glm tools
         """
+        import vertexai.generative_models as glm
+
         return glm.Tool(
             function_declarations=[
                 glm.FunctionDeclaration(
@@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
         :param user: unique user id
         :return: full response or stream response chunk generator result
         """
+        import vertexai.generative_models as glm
+        from google.cloud import aiplatform
+        from google.oauth2 import service_account
+
         config_kwargs = model_parameters.copy()
         config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
 
@@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
         return self._handle_generate_response(model, credentials, response, prompt_messages)
 
     def _handle_generate_response(
-        self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
+        self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
     ) -> LLMResult:
         """
         Handle llm response
@@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
         return result
 
     def _handle_generate_stream_response(
-        self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
+        self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
     ) -> Generator:
         """
         Handle llm stream response
@@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
 
         return message_text
 
-    def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
+    def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
         """
         Format a single message into glm.Content for Google API
 
         :param message: one PromptMessage
         :return: glm Content representation of message
         """
+        import vertexai.generative_models as glm
+
         if isinstance(message, UserPromptMessage):
             glm_content = glm.Content(role="user", parts=[])
 

+ 14 - 4
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

@@ -2,12 +2,9 @@ import base64
 import json
 import time
 from decimal import Decimal
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 import tiktoken
-from google.cloud import aiplatform
-from google.oauth2 import service_account
-from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
 
 from core.entities.embedding_type import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
@@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
 
+if TYPE_CHECKING:
+    from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
+else:
+    VertexTextEmbeddingModel = None
+
 
 class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
     """
@@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
         :param input_type: input type
         :return: embeddings result
         """
+        from google.cloud import aiplatform
+        from google.oauth2 import service_account
+        from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
+
         service_account_key = credentials.get("vertex_service_account_key", "")
         project_id = credentials["vertex_project_id"]
         location = credentials["vertex_location"]
@@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
         :param credentials: model credentials
         :return:
         """
+        from google.cloud import aiplatform
+        from google.oauth2 import service_account
+        from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
+
         try:
             service_account_key = credentials.get("vertex_service_account_key", "")
             project_id = credentials["vertex_project_id"]

+ 9 - 6
api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

@@ -1,18 +1,19 @@
 import re
 from typing import Optional
 
-import jieba
-from jieba.analyse import default_tfidf
-
-from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
-
 
 class JiebaKeywordTableHandler:
     def __init__(self):
-        default_tfidf.stop_words = STOPWORDS
+        import jieba.analyse
+
+        from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
+
+        jieba.analyse.default_tfidf.stop_words = STOPWORDS
 
     def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
         """Extract keywords with JIEBA tfidf."""
+        import jieba
+
         keywords = jieba.analyse.extract_tags(
             sentence=text,
             topK=max_keywords_per_chunk,
@@ -22,6 +23,8 @@ class JiebaKeywordTableHandler:
 
     def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
         """Get subtokens from a list of tokens., filtering for stopwords."""
+        from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
+
         results = set()
         for token in tokens:
             results.add(token)

+ 4 - 2
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -6,10 +6,8 @@ from contextlib import contextmanager
 from typing import Any
 
 import jieba.posseg as pseg
-import nltk
 import numpy
 import oracledb
-from nltk.corpus import stopwords
 from pydantic import BaseModel, model_validator
 
 from configs import dify_config
@@ -202,6 +200,10 @@ class OracleVector(BaseVector):
         return docs
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # lazy import
+        import nltk
+        from nltk.corpus import stopwords
+
         top_k = kwargs.get("top_k", 5)
         # just not implement fetch by score_threshold now, may be later
         score_threshold = float(kwargs.get("score_threshold") or 0.0)

+ 11 - 6
api/core/workflow/nodes/document_extractor/node.py

@@ -8,12 +8,6 @@ import docx
 import pandas as pd
 import pypdfium2  # type: ignore
 import yaml  # type: ignore
-from unstructured.partition.api import partition_via_api
-from unstructured.partition.email import partition_email
-from unstructured.partition.epub import partition_epub
-from unstructured.partition.msg import partition_msg
-from unstructured.partition.ppt import partition_ppt
-from unstructured.partition.pptx import partition_pptx
 
 from configs import dify_config
 from core.file import File, FileTransferMethod, file_manager
@@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
 
 
 def _extract_text_from_ppt(file_content: bytes) -> str:
+    from unstructured.partition.ppt import partition_ppt
+
     try:
         with io.BytesIO(file_content) as file:
             elements = partition_ppt(file=file)
@@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
 
 
 def _extract_text_from_pptx(file_content: bytes) -> str:
+    from unstructured.partition.api import partition_via_api
+    from unstructured.partition.pptx import partition_pptx
+
     try:
         if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
             with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
 
 
 def _extract_text_from_epub(file_content: bytes) -> str:
+    from unstructured.partition.epub import partition_epub
+
     try:
         with io.BytesIO(file_content) as file:
             elements = partition_epub(file=file)
@@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
 
 
 def _extract_text_from_eml(file_content: bytes) -> str:
+    from unstructured.partition.email import partition_email
+
     try:
         with io.BytesIO(file_content) as file:
             elements = partition_email(file=file)
@@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
 
 
 def _extract_text_from_msg(file_content: bytes) -> str:
+    from unstructured.partition.msg import partition_msg
+
     try:
         with io.BytesIO(file_content) as file:
             elements = partition_msg(file=file)