Browse Source

refactor: Add an enumeration type and use the factory pattern to obtain the corresponding class (#9356)

zhuhao 6 months ago
parent
commit
cd7ab6231f

+ 14 - 11
api/core/rag/datasource/keyword/keyword_factory.py

@@ -1,8 +1,8 @@
 from typing import Any
 
 from configs import dify_config
-from core.rag.datasource.keyword.jieba.jieba import Jieba
 from core.rag.datasource.keyword.keyword_base import BaseKeyword
+from core.rag.datasource.keyword.keyword_type import KeyWordType
 from core.rag.models.document import Document
 from models.dataset import Dataset
 
@@ -13,16 +13,19 @@ class Keyword:
         self._keyword_processor = self._init_keyword()
 
     def _init_keyword(self) -> BaseKeyword:
-        config = dify_config
-        keyword_type = config.KEYWORD_STORE
-
-        if not keyword_type:
-            raise ValueError("Keyword store must be specified.")
-
-        if keyword_type == "jieba":
-            return Jieba(dataset=self._dataset)
-        else:
-            raise ValueError(f"Keyword store {keyword_type} is not supported.")
+        keyword_type = dify_config.KEYWORD_STORE
+        keyword_factory = self.get_keyword_factory(keyword_type)
+        return keyword_factory(self._dataset)
+
+    @staticmethod
+    def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
+        match keyword_type:
+            case KeyWordType.JIEBA:
+                from core.rag.datasource.keyword.jieba.jieba import Jieba
+
+                return Jieba
+            case _:
+                raise ValueError(f"Keyword store {keyword_type} is not supported.")
 
     def create(self, texts: list[Document], **kwargs):
         self._keyword_processor.create(texts, **kwargs)

+ 5 - 0
api/core/rag/datasource/keyword/keyword_type.py

@@ -0,0 +1,5 @@
+from enum import Enum
+
+
+class KeyWordType(str, Enum):
+    JIEBA = "jieba"

+ 18 - 8
api/services/auth/api_key_auth_factory.py

@@ -1,15 +1,25 @@
-from services.auth.firecrawl import FirecrawlAuth
-from services.auth.jina import JinaAuth
+from services.auth.api_key_auth_base import ApiKeyAuthBase
+from services.auth.auth_type import AuthType
 
 
 class ApiKeyAuthFactory:
     def __init__(self, provider: str, credentials: dict):
-        if provider == "firecrawl":
-            self.auth = FirecrawlAuth(credentials)
-        elif provider == "jinareader":
-            self.auth = JinaAuth(credentials)
-        else:
-            raise ValueError("Invalid provider")
+        auth_factory = self.get_apikey_auth_factory(provider)
+        self.auth = auth_factory(credentials)
 
     def validate_credentials(self):
         return self.auth.validate_credentials()
+
+    @staticmethod
+    def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
+        match provider:
+            case AuthType.FIRECRAWL:
+                from services.auth.firecrawl.firecrawl import FirecrawlAuth
+
+                return FirecrawlAuth
+            case AuthType.JINA:
+                from services.auth.jina.jina import JinaAuth
+
+                return JinaAuth
+            case _:
+                raise ValueError("Invalid provider")

+ 6 - 0
api/services/auth/auth_type.py

@@ -0,0 +1,6 @@
+from enum import Enum
+
+
+class AuthType(str, Enum):
+    FIRECRAWL = "firecrawl"
+    JINA = "jinareader"

+ 0 - 0
api/services/auth/firecrawl/__init__.py


+ 0 - 0
api/services/auth/firecrawl.py → api/services/auth/firecrawl/firecrawl.py


+ 0 - 0
api/services/auth/jina/__init__.py


+ 0 - 0
api/services/auth/jina.py → api/services/auth/jina/jina.py