|
@@ -17,7 +17,13 @@ class WenxinRerank(_CommonWenxin):
|
|
|
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
|
|
|
access_token = self._get_access_token()
|
|
|
url = f"{self.api_bases[model]}?access_token={access_token}"
|
|
|
-
|
|
|
+ # For issue #11252
|
|
|
+ # for wenxin Rerank model top_n length should be equal or less than docs length
|
|
|
+ if top_n is not None and top_n > len(docs):
|
|
|
+ top_n = len(docs)
|
|
|
+ # for wenxin Rerank model, query should not be an empty string
|
|
|
+ if query == "":
|
|
|
+ query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience.
|
|
|
try:
|
|
|
response = httpx.post(
|
|
|
url,
|
|
@@ -25,7 +31,11 @@ class WenxinRerank(_CommonWenxin):
|
|
|
headers={"Content-Type": "application/json"},
|
|
|
)
|
|
|
response.raise_for_status()
|
|
|
- return response.json()
|
|
|
+ data = response.json()
|
|
|
+ # wenxin error handling
|
|
|
+ if "error_code" in data:
|
|
|
+ raise InternalServerError(data["error_msg"])
|
|
|
+ return data
|
|
|
except httpx.HTTPStatusError as e:
|
|
|
raise InternalServerError(str(e))
|
|
|
|
|
@@ -69,6 +79,9 @@ class WenxinRerankModel(RerankModel):
|
|
|
results = wenxin_rerank.rerank(model, query, docs, top_n)
|
|
|
|
|
|
rerank_documents = []
|
|
|
+ if "results" not in results:
|
|
|
+ raise ValueError("results key not found in response")
|
|
|
+
|
|
|
for result in results["results"]:
|
|
|
index = result["index"]
|
|
|
if "document" in result:
|