Browse Source

feat: record price unit in messages (#919)

takatost 1 year ago
parent
commit
0a0d63457d

+ 9 - 0
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.conversation_message_task import ConversationMessageTask
+from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.llm.base import BaseLLM
 
 
@@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.status = 'llm_end'
             if response.llm_output:
                 self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
+            else:
+                self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
+                    [PromptMessage(content=self._current_loop.prompt)]
+                )
             completion_generation = response.generations[0][0]
             if isinstance(completion_generation, ChatGeneration):
                 completion_message = completion_generation.message
@@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
 
             if response.llm_output:
                 self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
+            else:
+                self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
+                    [PromptMessage(content=self._current_loop.completion)]
+                )
 
     def on_llm_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

+ 12 - 0
api/core/conversation_message_task.py

@@ -119,9 +119,11 @@ class ConversationMessageTask:
             message="",
             message_tokens=0,
             message_unit_price=0,
+            message_price_unit=0,
             answer="",
             answer_tokens=0,
             answer_unit_price=0,
+            answer_price_unit=0,
             provider_response_latency=0,
             total_price=0,
             currency=self.model_instance.get_currency(),
@@ -142,7 +144,9 @@ class ConversationMessageTask:
         answer_tokens = llm_message.completion_tokens
 
         message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
+        message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
         answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
+        answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
 
         message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
         answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
@@ -151,9 +155,11 @@ class ConversationMessageTask:
         self.message.message = llm_message.prompt
         self.message.message_tokens = message_tokens
         self.message.message_unit_price = message_unit_price
+        self.message.message_price_unit = message_price_unit
         self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
         self.message.answer_tokens = answer_tokens
         self.message.answer_unit_price = answer_unit_price
+        self.message.answer_price_unit = answer_price_unit
         self.message.provider_response_latency = llm_message.latency
         self.message.total_price = total_price
 
@@ -195,7 +201,9 @@ class ConversationMessageTask:
             tool=agent_loop.tool_name,
             tool_input=agent_loop.tool_input,
             message=agent_loop.prompt,
+            message_price_unit=0,
             answer=agent_loop.completion,
+            answer_price_unit=0,
             created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
             created_by=self.user.id
         )
@@ -210,7 +218,9 @@ class ConversationMessageTask:
     def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
                      agent_loop: AgentLoop):
         agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
+        agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
         agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
+        agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
@@ -223,8 +233,10 @@ class ConversationMessageTask:
         message_agent_thought.tool_process_data = ''  # currently not support
         message_agent_thought.message_token = loop_message_tokens
         message_agent_thought.message_unit_price = agent_message_unit_price
+        message_agent_thought.message_price_unit = agent_message_price_unit
         message_agent_thought.answer_token = loop_answer_tokens
         message_agent_thought.answer_unit_price = agent_answer_unit_price
+        message_agent_thought.answer_price_unit = agent_answer_price_unit
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.total_price = loop_total_price

+ 20 - 4
api/core/model_providers/models/llm/base.py

@@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
         """
         raise NotImplementedError
 
-    def calc_tokens_price(self, tokens:int, message_type: MessageType):
+    def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
         """
         calc tokens total price.
 
@@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
             unit_price = self.price_config['prompt']
         else:
             unit_price = self.price_config['completion']
-        unit = self.price_config['unit']
+        unit = self.get_price_unit(message_type)
 
         total_price = tokens * unit_price * unit
         total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
         logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
         return total_price
 
-    def get_tokens_unit_price(self, message_type: MessageType):
+    def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
         """
         get token price.
 
@@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
         logging.debug(f"unit_price={unit_price}")
         return unit_price
 
-    def get_currency(self):
+    def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
+        """
+        get price unit.
+
+        :param message_type:
+        :return: decimal.Decimal('0.000001')
+        """
+        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+            price_unit = self.price_config['unit']
+        else:
+            price_unit = self.price_config['unit']
+
+        price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
+        logging.debug(f"price_unit={price_unit}")
+        return price_unit
+
+    def get_currency(self) -> str:
         """
         get token currency.
 

+ 43 - 0
api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py

@@ -0,0 +1,43 @@
+"""add message price unit
+
+Revision ID: 853f9b9cd3b6
+Revises: e8883b0148c9
+Create Date: 2023-08-19 17:01:57.471562
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '853f9b9cd3b6'
+down_revision = 'e8883b0148c9'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+
+    with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
+        batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
+
+    with op.batch_alter_table('messages', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
+        batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('messages', schema=None) as batch_op:
+        batch_op.drop_column('answer_price_unit')
+        batch_op.drop_column('message_price_unit')
+
+    with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+        batch_op.drop_column('answer_price_unit')
+        batch_op.drop_column('message_price_unit')
+
+    # ### end Alembic commands ###

+ 4 - 0
api/models/model.py

@@ -421,9 +421,11 @@ class Message(db.Model):
     message = db.Column(db.JSON, nullable=False)
     message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
     message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
+    message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
     answer = db.Column(db.Text, nullable=False)
     answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
     answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
+    answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
     provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0'))
     total_price = db.Column(db.Numeric(10, 7))
     currency = db.Column(db.String(255), nullable=False)
@@ -705,9 +707,11 @@ class MessageAgentThought(db.Model):
     message = db.Column(db.Text, nullable=True)
     message_token = db.Column(db.Integer, nullable=True)
     message_unit_price = db.Column(db.Numeric, nullable=True)
+    message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
     answer = db.Column(db.Text, nullable=True)
     answer_token = db.Column(db.Integer, nullable=True)
     answer_unit_price = db.Column(db.Numeric, nullable=True)
+    answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
     tokens = db.Column(db.Integer, nullable=True)
     total_price = db.Column(db.Numeric, nullable=True)
     currency = db.Column(db.String, nullable=True)