|
@@ -1,5 +1,6 @@
|
|
from abc import abstractmethod
|
|
from abc import abstractmethod
|
|
from typing import List, Optional, Any, Union
|
|
from typing import List, Optional, Any, Union
|
|
|
|
+import decimal
|
|
|
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
from langchain.callbacks.manager import Callbacks
|
|
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
|
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
|
@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
|
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
|
from core.model_providers.providers.base import BaseModelProvider
|
|
from core.model_providers.providers.base import BaseModelProvider
|
|
from core.third_party.langchain.llms.fake import FakeLLM
|
|
from core.third_party.langchain.llms.fake import FakeLLM
|
|
|
|
+import logging
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
class BaseLLM(BaseProviderModel):
|
|
class BaseLLM(BaseProviderModel):
|
|
@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
|
|
def _init_client(self) -> Any:
|
|
def _init_client(self) -> Any:
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def base_model_name(self) -> str:
|
|
|
|
+ """
|
|
|
|
+ get llm base model name
|
|
|
|
+
|
|
|
|
+ :return: str
|
|
|
|
+ """
|
|
|
|
+ return self.name
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def price_config(self) -> dict:
|
|
|
|
+ def get_or_default():
|
|
|
|
+ default_price_config = {
|
|
|
|
+ 'prompt': decimal.Decimal('0'),
|
|
|
|
+ 'completion': decimal.Decimal('0'),
|
|
|
|
+ 'unit': decimal.Decimal('0'),
|
|
|
|
+ 'currency': 'USD'
|
|
|
|
+ }
|
|
|
|
+ rules = self.model_provider.get_rules()
|
|
|
|
+ price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
|
|
|
+ price_config = {
|
|
|
|
+ 'prompt': decimal.Decimal(price_config['prompt']),
|
|
|
|
+ 'completion': decimal.Decimal(price_config['completion']),
|
|
|
|
+ 'unit': decimal.Decimal(price_config['unit']),
|
|
|
|
+ 'currency': price_config['currency']
|
|
|
|
+ }
|
|
|
|
+ return price_config
|
|
|
|
+
|
|
|
|
+ self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
|
|
|
+
|
|
|
|
+ logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
|
|
|
+ return self._price_config
|
|
|
|
+
|
|
def run(self, messages: List[PromptMessage],
|
|
def run(self, messages: List[PromptMessage],
|
|
stop: Optional[List[str]] = None,
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
callbacks: Callbacks = None,
|
|
@@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
|
|
"""
|
|
"""
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|
|
|
|
|
|
- @abstractmethod
|
|
|
|
- def get_token_price(self, tokens: int, message_type: MessageType):
|
|
|
|
|
|
+ def calc_tokens_price(self, tokens:int, message_type: MessageType):
|
|
"""
|
|
"""
|
|
- get token price.
|
|
|
|
|
|
+ calc tokens total price.
|
|
|
|
|
|
:param tokens:
|
|
:param tokens:
|
|
:param message_type:
|
|
:param message_type:
|
|
:return:
|
|
:return:
|
|
"""
|
|
"""
|
|
- raise NotImplementedError
|
|
|
|
|
|
+ if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
|
|
+ unit_price = self.price_config['prompt']
|
|
|
|
+ else:
|
|
|
|
+ unit_price = self.price_config['completion']
|
|
|
|
+ unit = self.price_config['unit']
|
|
|
|
+
|
|
|
|
+ total_price = tokens * unit_price * unit
|
|
|
|
+ total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
|
|
+ logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
|
|
|
+ return total_price
|
|
|
|
+
|
|
|
|
+ def get_tokens_unit_price(self, message_type: MessageType):
|
|
|
|
+ """
|
|
|
|
+ get token price.
|
|
|
|
+
|
|
|
|
+ :param message_type:
|
|
|
|
+ :return: decimal.Decimal('0.0001')
|
|
|
|
+ """
|
|
|
|
+ if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
|
|
+ unit_price = self.price_config['prompt']
|
|
|
|
+ else:
|
|
|
|
+ unit_price = self.price_config['completion']
|
|
|
|
+ unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
|
|
|
+ logging.debug(f"unit_price={unit_price}")
|
|
|
|
+ return unit_price
|
|
|
|
|
|
- @abstractmethod
|
|
|
|
def get_currency(self):
|
|
def get_currency(self):
|
|
"""
|
|
"""
|
|
get token currency.
|
|
get token currency.
|
|
|
|
|
|
- :return:
|
|
|
|
|
|
+ :return: get from price config, default 'USD'
|
|
"""
|
|
"""
|
|
- raise NotImplementedError
|
|
|
|
|
|
+ currency = self.price_config['currency']
|
|
|
|
+ return currency
|
|
|
|
|
|
def get_model_kwargs(self):
|
|
def get_model_kwargs(self):
|
|
return self.model_kwargs
|
|
return self.model_kwargs
|