Ver Fonte

feat: add gte rerank for tongyi (#9153)

Fei He há 6 meses atrás
pai
commit
5c76131d3d

+ 0 - 0
api/core/model_runtime/model_providers/tongyi/rerank/__init__.py


+ 1 - 0
api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml

@@ -0,0 +1 @@
+- gte-rerank

+ 4 - 0
api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml

@@ -0,0 +1,4 @@
+model: gte-rerank
+model_type: rerank
+model_properties:
+  context_size: 4000

+ 136 - 0
api/core/model_runtime/model_providers/tongyi/rerank/rerank.py

@@ -0,0 +1,136 @@
+from typing import Optional
+
+import dashscope
+from dashscope.common.error import (
+    AuthenticationError,
+    InvalidParameter,
+    RequestFailure,
+    ServiceUnavailableError,
+    UnsupportedHTTPMethod,
+    UnsupportedModel,
+)
+
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class GTERerankModel(RerankModel):
+    """
+    Model class for GTE rerank model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        query: str,
+        docs: list[str],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id
+        :return: rerank result
+        """
+        if len(docs) == 0:
+            return RerankResult(model=model, docs=docs)
+
+        # initialize client
+        dashscope.api_key = credentials["dashscope_api_key"]
+
+        response = dashscope.TextReRank.call(
+            query=query,
+            documents=docs,
+            model=model,
+            top_n=top_n,
+            return_documents=True,
+        )
+
+        rerank_documents = []
+        for _, result in enumerate(response.output.results):
+            # format document
+            rerank_document = RerankDocument(
+                index=result.index,
+                score=result.relevance_score,
+                text=result["document"]["text"],
+            )
+
+            # score threshold check
+            if score_threshold is not None:
+                if result.relevance_score >= score_threshold:
+                    rerank_documents.append(rerank_document)
+            else:
+                rerank_documents.append(rerank_document)
+
+        return RerankResult(model=model, docs=rerank_documents)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self.invoke(
+                model=model,
+                credentials=credentials,
+                query="What is the capital of the United States?",
+                docs=[
+                    "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                    "Census, Carson City had a population of 55,274.",
+                    "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                    "are a political division controlled by the United States. Its capital is Saipan.",
+                ],
+                score_threshold=0.8,
+            )
+        except Exception as ex:
+            print(ex)
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                RequestFailure,
+            ],
+            InvokeServerUnavailableError: [
+                ServiceUnavailableError,
+            ],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [
+                AuthenticationError,
+            ],
+            InvokeBadRequestError: [
+                InvalidParameter,
+                UnsupportedModel,
+                UnsupportedHTTPMethod,
+            ],
+        }

+ 1 - 0
api/core/model_runtime/model_providers/tongyi/tongyi.yaml

@@ -18,6 +18,7 @@ supported_model_types:
   - llm
   - tts
   - text-embedding
+  - rerank
 configurate_methods:
   - predefined-model
   - customizable-model

+ 40 - 0
api/tests/integration_tests/model_runtime/tongyi/test_rerank.py

@@ -0,0 +1,40 @@
+import os
+
+import dashscope
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.tongyi.rerank.rerank import GTERerankModel
+
+
+def test_validate_credentials():
+    model = GTERerankModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(model="get-rank", credentials={"dashscope_api_key": "invalid_key"})
+
+    model.validate_credentials(
+        model="get-rank", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
+    )
+
+
+def test_invoke_model():
+    model = GTERerankModel()
+
+    result = model.invoke(
+        model=dashscope.TextReRank.Models.gte_rerank,
+        credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
+        query="什么是文本排序模型",
+        docs=[
+            "文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序",
+            "量子计算是计算科学的一个前沿领域",
+            "预训练语言模型的发展给文本排序模型带来了新的进展",
+        ],
+        score_threshold=0.7,
+    )
+
+    assert isinstance(result, RerankResult)
+    assert len(result.docs) == 1
+    assert result.docs[0].index == 0
+    assert result.docs[0].score >= 0.7