Kaynağa Gözat

Fix/incorrect parameter extractor memory (#6038)

Yeuoly 9 ay önce
ebeveyn
işleme
a877d4831d

+ 1 - 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -365,7 +365,7 @@ class ParameterExtractorNode(LLMNode):
             files=[],
             context='',
             memory_config=node_data.memory,
-            memory=memory,
+            memory=None,
             model_config=model_config
         )
 

+ 92 - 1
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py

@@ -1,5 +1,6 @@
 import json
 import os
+from typing import Optional
 from unittest.mock import MagicMock
 
 import pytest
@@ -7,6 +8,7 @@ import pytest
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
 from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
+from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
@@ -61,6 +63,16 @@ def get_mocked_fetch_model_config(
 
     return MagicMock(return_value=(model_instance, model_config))
 
+def get_mocked_fetch_memory(memory_text: str):
+    class MemoryMock:
+        def get_history_prompt_text(self, human_prefix: str = "Human",
+                                ai_prefix: str = "Assistant",
+                                max_token_limit: int = 2000,
+                                message_limit: Optional[int] = None):
+            return memory_text
+
+    return MagicMock(return_value=MemoryMock())
+
 @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
 def test_function_calling_parameter_extractor(setup_openai_mock):
     """
@@ -354,4 +366,83 @@ def test_extract_json_response():
         hello world.                          
     """)
 
-    assert result['location'] == 'kawaii'
+    assert result['location'] == 'kawaii'
+
+@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
+def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
+    """
+    Test chat parameter extractor with memory.
+    """
+    node = ParameterExtractorNode(
+        tenant_id='1',
+        app_id='1',
+        workflow_id='1',
+        user_id='1',
+        invoke_from=InvokeFrom.WEB_APP,
+        user_from=UserFrom.ACCOUNT,
+        config={
+            'id': 'llm',
+            'data': {
+                'title': '123',
+                'type': 'parameter-extractor',
+                'model': {
+                    'provider': 'anthropic',
+                    'name': 'claude-2',
+                    'mode': 'chat',
+                    'completion_params': {}
+                },
+                'query': ['sys', 'query'],
+                'parameters': [{
+                    'name': 'location',
+                    'type': 'string',
+                    'description': 'location',
+                    'required': True
+                }],
+                'reasoning_mode': 'prompt',
+                'instruction': '',
+                'memory': {
+                    'window': {
+                        'enabled': True,
+                        'size': 50
+                    }
+                },
+            }
+        }
+    )
+
+    node._fetch_model_config = get_mocked_fetch_model_config(
+        provider='anthropic', model='claude-2', mode='chat', credentials={
+            'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
+        }
+    )
+    node._fetch_memory = get_mocked_fetch_memory('customized memory')
+    db.session.close = MagicMock()
+
+    # construct variable pool
+    pool = VariablePool(system_variables={
+        SystemVariable.QUERY: 'what\'s the weather in SF',
+        SystemVariable.FILES: [],
+        SystemVariable.CONVERSATION_ID: 'abababa',
+        SystemVariable.USER_ID: 'aaa'
+    }, user_inputs={})
+
+    result = node.run(pool)
+
+    assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs.get('location') == ''
+    assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.'
+    prompts = result.process_data.get('prompts')
+
+    latest_role = None
+    for prompt in prompts:
+        if prompt.get('role') == 'user':
+            if '<structure>' in prompt.get('text'):
+                assert '<structure>\n{"type": "object"' in prompt.get('text')
+        elif prompt.get('role') == 'system':
+            assert 'customized memory' in prompt.get('text')
+
+        if latest_role is not None:
+            assert latest_role != prompt.get('role')
+
+        if prompt.get('role') in ['user', 'assistant']:
+            latest_role = prompt.get('role')