Browse Source

feat: xinference rerank model support (#1615)

takatost 1 year ago
parent
commit
0e627c920f

+ 3 - 3
api/controllers/console/workspace/model_providers.py

@@ -115,7 +115,7 @@ class ModelProviderModelValidateApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
         parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
         parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
@@ -155,7 +155,7 @@ class ModelProviderModelUpdateApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
         parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
+                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
         parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
@@ -184,7 +184,7 @@ class ModelProviderModelUpdateApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
         parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text'], location='args')
+                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
         args = parser.parse_args()
 
         provider_service = ProviderService()

+ 58 - 0
api/core/model_providers/models/reranking/xinference_reranking.py

@@ -0,0 +1,58 @@
+import logging
+from typing import Optional, List
+
+from langchain.schema import Document
+from xinference_client.client.restful.restful_client import Client
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.models.reranking.base import BaseReranking
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class XinferenceReranking(BaseReranking):
+
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        self.credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = Client(self.credentials['server_url'])
+
+        super().__init__(model_provider, client, name)
+
+    def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
+        docs = []
+        doc_id = []
+        for document in documents:
+            if document.metadata['doc_id'] not in doc_id:
+                doc_id.append(document.metadata['doc_id'])
+                docs.append(document.page_content)
+
+        model = self.client.get_model(self.credentials['model_uid'])
+        response = model.rerank(query=query, documents=docs, top_n=top_k)
+        rerank_documents = []
+
+        for idx, result in enumerate(response['results']):
+            # format document
+            index = result['index']
+            rerank_document = Document(
+                page_content=result['document'],
+                metadata={
+                    "doc_id": documents[index].metadata['doc_id'],
+                    "doc_hash": documents[index].metadata['doc_hash'],
+                    "document_id": documents[index].metadata['document_id'],
+                    "dataset_id": documents[index].metadata['dataset_id'],
+                    'score': result['relevance_score']
+                }
+            )
+            # score threshold check
+            if score_threshold is not None:
+                if result.relevance_score >= score_threshold:
+                    rerank_documents.append(rerank_document)
+            else:
+                rerank_documents.append(rerank_document)
+        return rerank_documents
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        return LLMBadRequestError(f"Xinference rerank: {str(ex)}")

+ 8 - 0
api/core/model_providers/providers/xinference_provider.py

@@ -2,11 +2,13 @@ import json
 from typing import Type
 
 import requests
+from xinference_client.client.restful.restful_client import Client
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
 from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
 from core.model_providers.models.llm.xinference_model import XinferenceModel
+from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
 from core.model_providers.models.base import BaseProviderModel
@@ -40,6 +42,8 @@ class XinferenceProvider(BaseModelProvider):
             model_class = XinferenceModel
         elif model_type == ModelType.EMBEDDINGS:
             model_class = XinferenceEmbedding
+        elif model_type == ModelType.RERANKING:
+            model_class = XinferenceReranking
         else:
             raise NotImplementedError
 
@@ -113,6 +117,10 @@ class XinferenceProvider(BaseModelProvider):
                 )
 
                 embedding.embed_query("ping")
+            elif model_type == ModelType.RERANKING:
+                rerank_client = Client(credential_kwargs['server_url'])
+                model = rerank_client.get_model(credential_kwargs['model_uid'])
+                model.rerank(query="ping", documents=["ping", "pong"], top_n=2)
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 

+ 2 - 1
api/core/model_providers/rules/xinference.json

@@ -6,6 +6,7 @@
     "model_flexibility": "configurable",
     "supported_model_types": [
         "text-generation",
-        "embeddings"
+        "embeddings",
+        "reranking"
     ]
 }

+ 1 - 1
api/requirements.txt

@@ -48,7 +48,7 @@ huggingface_hub~=0.16.4
 transformers~=4.31.0
 stripe~=5.5.0
 pandas==1.5.3
-xinference-client~=0.5.4
+xinference-client~=0.6.4
 safetensors==0.3.2
 zhipuai==1.0.7
 werkzeug==2.3.7

+ 4 - 1
api/tests/integration_tests/.env.example

@@ -50,4 +50,7 @@ XINFERENCE_MODEL_UID=
 OPENLLM_SERVER_URL=
 
 # LocalAI Credentials
-LOCALAI_SERVER_URL=
+LOCALAI_SERVER_URL=
+
+# Cohere Credentials
+COHERE_API_KEY=

+ 0 - 0
api/tests/integration_tests/models/reranking/__init__.py


+ 61 - 0
api/tests/integration_tests/models/reranking/test_cohere_reranking.py

@@ -0,0 +1,61 @@
+import json
+import os
+from unittest.mock import patch
+
+from langchain.schema import Document
+
+from core.model_providers.models.reranking.cohere_reranking import CohereReranking
+from core.model_providers.providers.cohere_provider import CohereProvider
+from models.provider import Provider, ProviderType
+
+
+def get_mock_provider(valid_api_key):
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='cohere',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps({'api_key': valid_api_key}),
+        is_valid=True,
+    )
+
+
+def get_mock_model():
+    valid_api_key = os.environ['COHERE_API_KEY']
+    provider = CohereProvider(provider=get_mock_provider(valid_api_key))
+    return CohereReranking(
+        model_provider=provider,
+        name='rerank-english-v2.0'
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_run(mock_decrypt):
+    model = get_mock_model()
+
+    docs = []
+    docs.append(Document(
+        page_content='bye',
+        metadata={
+            "doc_id": 'a',
+            "doc_hash": 'doc_hash',
+            "document_id": 'document_id',
+            "dataset_id": 'dataset_id',
+        }
+    ))
+    docs.append(Document(
+        page_content='hello',
+        metadata={
+            "doc_id": 'b',
+            "doc_hash": 'doc_hash',
+            "document_id": 'document_id',
+            "dataset_id": 'dataset_id',
+        }
+    ))
+    rst = model.rerank('hello', docs, None, 2)
+
+    assert rst[0].page_content == 'hello'

+ 78 - 0
api/tests/integration_tests/models/reranking/test_xinference_reranking.py

@@ -0,0 +1,78 @@
+import json
+import os
+from unittest.mock import patch, MagicMock
+
+from langchain.schema import Document
+
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking
+from core.model_providers.providers.xinference_provider import XinferenceProvider
+from models.provider import Provider, ProviderType, ProviderModel
+
+
+def get_mock_provider(valid_server_url, valid_model_uid):
+    return Provider(
+        id='provider_id',
+        tenant_id='tenant_id',
+        provider_name='xinference',
+        provider_type=ProviderType.CUSTOM.value,
+        encrypted_config=json.dumps({'server_url': valid_server_url, 'model_uid': valid_model_uid}),
+        is_valid=True,
+    )
+
+
+def get_mock_model(mocker):
+    valid_server_url = os.environ['XINFERENCE_SERVER_URL']
+    valid_model_uid = os.environ['XINFERENCE_MODEL_UID']
+    model_name = 'bge-reranker-base'
+    provider = XinferenceProvider(provider=get_mock_provider(valid_server_url, valid_model_uid))
+
+    mock_query = MagicMock()
+    mock_query.filter.return_value.first.return_value = ProviderModel(
+        provider_name='xinference',
+        model_name=model_name,
+        model_type=ModelType.RERANKING.value,
+        encrypted_config=json.dumps({
+            'server_url': valid_server_url,
+            'model_uid': valid_model_uid
+        }),
+        is_valid=True,
+    )
+    mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
+
+    return XinferenceReranking(
+        model_provider=provider,
+        name=model_name
+    )
+
+
+def decrypt_side_effect(tenant_id, encrypted_api_key):
+    return encrypted_api_key
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_run(mock_decrypt, mocker):
+    model = get_mock_model(mocker)
+
+    docs = []
+    docs.append(Document(
+        page_content='bye',
+        metadata={
+            "doc_id": 'a',
+            "doc_hash": 'doc_hash',
+            "document_id": 'document_id',
+            "dataset_id": 'dataset_id',
+        }
+    ))
+    docs.append(Document(
+        page_content='hello',
+        metadata={
+            "doc_id": 'b',
+            "doc_hash": 'doc_hash',
+            "document_id": 'document_id',
+            "dataset_id": 'dataset_id',
+        }
+    ))
+    rst = model.rerank('hello', docs, None, 2)
+
+    assert rst[0].page_content == 'hello'