streamable_chat_anthropic.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from typing import List, Optional, Any, Dict
  2. from httpx import Timeout
  3. from langchain.callbacks.manager import Callbacks
  4. from langchain.chat_models import ChatAnthropic
  5. from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
  6. from pydantic import root_validator
  7. from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
  8. class StreamableChatAnthropic(ChatAnthropic):
  9. """
  10. Wrapper around Anthropic's large language model.
  11. """
  12. default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
  13. @root_validator()
  14. def prepare_params(cls, values: Dict) -> Dict:
  15. values['model_name'] = values.get('model')
  16. values['max_tokens'] = values.get('max_tokens_to_sample')
  17. return values
  18. @handle_anthropic_exceptions
  19. def generate(
  20. self,
  21. messages: List[List[BaseMessage]],
  22. stop: Optional[List[str]] = None,
  23. callbacks: Callbacks = None,
  24. *,
  25. tags: Optional[List[str]] = None,
  26. metadata: Optional[Dict[str, Any]] = None,
  27. **kwargs: Any,
  28. ) -> LLMResult:
  29. return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
  30. @classmethod
  31. def get_kwargs_from_model_params(cls, params: dict):
  32. params['model'] = params.get('model_name')
  33. del params['model_name']
  34. params['max_tokens_to_sample'] = params.get('max_tokens')
  35. del params['max_tokens']
  36. del params['frequency_penalty']
  37. del params['presence_penalty']
  38. return params
  39. def _convert_one_message_to_text(self, message: BaseMessage) -> str:
  40. if isinstance(message, ChatMessage):
  41. message_text = f"\n\n{message.role.capitalize()}: {message.content}"
  42. elif isinstance(message, HumanMessage):
  43. message_text = f"{self.HUMAN_PROMPT} {message.content}"
  44. elif isinstance(message, AIMessage):
  45. message_text = f"{self.AI_PROMPT} {message.content}"
  46. elif isinstance(message, SystemMessage):
  47. message_text = f"<admin>{message.content}</admin>"
  48. else:
  49. raise ValueError(f"Got unknown type {message}")
  50. return message_text