Krasus.Chen преди 1 година
родител
ревизия
fd0fc8f4fe
променени са 22 файла, в които са добавени 288 реда и са изтрити 230 реда
  1. 11 20
      api/core/conversation_message_task.py
  2. 3 3
      api/core/indexing_runner.py
  3. 9 10
      api/core/model_providers/models/embedding/azure_openai_embedding.py
  4. 69 5
      api/core/model_providers/models/embedding/base.py
  5. 0 3
      api/core/model_providers/models/embedding/minimax_embedding.py
  6. 0 10
      api/core/model_providers/models/embedding/openai_embedding.py
  7. 0 7
      api/core/model_providers/models/embedding/replicate_embedding.py
  8. 0 26
      api/core/model_providers/models/llm/anthropic_model.py
  9. 9 40
      api/core/model_providers/models/llm/azure_openai_model.py
  10. 66 7
      api/core/model_providers/models/llm/base.py
  11. 0 3
      api/core/model_providers/models/llm/chatglm_model.py
  12. 0 7
      api/core/model_providers/models/llm/huggingface_hub_model.py
  13. 0 3
      api/core/model_providers/models/llm/minimax_model.py
  14. 2 39
      api/core/model_providers/models/llm/openai_model.py
  15. 0 7
      api/core/model_providers/models/llm/replicate_model.py
  16. 0 3
      api/core/model_providers/models/llm/spark_model.py
  17. 0 3
      api/core/model_providers/models/llm/tongyi_model.py
  18. 1 30
      api/core/model_providers/models/llm/wenxin_model.py
  19. 15 1
      api/core/model_providers/rules/anthropic.json
  20. 44 1
      api/core/model_providers/rules/azure_openai.json
  21. 38 1
      api/core/model_providers/rules/openai.json
  22. 21 1
      api/core/model_providers/rules/wenxin.json

+ 11 - 20
api/core/conversation_message_task.py

@@ -140,10 +140,13 @@ class ConversationMessageTask:
     def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
     def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
         message_tokens = llm_message.prompt_tokens
         message_tokens = llm_message.prompt_tokens
         answer_tokens = llm_message.completion_tokens
         answer_tokens = llm_message.completion_tokens
-        message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
-        answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
 
 
-        total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
+        message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
+        answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
+
+        message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
+        answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
+        total_price = message_total_price + answer_total_price
 
 
         self.message.message = llm_message.prompt
         self.message.message = llm_message.prompt
         self.message.message_tokens = message_tokens
         self.message.message_tokens = message_tokens
@@ -206,18 +209,15 @@ class ConversationMessageTask:
 
 
     def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
     def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
                      agent_loop: AgentLoop):
                      agent_loop: AgentLoop):
-        agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
-        agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
+        agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
+        agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
 
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
         loop_answer_tokens = agent_loop.completion_tokens
 
 
-        loop_total_price = self.calc_total_price(
-            loop_message_tokens,
-            agent_message_unit_price,
-            loop_answer_tokens,
-            agent_answer_unit_price
-        )
+        loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
+        loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
+        loop_total_price = loop_message_total_price + loop_answer_total_price
 
 
         message_agent_thought.observation = agent_loop.tool_output
         message_agent_thought.observation = agent_loop.tool_output
         message_agent_thought.tool_process_data = ''  # currently not support
         message_agent_thought.tool_process_data = ''  # currently not support
@@ -243,15 +243,6 @@ class ConversationMessageTask:
 
 
         db.session.add(dataset_query)
         db.session.add(dataset_query)
 
 
-    def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
-        message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                                  rounding=decimal.ROUND_HALF_UP)
-        answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                                rounding=decimal.ROUND_HALF_UP)
-
-        total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
     def end(self):
     def end(self):
         self._pub_handler.pub_end()
         self._pub_handler.pub_end()
 
 

+ 3 - 3
api/core/indexing_runner.py

@@ -278,7 +278,7 @@ class IndexingRunner:
                     "total_segments": total_segments * 20,
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
                     "total_price": '{:f}'.format(
-                        text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
+                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
                     "currency": embedding_model.get_currency(),
                     "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                     "preview": preview_texts
@@ -286,7 +286,7 @@ class IndexingRunner:
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
             "tokens": tokens,
             "tokens": tokens,
-            "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
+            "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
             "currency": embedding_model.get_currency(),
             "currency": embedding_model.get_currency(),
             "preview": preview_texts
             "preview": preview_texts
         }
         }
@@ -371,7 +371,7 @@ class IndexingRunner:
                     "total_segments": total_segments * 20,
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
                     "total_price": '{:f}'.format(
-                        text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
+                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
                     "currency": embedding_model.get_currency(),
                     "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
                     "preview": preview_texts

+ 9 - 10
api/core/model_providers/models/embedding/azure_openai_embedding.py

@@ -31,6 +31,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
         )
         )
 
 
         super().__init__(model_provider, client, name)
         super().__init__(model_provider, client, name)
+    
+    @property
+    def base_model_name(self) -> str:
+        """
+        get base model name (not deployment)
+        
+        :return: str
+        """
+        return self.credentials.get("base_model_name")
 
 
     def get_num_tokens(self, text: str) -> int:
     def get_num_tokens(self, text: str) -> int:
         """
         """
@@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
         # calculate the number of tokens in the encoded text
         # calculate the number of tokens in the encoded text
         return len(tokenized_text)
         return len(tokenized_text)
 
 
-    def get_token_price(self, tokens: int):
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * decimal.Decimal('0.0001')
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    def get_currency(self):
-        return 'USD'
-
     def handle_exceptions(self, ex: Exception) -> Exception:
     def handle_exceptions(self, ex: Exception) -> Exception:
         if isinstance(ex, openai.error.InvalidRequestError):
         if isinstance(ex, openai.error.InvalidRequestError):
             logging.warning("Invalid request to Azure OpenAI API.")
             logging.warning("Invalid request to Azure OpenAI API.")

+ 69 - 5
api/core/model_providers/models/embedding/base.py

@@ -1,5 +1,6 @@
 from abc import abstractmethod
 from abc import abstractmethod
 from typing import Any
 from typing import Any
+import decimal
 
 
 import tiktoken
 import tiktoken
 from langchain.schema.language_model import _get_token_ids_default_method
 from langchain.schema.language_model import _get_token_ids_default_method
@@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.entity.model_params import ModelType
 from core.model_providers.models.entity.model_params import ModelType
 from core.model_providers.providers.base import BaseModelProvider
 from core.model_providers.providers.base import BaseModelProvider
-
+import logging
+logger = logging.getLogger(__name__)
 
 
 class BaseEmbedding(BaseProviderModel):
 class BaseEmbedding(BaseProviderModel):
     name: str
     name: str
@@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
         super().__init__(model_provider, client)
         super().__init__(model_provider, client)
         self.name = name
         self.name = name
 
 
+    @property
+    def base_model_name(self) -> str:
+        """
+        get base model name
+        
+        :return: str
+        """
+        return self.name
+
+    @property
+    def price_config(self) -> dict:
+        def get_or_default():
+            default_price_config = {
+                    'prompt': decimal.Decimal('0'),
+                    'completion': decimal.Decimal('0'),
+                    'unit': decimal.Decimal('0'),
+                    'currency': 'USD'
+                }
+            rules = self.model_provider.get_rules()
+            price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
+            price_config = {
+                'prompt': decimal.Decimal(price_config['prompt']),
+                'completion': decimal.Decimal(price_config['completion']),
+                'unit': decimal.Decimal(price_config['unit']),
+                'currency': price_config['currency']
+            }
+            return price_config
+        
+        self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
+
+        logger.debug(f"model: {self.name} price_config: {self._price_config}")
+        return self._price_config
+
+    def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
+        """
+        calc tokens total price.
+
+        :param tokens:
+        :return: decimal.Decimal('0.0000001')
+        """
+        unit_price = self._price_config['completion']
+        unit = self._price_config['unit']
+        total_price = tokens * unit_price * unit
+        total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+        logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
+        return total_price
+
+    def get_tokens_unit_price(self) -> decimal.Decimal:
+        """
+        get token price.
+
+        :return: decimal.Decimal('0.0001')
+        
+        """
+        unit_price = self._price_config['completion']
+        unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
+        logger.debug(f'unit_price:{unit_price}')
+        return unit_price
+
     def get_num_tokens(self, text: str) -> int:
     def get_num_tokens(self, text: str) -> int:
         """
         """
         get num tokens of text.
         get num tokens of text.
@@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
 
 
         return len(_get_token_ids_default_method(text))
         return len(_get_token_ids_default_method(text))
 
 
-    def get_token_price(self, tokens: int):
-        return 0
-
     def get_currency(self):
     def get_currency(self):
-        return 'USD'
+        """
+        get token currency.
+
+        :return: get from price config, default 'USD'
+        """
+        currency = self._price_config['currency']
+        return currency
 
 
     @abstractmethod
     @abstractmethod
     def handle_exceptions(self, ex: Exception) -> Exception:
     def handle_exceptions(self, ex: Exception) -> Exception:

+ 0 - 3
api/core/model_providers/models/embedding/minimax_embedding.py

@@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
 
 
         super().__init__(model_provider, client, name)
         super().__init__(model_provider, client, name)
 
 
-    def get_token_price(self, tokens: int):
-        return decimal.Decimal('0')
-
     def get_currency(self):
     def get_currency(self):
         return 'RMB'
         return 'RMB'
 
 

+ 0 - 10
api/core/model_providers/models/embedding/openai_embedding.py

@@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
         # calculate the number of tokens in the encoded text
         # calculate the number of tokens in the encoded text
         return len(tokenized_text)
         return len(tokenized_text)
 
 
-    def get_token_price(self, tokens: int):
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * decimal.Decimal('0.0001')
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    def get_currency(self):
-        return 'USD'
-
     def handle_exceptions(self, ex: Exception) -> Exception:
     def handle_exceptions(self, ex: Exception) -> Exception:
         if isinstance(ex, openai.error.InvalidRequestError):
         if isinstance(ex, openai.error.InvalidRequestError):
             logging.warning("Invalid request to OpenAI API.")
             logging.warning("Invalid request to OpenAI API.")

+ 0 - 7
api/core/model_providers/models/embedding/replicate_embedding.py

@@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
 
 
         super().__init__(model_provider, client, name)
         super().__init__(model_provider, client, name)
 
 
-    def get_token_price(self, tokens: int):
-        # replicate only pay for prediction seconds
-        return decimal.Decimal('0')
-
-    def get_currency(self):
-        return 'USD'
-
     def handle_exceptions(self, ex: Exception) -> Exception:
     def handle_exceptions(self, ex: Exception) -> Exception:
         if isinstance(ex, (ModelError, ReplicateError)):
         if isinstance(ex, (ModelError, ReplicateError)):
             return LLMBadRequestError(f"Replicate: {str(ex)}")
             return LLMBadRequestError(f"Replicate: {str(ex)}")

+ 0 - 26
api/core/model_providers/models/llm/anthropic_model.py

@@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
         return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        model_unit_prices = {
-            'claude-instant-1': {
-                'prompt': decimal.Decimal('1.63'),
-                'completion': decimal.Decimal('5.51'),
-            },
-            'claude-2': {
-                'prompt': decimal.Decimal('11.02'),
-                'completion': decimal.Decimal('32.68'),
-            },
-        }
-
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
-            unit_price = model_unit_prices[self.name]['prompt']
-        else:
-            unit_price = model_unit_prices[self.name]['completion']
-
-        tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
-                                                                     rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1m * unit_price
-        return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
-
-    def get_currency(self):
-        return 'USD'
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         for k, v in provider_model_kwargs.items():
         for k, v in provider_model_kwargs.items():

+ 9 - 40
api/core/model_providers/models/llm/azure_openai_model.py

@@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
             self.model_mode = ModelMode.COMPLETION
             self.model_mode = ModelMode.COMPLETION
         else:
         else:
             self.model_mode = ModelMode.CHAT
             self.model_mode = ModelMode.CHAT
-
         super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
         super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
 
 
     def _init_client(self) -> Any:
     def _init_client(self) -> Any:
@@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
         """
         """
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return self._client.generate([prompts], stop, callbacks)
         return self._client.generate([prompts], stop, callbacks)
+    
+    @property
+    def base_model_name(self) -> str:
+        """
+        get base model name (not deployment)
+        
+        :return: str
+        """
+        return self.credentials.get("base_model_name")
 
 
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """
         """
@@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
         else:
         else:
             return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
             return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        model_unit_prices = {
-            'gpt-4': {
-                'prompt': decimal.Decimal('0.03'),
-                'completion': decimal.Decimal('0.06'),
-            },
-            'gpt-4-32k': {
-                'prompt': decimal.Decimal('0.06'),
-                'completion': decimal.Decimal('0.12')
-            },
-            'gpt-35-turbo': {
-                'prompt': decimal.Decimal('0.0015'),
-                'completion': decimal.Decimal('0.002')
-            },
-            'gpt-35-turbo-16k': {
-                'prompt': decimal.Decimal('0.003'),
-                'completion': decimal.Decimal('0.004')
-            },
-            'text-davinci-003': {
-                'prompt': decimal.Decimal('0.02'),
-                'completion': decimal.Decimal('0.02')
-            },
-        }
-
-        base_model_name = self.credentials.get("base_model_name")
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
-            unit_price = model_unit_prices[base_model_name]['prompt']
-        else:
-            unit_price = model_unit_prices[base_model_name]['completion']
-
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * unit_price
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    def get_currency(self):
-        return 'USD'
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         if self.name == 'text-davinci-003':
         if self.name == 'text-davinci-003':

+ 66 - 7
api/core/model_providers/models/llm/base.py

@@ -1,5 +1,6 @@
 from abc import abstractmethod
 from abc import abstractmethod
 from typing import List, Optional, Any, Union
 from typing import List, Optional, Any, Union
+import decimal
 
 
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
@@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.providers.base import BaseModelProvider
 from core.model_providers.providers.base import BaseModelProvider
 from core.third_party.langchain.llms.fake import FakeLLM
 from core.third_party.langchain.llms.fake import FakeLLM
+import logging
+logger = logging.getLogger(__name__)
 
 
 
 
 class BaseLLM(BaseProviderModel):
 class BaseLLM(BaseProviderModel):
@@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
     def _init_client(self) -> Any:
     def _init_client(self) -> Any:
         raise NotImplementedError
         raise NotImplementedError
 
 
+    @property
+    def base_model_name(self) -> str:
+        """
+        get llm base model name
+
+        :return: str
+        """
+        return self.name
+
+    @property
+    def price_config(self) -> dict:
+        def get_or_default():
+            default_price_config = {
+                    'prompt': decimal.Decimal('0'),
+                    'completion': decimal.Decimal('0'),
+                    'unit': decimal.Decimal('0'),
+                    'currency': 'USD'
+                }
+            rules = self.model_provider.get_rules()
+            price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
+            price_config = {
+                'prompt': decimal.Decimal(price_config['prompt']),
+                'completion': decimal.Decimal(price_config['completion']),
+                'unit': decimal.Decimal(price_config['unit']),
+                'currency': price_config['currency']
+            }
+            return price_config
+        
+        self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
+
+        logger.debug(f"model: {self.name} price_config: {self._price_config}")
+        return self._price_config
+
     def run(self, messages: List[PromptMessage],
     def run(self, messages: List[PromptMessage],
             stop: Optional[List[str]] = None,
             stop: Optional[List[str]] = None,
             callbacks: Callbacks = None,
             callbacks: Callbacks = None,
@@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
-    @abstractmethod
-    def get_token_price(self, tokens: int, message_type: MessageType):
+    def calc_tokens_price(self, tokens:int, message_type: MessageType):
         """
         """
-        get token price.
+        calc tokens total price.
 
 
         :param tokens:
         :param tokens:
         :param message_type:
         :param message_type:
         :return:
         :return:
         """
         """
-        raise NotImplementedError
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = self.price_config['prompt']
+        else:
+            unit_price = self.price_config['completion']
+        unit = self.price_config['unit']
+
+        total_price = tokens * unit_price * unit
+        total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
+        logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
+        return total_price
+
+    def get_tokens_unit_price(self, message_type: MessageType):
+        """
+        get token price.
+
+        :param message_type:
+        :return: decimal.Decimal('0.0001')
+        """
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            unit_price = self.price_config['prompt']
+        else:
+            unit_price = self.price_config['completion']
+        unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
+        logging.debug(f"unit_price={unit_price}")
+        return unit_price
 
 
-    @abstractmethod
     def get_currency(self):
     def get_currency(self):
         """
         """
         get token currency.
         get token currency.
 
 
-        :return:
+        :return: get from price config, default 'USD'
         """
         """
-        raise NotImplementedError
+        currency = self.price_config['currency']
+        return currency
 
 
     def get_model_kwargs(self):
     def get_model_kwargs(self):
         return self.model_kwargs
         return self.model_kwargs

+ 0 - 3
api/core/model_providers/models/llm/chatglm_model.py

@@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
         return max(self._client.get_num_tokens(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        return decimal.Decimal('0')
-
     def get_currency(self):
     def get_currency(self):
         return 'RMB'
         return 'RMB'
 
 

+ 0 - 7
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return self._client.get_num_tokens(prompts)
         return self._client.get_num_tokens(prompts)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        # not support calc price
-        return decimal.Decimal('0')
-
-    def get_currency(self):
-        return 'USD'
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         self.client.model_kwargs = provider_model_kwargs
         self.client.model_kwargs = provider_model_kwargs

+ 0 - 3
api/core/model_providers/models/llm/minimax_model.py

@@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
         return max(self._client.get_num_tokens(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        return decimal.Decimal('0')
-
     def get_currency(self):
     def get_currency(self):
         return 'RMB'
         return 'RMB'
 
 

+ 2 - 39
api/core/model_providers/models/llm/openai_model.py

@@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
             self.model_mode = ModelMode.COMPLETION
             self.model_mode = ModelMode.COMPLETION
         else:
         else:
             self.model_mode = ModelMode.CHAT
             self.model_mode = ModelMode.CHAT
-
+        
+        # TODO load price config from configs(db)
         super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
         super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
 
 
     def _init_client(self) -> Any:
     def _init_client(self) -> Any:
@@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
         else:
         else:
             return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
             return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        model_unit_prices = {
-            'gpt-4': {
-                'prompt': decimal.Decimal('0.03'),
-                'completion': decimal.Decimal('0.06'),
-            },
-            'gpt-4-32k': {
-                'prompt': decimal.Decimal('0.06'),
-                'completion': decimal.Decimal('0.12')
-            },
-            'gpt-3.5-turbo': {
-                'prompt': decimal.Decimal('0.0015'),
-                'completion': decimal.Decimal('0.002')
-            },
-            'gpt-3.5-turbo-16k': {
-                'prompt': decimal.Decimal('0.003'),
-                'completion': decimal.Decimal('0.004')
-            },
-            'text-davinci-003': {
-                'prompt': decimal.Decimal('0.02'),
-                'completion': decimal.Decimal('0.02')
-            },
-        }
-
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
-            unit_price = model_unit_prices[self.name]['prompt']
-        else:
-            unit_price = model_unit_prices[self.name]['completion']
-
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * unit_price
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    def get_currency(self):
-        return 'USD'
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         if self.name in COMPLETION_MODELS:
         if self.name in COMPLETION_MODELS:

+ 0 - 7
api/core/model_providers/models/llm/replicate_model.py

@@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
 
 
         return self._client.get_num_tokens(prompts)
         return self._client.get_num_tokens(prompts)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        # replicate only pay for prediction seconds
-        return decimal.Decimal('0')
-
-    def get_currency(self):
-        return 'USD'
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         self.client.input = provider_model_kwargs
         self.client.input = provider_model_kwargs

+ 0 - 3
api/core/model_providers/models/llm/spark_model.py

@@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
         contents = [message.content for message in messages]
         contents = [message.content for message in messages]
         return max(self._client.get_num_tokens("".join(contents)), 0)
         return max(self._client.get_num_tokens("".join(contents)), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        return decimal.Decimal('0')
-
     def get_currency(self):
     def get_currency(self):
         return 'RMB'
         return 'RMB'
 
 

+ 0 - 3
api/core/model_providers/models/llm/tongyi_model.py

@@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
         return max(self._client.get_num_tokens(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        return decimal.Decimal('0')
-
     def get_currency(self):
     def get_currency(self):
         return 'RMB'
         return 'RMB'
 
 

+ 1 - 30
api/core/model_providers/models/llm/wenxin_model.py

@@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
 
 
     def _init_client(self) -> Any:
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
+        # TODO load price_config from configs(db)
         return Wenxin(
         return Wenxin(
             streaming=self.streaming,
             streaming=self.streaming,
             callbacks=self.callbacks,
             callbacks=self.callbacks,
@@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
         return max(self._client.get_num_tokens(prompts), 0)
 
 
-    def get_token_price(self, tokens: int, message_type: MessageType):
-        model_unit_prices = {
-            'ernie-bot': {
-                'prompt': decimal.Decimal('0.012'),
-                'completion': decimal.Decimal('0.012'),
-            },
-            'ernie-bot-turbo': {
-                'prompt': decimal.Decimal('0.008'),
-                'completion': decimal.Decimal('0.008')
-            },
-            'bloomz-7b': {
-                'prompt': decimal.Decimal('0.006'),
-                'completion': decimal.Decimal('0.006')
-            }
-        }
-
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
-            unit_price = model_unit_prices[self.name]['prompt']
-        else:
-            unit_price = model_unit_prices[self.name]['completion']
-
-        tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
-                                                                  rounding=decimal.ROUND_HALF_UP)
-
-        total_price = tokens_per_1k * unit_price
-        return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
-
-    def get_currency(self):
-        return 'RMB'
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         for k, v in provider_model_kwargs.items():
         for k, v in provider_model_kwargs.items():

+ 15 - 1
api/core/model_providers/rules/anthropic.json

@@ -11,5 +11,19 @@
         "quota_unit": "tokens",
         "quota_unit": "tokens",
         "quota_limit": 600000
         "quota_limit": 600000
     },
     },
-    "model_flexibility": "fixed"
+    "model_flexibility": "fixed",
+    "price_config": {
+        "claude-instant-1": {
+            "prompt": "1.63",
+            "completion": "5.51",
+            "unit": "0.000001",
+            "currency": "USD"
+        },
+        "claude-2": {
+            "prompt": "11.02",
+            "completion": "32.68",
+            "unit": "0.000001",
+            "currency": "USD"
+        }
+    }
 }
 }

+ 44 - 1
api/core/model_providers/rules/azure_openai.json

@@ -3,5 +3,48 @@
         "custom"
         "custom"
     ],
     ],
     "system_config": null,
     "system_config": null,
-    "model_flexibility": "configurable"
+    "model_flexibility": "configurable",
+    "price_config":{
+        "gpt-4": {
+            "prompt": "0.03",
+            "completion": "0.06",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "gpt-4-32k": {
+            "prompt": "0.06",
+            "completion": "0.12",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "gpt-35-turbo": {
+            "prompt": "0.0015",
+            "completion": "0.002",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "gpt-35-turbo-16k": {
+            "prompt": "0.003",
+            "completion": "0.004",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "text-davinci-002": {
+            "prompt": "0.02",
+            "completion": "0.02",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "text-davinci-003": {
+            "prompt": "0.02",
+            "completion": "0.02",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "text-embedding-ada-002":{
+            "completion": "0.0001",
+            "unit": "0.001",
+            "currency": "USD"
+        }
+    }
 }
 }

+ 38 - 1
api/core/model_providers/rules/openai.json

@@ -10,5 +10,42 @@
         "quota_unit": "times",
         "quota_unit": "times",
         "quota_limit": 200
         "quota_limit": 200
     },
     },
-    "model_flexibility": "fixed"
+    "model_flexibility": "fixed",
+    "price_config": {
+        "gpt-4": {
+            "prompt": "0.03",
+            "completion": "0.06",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "gpt-4-32k": {
+            "prompt": "0.06",
+            "completion": "0.12",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "gpt-3.5-turbo": {
+            "prompt": "0.0015",
+            "completion": "0.002",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "gpt-3.5-turbo-16k": {
+            "prompt": "0.003",
+            "completion": "0.004",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "text-davinci-003": {
+            "prompt": "0.02",
+            "completion": "0.02",
+            "unit": "0.001",
+            "currency": "USD"
+        },
+        "text-embedding-ada-002":{
+            "completion": "0.0001",
+            "unit": "0.001",
+            "currency": "USD"
+        }
+    }
 }
 }

+ 21 - 1
api/core/model_providers/rules/wenxin.json

@@ -3,5 +3,25 @@
         "custom"
         "custom"
     ],
     ],
     "system_config": null,
     "system_config": null,
-    "model_flexibility": "fixed"
+    "model_flexibility": "fixed",
+    "price_config": {
+        "ernie-bot": {
+            "prompt": "0.012",
+            "completion": "0.012",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
+        "ernie-bot-turbo": {
+            "prompt": "0.008",
+            "completion": "0.008",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
+        "bloomz-7b": {
+            "prompt": "0.006",
+            "completion": "0.006",
+            "unit": "0.001",
+            "currency": "RMB"
+        }
+    }
 }
 }