Sfoglia il codice sorgente

feat: fix azure completion choices return empty (#708)

takatost 1 anno fa
parent
commit
e18211ffea
1 ha cambiato i file con 58 aggiunte e 1 eliminazioni
  1. 58 1
      api/core/llm/streamable_azure_open_ai.py

+ 58 - 1
api/core/llm/streamable_azure_open_ai.py

@@ -1,5 +1,7 @@
-from langchain.callbacks.manager import Callbacks
+from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
 from langchain.llms import AzureOpenAI
+from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
+    update_token_usage
 from langchain.schema import LLMResult
 from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
 
@@ -67,3 +69,58 @@ class StreamableAzureOpenAI(AzureOpenAI):
     @classmethod
     def get_kwargs_from_model_params(cls, params: dict):
         return params
+
+    def _generate(
+        self,
+        prompts: List[str],
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        **kwargs: Any,
+    ) -> LLMResult:
+        """Call out to OpenAI's endpoint with k unique prompts.
+
+        Args:
+            prompts: The prompts to pass into the model.
+            stop: Optional list of stop words to use when generating.
+
+        Returns:
+            The full LLM output.
+
+        Example:
+            .. code-block:: python
+
+                response = openai.generate(["Tell me a joke."])
+        """
+        params = self._invocation_params
+        params = {**params, **kwargs}
+        sub_prompts = self.get_sub_prompts(params, prompts, stop)
+        choices = []
+        token_usage: Dict[str, int] = {}
+        # Get the token usage from the response.
+        # Includes prompt, completion, and total tokens used.
+        _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
+        for _prompts in sub_prompts:
+            if self.streaming:
+                if len(_prompts) > 1:
+                    raise ValueError("Cannot stream results with multiple prompts.")
+                params["stream"] = True
+                response = _streaming_response_template()
+                for stream_resp in completion_with_retry(
+                    self, prompt=_prompts, **params
+                ):
+                    if len(stream_resp["choices"]) > 0:
+                        if run_manager:
+                            run_manager.on_llm_new_token(
+                                stream_resp["choices"][0]["text"],
+                                verbose=self.verbose,
+                                logprobs=stream_resp["choices"][0]["logprobs"],
+                            )
+                        _update_response(response, stream_resp)
+                choices.extend(response["choices"])
+            else:
+                response = completion_with_retry(self, prompt=_prompts, **params)
+                choices.extend(response["choices"])
+            if not self.streaming:
+                # Can't update token usage if streaming
+                update_token_usage(_keys, response, token_usage)
+        return self.create_llm_result(choices, prompts, token_usage)