|
@@ -8,12 +8,15 @@ from typing import (
|
|
|
Any,
|
|
|
Dict,
|
|
|
List,
|
|
|
- Optional, Iterator,
|
|
|
+ Optional, Iterator, Tuple,
|
|
|
)
|
|
|
|
|
|
import requests
|
|
|
+from langchain.chat_models.base import BaseChatModel
|
|
|
from langchain.llms.utils import enforce_stop_tokens
|
|
|
-from langchain.schema.output import GenerationChunk
|
|
|
+from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
|
|
+from langchain.schema.messages import AIMessageChunk
|
|
|
+from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
|
|
|
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
|
|
|
raise ValueError(f"Wenxin Model name is required")
|
|
|
|
|
|
model_url_map = {
|
|
|
+ 'ernie-bot-4': 'completions_pro',
|
|
|
'ernie-bot': 'completions',
|
|
|
'ernie-bot-turbo': 'eb-instant',
|
|
|
'bloomz-7b': 'bloomz_7b1',
|
|
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
|
|
|
|
|
|
access_token = self.get_access_token()
|
|
|
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
|
|
|
+ del request['model']
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
response = requests.post(api_url,
|
|
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
|
|
|
f"Wenxin API {json_response['error_code']}"
|
|
|
f" error: {json_response['error_msg']}"
|
|
|
)
|
|
|
- return json_response["result"]
|
|
|
+ return json_response
|
|
|
else:
|
|
|
return response
|
|
|
|
|
|
|
|
|
-class Wenxin(LLM):
|
|
|
- """Wrapper around Wenxin large language models.
|
|
|
- To use, you should have the environment variable
|
|
|
- ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
|
|
|
- or pass them as a named parameter to the constructor.
|
|
|
- Example:
|
|
|
- .. code-block:: python
|
|
|
- from langchain.llms.wenxin import Wenxin
|
|
|
- wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
|
|
|
- secret_key="my-group-id")
|
|
|
- """
|
|
|
+class Wenxin(BaseChatModel):
|
|
|
+ """Wrapper around Wenxin large language models."""
|
|
|
+
|
|
|
+ @property
|
|
|
+ def lc_secrets(self) -> Dict[str, str]:
|
|
|
+ return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
|
|
|
+
|
|
|
+ @property
|
|
|
+ def lc_serializable(self) -> bool:
|
|
|
+ return True
|
|
|
|
|
|
_client: _WenxinEndpointClient = PrivateAttr()
|
|
|
model: str = "ernie-bot"
|
|
@@ -161,64 +165,89 @@ class Wenxin(LLM):
|
|
|
secret_key=self.secret_key,
|
|
|
)
|
|
|
|
|
|
- def _call(
|
|
|
+ def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
|
|
+ if isinstance(message, ChatMessage):
|
|
|
+ message_dict = {"role": message.role, "content": message.content}
|
|
|
+ elif isinstance(message, HumanMessage):
|
|
|
+ message_dict = {"role": "user", "content": message.content}
|
|
|
+ elif isinstance(message, AIMessage):
|
|
|
+ message_dict = {"role": "assistant", "content": message.content}
|
|
|
+ elif isinstance(message, SystemMessage):
|
|
|
+ message_dict = {"role": "system", "content": message.content}
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Got unknown type {message}")
|
|
|
+ return message_dict
|
|
|
+
|
|
|
+ def _create_message_dicts(
|
|
|
+ self, messages: List[BaseMessage]
|
|
|
+ ) -> Tuple[List[Dict[str, Any]], str]:
|
|
|
+ dict_messages = []
|
|
|
+ system = None
|
|
|
+ for m in messages:
|
|
|
+ message = self._convert_message_to_dict(m)
|
|
|
+ if message['role'] == 'system':
|
|
|
+ if not system:
|
|
|
+ system = message['content']
|
|
|
+ else:
|
|
|
+ system += f"\n{message['content']}"
|
|
|
+ continue
|
|
|
+
|
|
|
+ if dict_messages:
|
|
|
+ previous_message = dict_messages[-1]
|
|
|
+ if previous_message['role'] == message['role']:
|
|
|
+ dict_messages[-1]['content'] += f"\n{message['content']}"
|
|
|
+ else:
|
|
|
+ dict_messages.append(message)
|
|
|
+ else:
|
|
|
+ dict_messages.append(message)
|
|
|
+
|
|
|
+ return dict_messages, system
|
|
|
+
|
|
|
+ def _generate(
|
|
|
self,
|
|
|
- prompt: str,
|
|
|
+ messages: List[BaseMessage],
|
|
|
stop: Optional[List[str]] = None,
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
**kwargs: Any,
|
|
|
- ) -> str:
|
|
|
- r"""Call out to Wenxin's completion endpoint to chat
|
|
|
- Args:
|
|
|
- prompt: The prompt to pass into the model.
|
|
|
- Returns:
|
|
|
- The string generated by the model.
|
|
|
- Example:
|
|
|
- .. code-block:: python
|
|
|
- response = wenxin("Tell me a joke.")
|
|
|
- """
|
|
|
+ ) -> ChatResult:
|
|
|
if self.streaming:
|
|
|
- completion = ""
|
|
|
+ generation: Optional[ChatGenerationChunk] = None
|
|
|
+ llm_output: Optional[Dict] = None
|
|
|
for chunk in self._stream(
|
|
|
- prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
|
|
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
):
|
|
|
- completion += chunk.text
|
|
|
+ if chunk.generation_info is not None \
|
|
|
+ and 'token_usage' in chunk.generation_info:
|
|
|
+ llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
|
|
+
|
|
|
+ if generation is None:
|
|
|
+ generation = chunk
|
|
|
+ else:
|
|
|
+ generation += chunk
|
|
|
+ assert generation is not None
|
|
|
+ return ChatResult(generations=[generation], llm_output=llm_output)
|
|
|
else:
|
|
|
+ message_dicts, system = self._create_message_dicts(messages)
|
|
|
request = self._default_params
|
|
|
- request["messages"] = [{"role": "user", "content": prompt}]
|
|
|
+ request["messages"] = message_dicts
|
|
|
+ if system:
|
|
|
+ request["system"] = system
|
|
|
request.update(kwargs)
|
|
|
- completion = self._client.post(request)
|
|
|
-
|
|
|
- if stop is not None:
|
|
|
- completion = enforce_stop_tokens(completion, stop)
|
|
|
-
|
|
|
- return completion
|
|
|
+ response = self._client.post(request)
|
|
|
+ return self._create_chat_result(response)
|
|
|
|
|
|
def _stream(
|
|
|
- self,
|
|
|
- prompt: str,
|
|
|
- stop: Optional[List[str]] = None,
|
|
|
- run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
- **kwargs: Any,
|
|
|
- ) -> Iterator[GenerationChunk]:
|
|
|
- r"""Call wenxin completion_stream and return the resulting generator.
|
|
|
-
|
|
|
- Args:
|
|
|
- prompt: The prompt to pass into the model.
|
|
|
- stop: Optional list of stop words to use when generating.
|
|
|
- Returns:
|
|
|
- A generator representing the stream of tokens from Wenxin.
|
|
|
- Example:
|
|
|
- .. code-block:: python
|
|
|
-
|
|
|
- prompt = "Write a poem about a stream."
|
|
|
- prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
|
|
- generator = wenxin.stream(prompt)
|
|
|
- for token in generator:
|
|
|
- yield token
|
|
|
- """
|
|
|
+ self,
|
|
|
+ messages: List[BaseMessage],
|
|
|
+ stop: Optional[List[str]] = None,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> Iterator[ChatGenerationChunk]:
|
|
|
+ message_dicts, system = self._create_message_dicts(messages)
|
|
|
request = self._default_params
|
|
|
- request["messages"] = [{"role": "user", "content": prompt}]
|
|
|
+ request["messages"] = message_dicts
|
|
|
+ if system:
|
|
|
+ request["system"] = system
|
|
|
request.update(kwargs)
|
|
|
|
|
|
for token in self._client.post(request).iter_lines():
|
|
@@ -228,12 +257,18 @@ class Wenxin(LLM):
|
|
|
if token.startswith('data:'):
|
|
|
completion = json.loads(token[5:])
|
|
|
|
|
|
- yield GenerationChunk(text=completion['result'])
|
|
|
- if run_manager:
|
|
|
- run_manager.on_llm_new_token(completion['result'])
|
|
|
+ chunk_dict = {
|
|
|
+ 'message': AIMessageChunk(content=completion['result']),
|
|
|
+ }
|
|
|
|
|
|
if completion['is_end']:
|
|
|
- break
|
|
|
+ token_usage = completion['usage']
|
|
|
+ token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
|
|
|
+ chunk_dict['generation_info'] = dict({'token_usage': token_usage})
|
|
|
+
|
|
|
+ yield ChatGenerationChunk(**chunk_dict)
|
|
|
+ if run_manager:
|
|
|
+ run_manager.on_llm_new_token(completion['result'])
|
|
|
else:
|
|
|
try:
|
|
|
json_response = json.loads(token)
|
|
@@ -245,3 +280,40 @@ class Wenxin(LLM):
|
|
|
f" error: {json_response['error_msg']}, "
|
|
|
f"please confirm if the model you have chosen is already paid for."
|
|
|
)
|
|
|
+
|
|
|
+ def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
|
|
+ generations = [ChatGeneration(
|
|
|
+ message=AIMessage(content=response['result']),
|
|
|
+ )]
|
|
|
+ token_usage = response.get("usage")
|
|
|
+ token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
|
|
|
+
|
|
|
+ llm_output = {"token_usage": token_usage, "model_name": self.model}
|
|
|
+ return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
+
|
|
|
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
|
+ """Get the number of tokens in the messages.
|
|
|
+
|
|
|
+ Useful for checking if an input will fit in a model's context window.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ messages: The message inputs to tokenize.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The sum of the number of tokens across the messages.
|
|
|
+ """
|
|
|
+ return sum([self.get_num_tokens(m.content) for m in messages])
|
|
|
+
|
|
|
+ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
|
|
+ overall_token_usage: dict = {}
|
|
|
+ for output in llm_outputs:
|
|
|
+ if output is None:
|
|
|
+ # Happens in streaming
|
|
|
+ continue
|
|
|
+ token_usage = output["token_usage"]
|
|
|
+ for k, v in token_usage.items():
|
|
|
+ if k in overall_token_usage:
|
|
|
+ overall_token_usage[k] += v
|
|
|
+ else:
|
|
|
+ overall_token_usage[k] = v
|
|
|
+ return {"token_usage": overall_token_usage, "model_name": self.model}
|