|
@@ -1,5 +1,6 @@
|
|
|
import json
|
|
|
import os
|
|
|
+from typing import Optional
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
|
import pytest
|
|
@@ -7,6 +8,7 @@ import pytest
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
|
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
|
|
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
|
|
+from core.memory.token_buffer_memory import TokenBufferMemory
|
|
|
from core.model_manager import ModelInstance
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
|
@@ -61,6 +63,16 @@ def get_mocked_fetch_model_config(
|
|
|
|
|
|
return MagicMock(return_value=(model_instance, model_config))
|
|
|
|
|
|
+def get_mocked_fetch_memory(memory_text: str):
|
|
|
+ class MemoryMock:
|
|
|
+ def get_history_prompt_text(self, human_prefix: str = "Human",
|
|
|
+ ai_prefix: str = "Assistant",
|
|
|
+ max_token_limit: int = 2000,
|
|
|
+ message_limit: Optional[int] = None):
|
|
|
+ return memory_text
|
|
|
+
|
|
|
+ return MagicMock(return_value=MemoryMock())
|
|
|
+
|
|
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
|
|
def test_function_calling_parameter_extractor(setup_openai_mock):
|
|
|
"""
|
|
@@ -354,4 +366,83 @@ def test_extract_json_response():
|
|
|
hello world.
|
|
|
""")
|
|
|
|
|
|
- assert result['location'] == 'kawaii'
|
|
|
+ assert result['location'] == 'kawaii'
|
|
|
+
|
|
|
+@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
|
|
+def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
|
|
+ """
|
|
|
+ Test chat parameter extractor with memory.
|
|
|
+ """
|
|
|
+ node = ParameterExtractorNode(
|
|
|
+ tenant_id='1',
|
|
|
+ app_id='1',
|
|
|
+ workflow_id='1',
|
|
|
+ user_id='1',
|
|
|
+ invoke_from=InvokeFrom.WEB_APP,
|
|
|
+ user_from=UserFrom.ACCOUNT,
|
|
|
+ config={
|
|
|
+ 'id': 'llm',
|
|
|
+ 'data': {
|
|
|
+ 'title': '123',
|
|
|
+ 'type': 'parameter-extractor',
|
|
|
+ 'model': {
|
|
|
+ 'provider': 'anthropic',
|
|
|
+ 'name': 'claude-2',
|
|
|
+ 'mode': 'chat',
|
|
|
+ 'completion_params': {}
|
|
|
+ },
|
|
|
+ 'query': ['sys', 'query'],
|
|
|
+ 'parameters': [{
|
|
|
+ 'name': 'location',
|
|
|
+ 'type': 'string',
|
|
|
+ 'description': 'location',
|
|
|
+ 'required': True
|
|
|
+ }],
|
|
|
+ 'reasoning_mode': 'prompt',
|
|
|
+ 'instruction': '',
|
|
|
+ 'memory': {
|
|
|
+ 'window': {
|
|
|
+ 'enabled': True,
|
|
|
+ 'size': 50
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ node._fetch_model_config = get_mocked_fetch_model_config(
|
|
|
+ provider='anthropic', model='claude-2', mode='chat', credentials={
|
|
|
+ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
|
|
+ }
|
|
|
+ )
|
|
|
+ node._fetch_memory = get_mocked_fetch_memory('customized memory')
|
|
|
+ db.session.close = MagicMock()
|
|
|
+
|
|
|
+ # construct variable pool
|
|
|
+ pool = VariablePool(system_variables={
|
|
|
+ SystemVariable.QUERY: 'what\'s the weather in SF',
|
|
|
+ SystemVariable.FILES: [],
|
|
|
+ SystemVariable.CONVERSATION_ID: 'abababa',
|
|
|
+ SystemVariable.USER_ID: 'aaa'
|
|
|
+ }, user_inputs={})
|
|
|
+
|
|
|
+ result = node.run(pool)
|
|
|
+
|
|
|
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
|
|
+ assert result.outputs.get('location') == ''
|
|
|
+ assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.'
|
|
|
+ prompts = result.process_data.get('prompts')
|
|
|
+
|
|
|
+ latest_role = None
|
|
|
+ for prompt in prompts:
|
|
|
+ if prompt.get('role') == 'user':
|
|
|
+ if '<structure>' in prompt.get('text'):
|
|
|
+ assert '<structure>\n{"type": "object"' in prompt.get('text')
|
|
|
+ elif prompt.get('role') == 'system':
|
|
|
+ assert 'customized memory' in prompt.get('text')
|
|
|
+
|
|
|
+ if latest_role is not None:
|
|
|
+ assert latest_role != prompt.get('role')
|
|
|
+
|
|
|
+ if prompt.get('role') in ['user', 'assistant']:
|
|
|
+ latest_role = prompt.get('role')
|