Ver Fonte

feat: add tool labels (#2178)

Yeuoly há 1 ano atrás
pai
commit
7cb75cb2e7

+ 1 - 0
api/controllers/service_api/app/message.py

@@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
         'position': fields.Integer,
         'thought': fields.String,
         'tool': fields.String,
+        'tool_labels': fields.Raw,
         'tool_input': fields.String,
         'created_at': TimestampField,
         'observation': fields.String,

+ 3 - 1
api/core/app_runner/generate_task_pipeline.py

@@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.tools.tool_file_manager import ToolFileManager
+from core.tools.tool_manager import ToolManager
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.prompt.prompt_template import PromptTemplateParser
 from events.message_event import message_was_created
@@ -281,7 +282,7 @@ class GenerateTaskPipeline:
 
                     self._task_state.llm_result.message.content = annotation.content
             elif isinstance(event, QueueAgentThoughtEvent):
-                agent_thought = (
+                agent_thought: MessageAgentThought = (
                     db.session.query(MessageAgentThought)
                     .filter(MessageAgentThought.id == event.agent_thought_id)
                     .first()
@@ -298,6 +299,7 @@ class GenerateTaskPipeline:
                         'thought': agent_thought.thought,
                         'observation': agent_thought.observation,
                         'tool': agent_thought.tool,
+                        'tool_labels': agent_thought.tool_labels,
                         'tool_input': agent_thought.tool_input,
                         'created_at': int(self._message.created_at.timestamp()),
                         'message_files': agent_thought.files

+ 16 - 0
api/core/features/assistant_base_runner.py

@@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
             message_chain_id=None,
             thought='',
             tool=tool_name,
+            tool_labels_str='{}',
             tool_input=tool_input,
             message=message,
             message_token=0,
@@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
             agent_thought.tokens = llm_usage.total_tokens
             agent_thought.total_price = llm_usage.total_price
 
+        # check if tool labels is not empty
+        labels = agent_thought.tool_labels or {}
+        tools = agent_thought.tool.split(';') if agent_thought.tool else []
+        for tool in tools:
+            if not tool:
+                continue
+            if tool not in labels:
+                tool_label = ToolManager.get_tool_label(tool)
+                if tool_label:
+                    labels[tool] = tool_label.to_dict()
+                else:
+                    labels[tool] = {'en_US': tool, 'zh_Hans': tool}
+
+        agent_thought.tool_labels_str = json.dumps(labels)
+
         db.session.commit()
 
     def get_history_prompt_messages(self) -> List[PromptMessage]:

+ 24 - 1
api/core/tools/tool_manager.py

@@ -31,6 +31,7 @@ import mimetypes
 logger = logging.getLogger(__name__)
 
 _builtin_providers = {}
+_builtin_tools_labels = {}
 
 class ToolManager:
     @staticmethod
@@ -233,7 +234,7 @@ class ToolManager:
         if len(_builtin_providers) > 0:
             return list(_builtin_providers.values())
         
-        builtin_providers = []
+        builtin_providers: List[BuiltinToolProviderController] = []
         for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
             if provider.startswith('__'):
                 continue
@@ -264,8 +265,30 @@ class ToolManager:
         # cache the builtin providers
         for provider in builtin_providers:
             _builtin_providers[provider.identity.name] = provider
+            for tool in provider.get_tools():
+                _builtin_tools_labels[tool.identity.name] = tool.identity.label
+
         return builtin_providers
     
+    @staticmethod
+    def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
+        """
+            get the tool label
+
+            :param tool_name: the name of the tool
+
+            :return: the label of the tool
+        """
+        global _builtin_tools_labels
+        if len(_builtin_tools_labels) == 0:
+            # init the builtin providers
+            ToolManager.list_builtin_providers()
+
+        if tool_name not in _builtin_tools_labels:
+            return None
+        
+        return _builtin_tools_labels[tool_name]
+    
     @staticmethod
     def user_list_providers(
         user_id: str,

+ 2 - 1
api/fields/conversation_fields.py

@@ -49,10 +49,11 @@ agent_thought_fields = {
     'position': fields.Integer,
     'thought': fields.String,
     'tool': fields.String,
+    'tool_labels': fields.Raw,
     'tool_input': fields.String,
     'created_at': TimestampField,
     'observation': fields.String,
-    'files': fields.List(fields.String)
+    'files': fields.List(fields.String),
 }
 
 message_detail_fields = {

+ 1 - 0
api/fields/message_fields.py

@@ -36,6 +36,7 @@ agent_thought_fields = {
     'position': fields.Integer,
     'thought': fields.String,
     'tool': fields.String,
+    'tool_labels': fields.Raw,
     'tool_input': fields.String,
     'created_at': TimestampField,
     'observation': fields.String,

+ 32 - 0
api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py

@@ -0,0 +1,32 @@
+"""add tool labels to agent thought
+
+Revision ID: 380c6aa5a70d
+Revises: dfb3b7f477da
+Create Date: 2024-01-24 10:58:15.644445
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '380c6aa5a70d'
+down_revision = 'dfb3b7f477da'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+        batch_op.drop_column('tool_labels_str')
+
+    # ### end Alembic commands ###

+ 11 - 0
api/models/model.py

@@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model):
     position = db.Column(db.Integer, nullable=False)
     thought = db.Column(db.Text, nullable=True)
     tool = db.Column(db.Text, nullable=True)
+    tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
     tool_input = db.Column(db.Text, nullable=True)
     observation = db.Column(db.Text, nullable=True)
     # plugin_id = db.Column(UUID, nullable=True)  ## for future design
@@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model):
             return json.loads(self.message_files)
         else:
             return []
+        
+    @property
+    def tool_labels(self) -> dict:
+        try:
+            if self.tool_labels_str:
+                return json.loads(self.tool_labels_str)
+            else:
+                return {}
+        except Exception as e:
+            return {}
 
 class DatasetRetrieverResource(db.Model):
     __tablename__ = 'dataset_retriever_resources'