Prechádzať zdrojové kódy

enhance: claude stream tool call (#4469)

Yeuoly 11 mesiacov pred
rodič
commit
091fba74cb

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml

@@ -6,6 +6,7 @@ features:
   - agent-thought
   - vision
   - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml

@@ -6,6 +6,7 @@ features:
   - agent-thought
   - vision
   - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml

@@ -6,6 +6,7 @@ features:
   - agent-thought
   - vision
   - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 31 - 3
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -324,10 +324,32 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         output_tokens = 0
         finish_reason = None
         index = 0
+
+        tool_calls: list[AssistantPromptMessage.ToolCall] = []
+
         for chunk in response:
             if isinstance(chunk, MessageStartEvent):
-                return_model = chunk.message.model
-                input_tokens = chunk.message.usage.input_tokens
+                if hasattr(chunk, 'content_block'):
+                    content_block = chunk.content_block
+                    if isinstance(content_block, dict):
+                        if content_block.get('type') == 'tool_use':
+                            tool_call = AssistantPromptMessage.ToolCall(
+                                id=content_block.get('id'),
+                                type='function',
+                                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                    name=content_block.get('name'),
+                                    arguments=''
+                                )
+                            )
+                            tool_calls.append(tool_call)
+                elif hasattr(chunk, 'delta'):
+                    delta = chunk.delta
+                    if isinstance(delta, dict) and len(tool_calls) > 0:
+                        if delta.get('type') == 'input_json_delta':
+                            tool_calls[-1].function.arguments += delta.get('partial_json', '')
+                elif chunk.message:
+                    return_model = chunk.message.model
+                    input_tokens = chunk.message.usage.input_tokens
             elif isinstance(chunk, MessageDeltaEvent):
                 output_tokens = chunk.usage.output_tokens
                 finish_reason = chunk.delta.stop_reason
@@ -335,13 +357,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                 # transform usage
                 usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
 
+                # transform empty tool call arguments to {}
+                for tool_call in tool_calls:
+                    if not tool_call.function.arguments:
+                        tool_call.function.arguments = '{}'
+
                 yield LLMResultChunk(
                     model=return_model,
                     prompt_messages=prompt_messages,
                     delta=LLMResultChunkDelta(
                         index=index + 1,
                         message=AssistantPromptMessage(
-                            content=''
+                            content='',
+                            tool_calls=tool_calls
                         ),
                         finish_reason=finish_reason,
                         usage=usage