Bladeren bron

enhance:speedup xinference embedding & rerank (#3587)

呆萌闷油瓶 1 jaar geleden
bovenliggende
commit
4365843c20

+ 17 - 12
api/core/model_runtime/model_providers/xinference/rerank/rerank.py

@@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel):
         if credentials['server_url'].endswith('/'):
         if credentials['server_url'].endswith('/'):
             credentials['server_url'] = credentials['server_url'][:-1]
             credentials['server_url'] = credentials['server_url'][:-1]
 
 
-        # initialize client
-        client = Client(
-            base_url=credentials['server_url']
-        )
-
-        xinference_client = client.get_model(model_uid=credentials['model_uid'])
-
-        if not isinstance(xinference_client, RESTfulRerankModelHandle):
-            raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
-
-        response = xinference_client.rerank(
+        handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
+        response = handle.rerank(
             documents=docs,
             documents=docs,
             query=query,
             query=query,
             top_n=top_n,
             top_n=top_n,
@@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel):
         try:
         try:
             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
             if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
                 raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
+
+            if credentials['server_url'].endswith('/'):
+                credentials['server_url'] = credentials['server_url'][:-1]
+
+            # initialize client
+            client = Client(
+                base_url=credentials['server_url']
+            )
+
+            xinference_client = client.get_model(model_uid=credentials['model_uid'])
+
+            if not isinstance(xinference_client, RESTfulRerankModelHandle):
+                raise InvokeBadRequestError(
+                    'please check model type, the model you want to invoke is not a rerank model')
             
             
             self.invoke(
             self.invoke(
                 model=model,
                 model=model,
@@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel):
             parameter_rules=[]
             parameter_rules=[]
         )
         )
 
 
-        return entity
+        return entity

+ 14 - 11
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         if server_url.endswith('/'):
         if server_url.endswith('/'):
             server_url = server_url[:-1]
             server_url = server_url[:-1]
 
 
-        client = Client(base_url=server_url)
-        
-        try:
-            handle = client.get_model(model_uid=model_uid)
-        except RuntimeError as e:
-            raise InvokeAuthorizationError(e)
-
-        if not isinstance(handle, RESTfulEmbeddingModelHandle):
-            raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
-
         try:
         try:
+            handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
             embeddings = handle.create_embedding(input=texts)
             embeddings = handle.create_embedding(input=texts)
         except RuntimeError as e:
         except RuntimeError as e:
             raise InvokeServerUnavailableError(e)
             raise InvokeServerUnavailableError(e)
@@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
 
 
             if extra_args.max_tokens:
             if extra_args.max_tokens:
                 credentials['max_tokens'] = extra_args.max_tokens
                 credentials['max_tokens'] = extra_args.max_tokens
+            if server_url.endswith('/'):
+                server_url = server_url[:-1]
+
+            client = Client(base_url=server_url)
+        
+            try:
+                handle = client.get_model(model_uid=model_uid)
+            except RuntimeError as e:
+                raise InvokeAuthorizationError(e)
+
+            if not isinstance(handle, RESTfulEmbeddingModelHandle):
+                raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
 
 
             self._invoke(model=model, credentials=credentials, texts=['ping'])
             self._invoke(model=model, credentials=credentials, texts=['ping'])
         except InvokeAuthorizationError as e:
         except InvokeAuthorizationError as e:
@@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
             parameter_rules=[]
             parameter_rules=[]
         )
         )
 
 
-        return entity
+        return entity