Bläddra i källkod

feat: advanced prompt backend (#1301)

Co-authored-by: takatost <takatost@gmail.com>
Garfield Dai 1 år sedan
förälder
incheckning
42a5b3ec17
61 ändrade filer med 762 tillägg och 576 borttagningar
  1. 22 2
      api/constants/model_template.py
  2. 1 1
      api/controllers/console/__init__.py
  3. 26 0
      api/controllers/console/app/advanced_prompt_template.py
  4. 0 30
      api/controllers/console/app/generator.py
  5. 1 1
      api/controllers/console/app/message.py
  6. 1 1
      api/controllers/web/message.py
  7. 26 64
      api/core/completion.py
  8. 11 11
      api/core/conversation_message_task.py
  9. 25 83
      api/core/generator/llm_generator.py
  10. 2 2
      api/core/indexing_runner.py
  11. 1 1
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  12. 5 5
      api/core/model_providers/models/entity/message.py
  13. 102 17
      api/core/model_providers/models/llm/base.py
  14. 6 1
      api/core/model_providers/providers/anthropic_provider.py
  15. 16 1
      api/core/model_providers/providers/azure_openai_provider.py
  16. 5 1
      api/core/model_providers/providers/baichuan_provider.py
  17. 23 4
      api/core/model_providers/providers/base.py
  18. 6 1
      api/core/model_providers/providers/chatglm_provider.py
  19. 4 1
      api/core/model_providers/providers/huggingface_hub_provider.py
  20. 8 1
      api/core/model_providers/providers/localai_provider.py
  21. 6 1
      api/core/model_providers/providers/minimax_provider.py
  22. 14 2
      api/core/model_providers/providers/openai_provider.py
  23. 4 1
      api/core/model_providers/providers/openllm_provider.py
  24. 5 1
      api/core/model_providers/providers/replicate_provider.py
  25. 6 1
      api/core/model_providers/providers/spark_provider.py
  26. 6 1
      api/core/model_providers/providers/tongyi_provider.py
  27. 7 1
      api/core/model_providers/providers/wenxin_provider.py
  28. 4 1
      api/core/model_providers/providers/xinference_provider.py
  29. 8 1
      api/core/model_providers/providers/zhipuai_provider.py
  30. 37 42
      api/core/orchestrator_rule_parser.py
  31. 79 0
      api/core/prompt/advanced_prompt_templates.py
  32. 12 26
      api/core/prompt/prompt_builder.py
  33. 28 68
      api/core/prompt/prompt_template.py
  34. 2 32
      api/core/prompt/prompts.py
  35. 8 6
      api/core/tool/dataset_retriever_tool.py
  36. 0 1
      api/events/event_handlers/__init__.py
  37. 0 14
      api/events/event_handlers/generate_conversation_summary_when_few_message_created.py
  38. 4 0
      api/fields/app_fields.py
  39. 1 0
      api/fields/conversation_fields.py
  40. 37 0
      api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
  41. 33 2
      api/models/model.py
  42. 56 0
      api/services/advanced_prompt_template_service.py
  43. 70 22
      api/services/app_model_config_service.py
  44. 17 49
      api/services/completion_service.py
  45. 3 0
      api/services/provider_service.py
  46. 0 55
      api/tasks/generate_conversation_summary_task.py
  47. 1 1
      api/tests/integration_tests/models/llm/test_anthropic_model.py
  48. 1 1
      api/tests/integration_tests/models/llm/test_azure_openai_model.py
  49. 3 3
      api/tests/integration_tests/models/llm/test_baichuan_model.py
  50. 2 2
      api/tests/integration_tests/models/llm/test_huggingface_hub_model.py
  51. 1 1
      api/tests/integration_tests/models/llm/test_minimax_model.py
  52. 1 1
      api/tests/integration_tests/models/llm/test_openai_model.py
  53. 1 1
      api/tests/integration_tests/models/llm/test_openllm_model.py
  54. 1 1
      api/tests/integration_tests/models/llm/test_replicate_model.py
  55. 1 1
      api/tests/integration_tests/models/llm/test_spark_model.py
  56. 1 1
      api/tests/integration_tests/models/llm/test_tongyi_model.py
  57. 1 1
      api/tests/integration_tests/models/llm/test_wenxin_model.py
  58. 1 1
      api/tests/integration_tests/models/llm/test_xinference_model.py
  59. 3 3
      api/tests/integration_tests/models/llm/test_zhipuai_model.py
  60. 5 2
      api/tests/unit_tests/model_providers/fake_model_provider.py
  61. 1 1
      api/tests/unit_tests/model_providers/test_base_model_provider.py

+ 22 - 2
api/constants/model_template.py

@@ -31,6 +31,7 @@ model_templates = {
             'model': json.dumps({
                 "provider": "openai",
                 "name": "gpt-3.5-turbo-instruct",
+                "mode": "completion",
                 "completion_params": {
                     "max_tokens": 512,
                     "temperature": 1,
@@ -81,6 +82,7 @@ model_templates = {
             'model': json.dumps({
                 "provider": "openai",
                 "name": "gpt-3.5-turbo",
+                "mode": "chat",
                 "completion_params": {
                     "max_tokens": 512,
                     "temperature": 1,
@@ -137,10 +139,11 @@ demo_model_templates = {
                 },
                 opening_statement='',
                 suggested_questions=None,
-                pre_prompt="Please translate the following text into {{target_language}}:\n",
+                pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
                 model=json.dumps({
                     "provider": "openai",
                     "name": "gpt-3.5-turbo-instruct",
+                    "mode": "completion",
                     "completion_params": {
                         "max_tokens": 1000,
                         "temperature": 0,
@@ -169,6 +172,13 @@ demo_model_templates = {
                                 'Italian',
                             ]
                         }
+                    },{
+                        "paragraph": {
+                            "label": "Query",
+                            "variable": "query",
+                            "required": True,
+                            "default": ""
+                        }
                     }
                 ])
             )
@@ -200,6 +210,7 @@ demo_model_templates = {
                 model=json.dumps({
                     "provider": "openai",
                     "name": "gpt-3.5-turbo",
+                    "mode": "chat",
                     "completion_params": {
                         "max_tokens": 300,
                         "temperature": 0.8,
@@ -255,10 +266,11 @@ demo_model_templates = {
                 },
                 opening_statement='',
                 suggested_questions=None,
-                pre_prompt="请将以下文本翻译为{{target_language}}:\n",
+                pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
                 model=json.dumps({
                     "provider": "openai",
                     "name": "gpt-3.5-turbo-instruct",
+                    "mode": "completion",
                     "completion_params": {
                         "max_tokens": 1000,
                         "temperature": 0,
@@ -287,6 +299,13 @@ demo_model_templates = {
                                 "意大利语",
                             ]
                         }
+                    },{
+                        "paragraph": {
+                            "label": "文本内容",
+                            "variable": "query",
+                            "required": True,
+                            "default": ""
+                        }
                     }
                 ])
             )
@@ -318,6 +337,7 @@ demo_model_templates = {
                 model=json.dumps({
                     "provider": "openai",
                     "name": "gpt-3.5-turbo",
+                    "mode": "chat",
                     "completion_params": {
                         "max_tokens": 300,
                         "temperature": 0.8,

+ 1 - 1
api/controllers/console/__init__.py

@@ -9,7 +9,7 @@ api = ExternalApi(bp)
 from . import setup, version, apikey, admin
 
 # Import app controllers
-from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
+from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
 
 # Import auth controllers
 from .auth import login, oauth, data_source_oauth, activate

+ 26 - 0
api/controllers/console/app/advanced_prompt_template.py

@@ -0,0 +1,26 @@
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from libs.login import login_required
+from services.advanced_prompt_template_service import AdvancedPromptTemplateService
+
+class AdvancedPromptTemplateList(Resource):
+    
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+         
+        parser = reqparse.RequestParser()
+        parser.add_argument('app_mode', type=str, required=True, location='args')
+        parser.add_argument('model_mode', type=str, required=True, location='args')
+        parser.add_argument('has_context', type=str, required=False, default='true', location='args')
+        parser.add_argument('model_name', type=str, required=True, location='args')
+        args = parser.parse_args()
+
+        service = AdvancedPromptTemplateService()
+        return service.get_prompt(args)
+
+api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')

+ 0 - 30
api/controllers/console/app/generator.py

@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
     LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
 
 
-class IntroductionGenerateApi(Resource):
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self):
-        parser = reqparse.RequestParser()
-        parser.add_argument('prompt_template', type=str, required=True, location='json')
-        args = parser.parse_args()
-
-        account = current_user
-
-        try:
-            answer = LLMGenerator.generate_introduction(
-                account.current_tenant_id,
-                args['prompt_template']
-            )
-        except ProviderTokenNotInitError as ex:
-            raise ProviderNotInitializeError(ex.description)
-        except QuotaExceededError:
-            raise ProviderQuotaExceededError()
-        except ModelCurrentlyNotSupportError:
-            raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
-            raise CompletionRequestError(str(e))
-
-        return {'introduction': answer}
-
-
 class RuleGenerateApi(Resource):
     @setup_required
     @login_required
@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
         return rules
 
 
-api.add_resource(IntroductionGenerateApi, '/introduction-generate')
 api.add_resource(RuleGenerateApi, '/rule-generate')

+ 1 - 1
api/controllers/console/app/message.py

@@ -329,7 +329,7 @@ class MessageApi(Resource):
         message_id = str(message_id)
 
         # get app info
-        app_model = _get_app(app_id, 'chat')
+        app_model = _get_app(app_id)
 
         message = db.session.query(Message).filter(
             Message.id == message_id,

+ 1 - 1
api/controllers/web/message.py

@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
         streaming = args['response_mode'] == 'streaming'
 
         try:
-            response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
+            response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
             return compact_response(response)
         except MessageNotExistsError:
             raise NotFound("Message Not Exists.")

+ 26 - 64
api/core/completion.py

@@ -1,4 +1,3 @@
-import json
 import logging
 from typing import Optional, List, Union
 
@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
-from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
-from models.dataset import DocumentSegment, Dataset, Document
-from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
+from core.prompt.prompt_template import PromptTemplateParser
+from models.model import App, AppModelConfig, Account, Conversation, EndUser
 
 
 class Completion:
@@ -30,7 +27,7 @@ class Completion:
         """
         errors: ProviderTokenNotInitError
         """
-        query = PromptBuilder.process_template(query)
+        query = PromptTemplateParser.remove_template_variables(query)
 
         memory = None
         if conversation:
@@ -160,14 +157,28 @@ class Completion:
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
                       fake_response: Optional[str]):
         # get llm prompt
-        prompt_messages, stop_words = model_instance.get_prompt(
-            mode=mode,
-            pre_prompt=app_model_config.pre_prompt,
-            inputs=inputs,
-            query=query,
-            context=agent_execute_result.output if agent_execute_result else None,
-            memory=memory
-        )
+        if app_model_config.prompt_type == 'simple':
+            prompt_messages, stop_words = model_instance.get_prompt(
+                mode=mode,
+                pre_prompt=app_model_config.pre_prompt,
+                inputs=inputs,
+                query=query,
+                context=agent_execute_result.output if agent_execute_result else None,
+                memory=memory
+            )
+        else:
+            prompt_messages = model_instance.get_advanced_prompt(
+                app_mode=mode,
+                app_model_config=app_model_config,
+                inputs=inputs,
+                query=query,
+                context=agent_execute_result.output if agent_execute_result else None,
+                memory=memory
+            )
+
+            model_config = app_model_config.model_dict
+            completion_params = model_config.get("completion_params", {})
+            stop_words = completion_params.get("stop", [])
 
         cls.recale_llm_max_tokens(
             model_instance=model_instance,
@@ -176,7 +187,7 @@ class Completion:
 
         response = model_instance.run(
             messages=prompt_messages,
-            stop=stop_words,
+            stop=stop_words if stop_words else None,
             callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
             fake_response=fake_response
         )
@@ -266,52 +277,3 @@ class Completion:
             model_kwargs = model_instance.get_model_kwargs()
             model_kwargs.max_tokens = max_tokens
             model_instance.set_model_kwargs(model_kwargs)
-
-    @classmethod
-    def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
-                                app_model_config: AppModelConfig, user: Account, streaming: bool):
-
-        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
-            tenant_id=app.tenant_id,
-            model_config=app_model_config.model_dict,
-            streaming=streaming
-        )
-
-        # get llm prompt
-        old_prompt_messages, _ = final_model_instance.get_prompt(
-            mode='completion',
-            pre_prompt=pre_prompt,
-            inputs=message.inputs,
-            query=message.query,
-            context=None,
-            memory=None
-        )
-
-        original_completion = message.answer.strip()
-
-        prompt = MORE_LIKE_THIS_GENERATE_PROMPT
-        prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
-
-        prompt_messages = [PromptMessage(content=prompt)]
-
-        conversation_message_task = ConversationMessageTask(
-            task_id=task_id,
-            app=app,
-            app_model_config=app_model_config,
-            user=user,
-            inputs=message.inputs,
-            query=message.query,
-            is_override=True if message.override_model_configs else False,
-            streaming=streaming,
-            model_instance=final_model_instance
-        )
-
-        cls.recale_llm_max_tokens(
-            model_instance=final_model_instance,
-            prompt_messages=prompt_messages
-        )
-
-        final_model_instance.run(
-            messages=prompt_messages,
-            callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
-        )

+ 11 - 11
api/core/conversation_message_task.py

@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.message import to_prompt_messages, MessageType
 from core.model_providers.models.llm.base import BaseLLM
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import JinjaPromptTemplate
+from core.prompt.prompt_template import PromptTemplateParser
 from events.message_event import message_was_created
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
@@ -74,10 +74,10 @@ class ConversationMessageTask:
         if self.mode == 'chat':
             introduction = self.app_model_config.opening_statement
             if introduction:
-                prompt_template = JinjaPromptTemplate.from_template(template=introduction)
-                prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
+                prompt_template = PromptTemplateParser(template=introduction)
+                prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
                 try:
-                    introduction = prompt_template.format(**prompt_inputs)
+                    introduction = prompt_template.format(prompt_inputs)
                 except KeyError:
                     pass
 
@@ -150,12 +150,12 @@ class ConversationMessageTask:
         message_tokens = llm_message.prompt_tokens
         answer_tokens = llm_message.completion_tokens
 
-        message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
-        message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
+        message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
+        message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
         answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
         answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
 
-        message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
+        message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
         answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
         total_price = message_total_price + answer_total_price
 
@@ -163,7 +163,7 @@ class ConversationMessageTask:
         self.message.message_tokens = message_tokens
         self.message.message_unit_price = message_unit_price
         self.message.message_price_unit = message_price_unit
-        self.message.answer = PromptBuilder.process_template(
+        self.message.answer = PromptTemplateParser.remove_template_variables(
             llm_message.completion.strip()) if llm_message.completion else ''
         self.message.answer_tokens = answer_tokens
         self.message.answer_unit_price = answer_unit_price
@@ -226,15 +226,15 @@ class ConversationMessageTask:
 
     def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
                      agent_loop: AgentLoop):
-        agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
-        agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
+        agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
+        agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
         agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
         agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
 
-        loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
+        loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
         loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
         loop_total_price = loop_message_total_price + loop_answer_total_price
 

+ 25 - 83
api/core/generator/llm_generator.py

@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
-from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
-from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
-    GENERATOR_QA_PROMPT
+from core.prompt.prompt_template import PromptTemplateParser
+from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
 
 
 class LLMGenerator:
@@ -44,78 +43,19 @@ class LLMGenerator:
 
         return answer.strip()
 
-    @classmethod
-    def generate_conversation_summary(cls, tenant_id: str, messages):
-        max_tokens = 200
-
-        model_instance = ModelFactory.get_text_generation_model(
-            tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                max_tokens=max_tokens
-            )
-        )
-
-        prompt = CONVERSATION_SUMMARY_PROMPT
-        prompt_with_empty_context = prompt.format(context='')
-        prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
-        max_context_token_length = model_instance.model_rules.max_tokens.max
-        max_context_token_length = max_context_token_length if max_context_token_length else 1500
-        rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
-
-        context = ''
-        for message in messages:
-            if not message.answer:
-                continue
-
-            if len(message.query) > 2000:
-                query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
-            else:
-                query = message.query
-
-            if len(message.answer) > 2000:
-                answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
-            else:
-                answer = message.answer
-
-            message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
-            if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
-                context += message_qa_text
-
-        if not context:
-            return '[message too long, no summary]'
-
-        prompt = prompt.format(context=context)
-        prompts = [PromptMessage(content=prompt)]
-        response = model_instance.run(prompts)
-        answer = response.content
-        return answer.strip()
-
-    @classmethod
-    def generate_introduction(cls, tenant_id: str, pre_prompt: str):
-        prompt = INTRODUCTION_GENERATE_PROMPT
-        prompt = prompt.format(prompt=pre_prompt)
-
-        model_instance = ModelFactory.get_text_generation_model(
-            tenant_id=tenant_id
-        )
-
-        prompts = [PromptMessage(content=prompt)]
-        response = model_instance.run(prompts)
-        answer = response.content
-        return answer.strip()
-
     @classmethod
     def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
         output_parser = SuggestedQuestionsAfterAnswerOutputParser()
         format_instructions = output_parser.get_format_instructions()
 
-        prompt = JinjaPromptTemplate(
-            template="{{histories}}\n{{format_instructions}}\nquestions:\n",
-            input_variables=["histories"],
-            partial_variables={"format_instructions": format_instructions}
+        prompt_template = PromptTemplateParser(
+            template="{{histories}}\n{{format_instructions}}\nquestions:\n"
         )
 
-        _input = prompt.format_prompt(histories=histories)
+        prompt = prompt_template.format({
+            "histories": histories,
+            "format_instructions": format_instructions
+        })
 
         try:
             model_instance = ModelFactory.get_text_generation_model(
@@ -128,10 +68,10 @@ class LLMGenerator:
         except ProviderTokenNotInitError:
             return []
 
-        prompts = [PromptMessage(content=_input.to_string())]
+        prompt_messages = [PromptMessage(content=prompt)]
 
         try:
-            output = model_instance.run(prompts)
+            output = model_instance.run(prompt_messages)
             questions = output_parser.parse(output.content)
         except LLMError:
             questions = []
@@ -145,19 +85,21 @@ class LLMGenerator:
     def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
         output_parser = RuleConfigGeneratorOutputParser()
 
-        prompt = OutLinePromptTemplate(
-            template=output_parser.get_format_instructions(),
-            input_variables=["audiences", "hoping_to_solve"],
-            partial_variables={
-                "variable": '{variable}',
-                "lanA": '{lanA}',
-                "lanB": '{lanB}',
-                "topic": '{topic}'
-            },
-            validate_template=False
+        prompt_template = PromptTemplateParser(
+            template=output_parser.get_format_instructions()
         )
 
-        _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
+        prompt = prompt_template.format(
+            inputs={
+                "audiences": audiences,
+                "hoping_to_solve": hoping_to_solve,
+                "variable": "{{variable}}",
+                "lanA": "{{lanA}}",
+                "lanB": "{{lanB}}",
+                "topic": "{{topic}}"
+            },
+            remove_template_variables=False
+        )
 
         model_instance = ModelFactory.get_text_generation_model(
             tenant_id=tenant_id,
@@ -167,10 +109,10 @@ class LLMGenerator:
             )
         )
 
-        prompts = [PromptMessage(content=_input.to_string())]
+        prompt_messages = [PromptMessage(content=prompt)]
 
         try:
-            output = model_instance.run(prompts)
+            output = model_instance.run(prompt_messages)
             rule_config = output_parser.parse(output.content)
         except LLMError as e:
             raise e

+ 2 - 2
api/core/indexing_runner.py

@@ -286,7 +286,7 @@ class IndexingRunner:
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
-                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
+                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
                     "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "preview": preview_texts
@@ -383,7 +383,7 @@ class IndexingRunner:
                     "total_segments": total_segments * 20,
                     "tokens": total_segments * 2000,
                     "total_price": '{:f}'.format(
-                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
+                        text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
                     "currency": embedding_model.get_currency(),
                     "qa_preview": document_qa_list,
                     "preview": preview_texts

+ 1 - 1
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
 
         chat_messages: List[PromptMessage] = []
         for message in messages:
-            chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
+            chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
             chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
 
         if not chat_messages:

+ 5 - 5
api/core/model_providers/models/entity/message.py

@@ -13,13 +13,13 @@ class LLMRunResult(BaseModel):
 
 
 class MessageType(enum.Enum):
-    HUMAN = 'human'
+    USER = 'user'
     ASSISTANT = 'assistant'
     SYSTEM = 'system'
 
 
 class PromptMessage(BaseModel):
-    type: MessageType = MessageType.HUMAN
+    type: MessageType = MessageType.USER
     content: str = ''
     function_call: dict = None
 
@@ -27,7 +27,7 @@ class PromptMessage(BaseModel):
 def to_lc_messages(messages: list[PromptMessage]):
     lc_messages = []
     for message in messages:
-        if message.type == MessageType.HUMAN:
+        if message.type == MessageType.USER:
             lc_messages.append(HumanMessage(content=message.content))
         elif message.type == MessageType.ASSISTANT:
             additional_kwargs = {}
@@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
     prompt_messages = []
     for message in messages:
         if isinstance(message, HumanMessage):
-            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
         elif isinstance(message, AIMessage):
             message_kwargs = {
                 'content': message.content,
@@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
         elif isinstance(message, SystemMessage):
             prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
         elif isinstance(message, FunctionMessage):
-            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
+            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
     return prompt_messages
 
 

+ 102 - 17
api/core/model_providers/models/llm/base.py

@@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.providers.base import BaseModelProvider
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import JinjaPromptTemplate
+from core.prompt.prompt_template import PromptTemplateParser
 from core.third_party.langchain.llms.fake import FakeLLM
 import logging
 
@@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel):
         :param message_type:
         :return:
         """
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+        if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
             unit_price = self.price_config['prompt']
         else:
             unit_price = self.price_config['completion']
@@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel):
         :param message_type:
         :return: decimal.Decimal('0.0001')
         """
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+        if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
             unit_price = self.price_config['prompt']
         else:
             unit_price = self.price_config['completion']
@@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel):
         :param message_type:
         :return: decimal.Decimal('0.000001')
         """
-        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
+        if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
             price_unit = self.price_config['unit']
         else:
             price_unit = self.price_config['unit']
@@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel):
         prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
         return [PromptMessage(content=prompt)], stops
 
+    def get_advanced_prompt(self, app_mode: str,
+                   app_model_config: str, inputs: dict,
+                   query: str,
+                   context: Optional[str],
+                   memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
+
+        model_mode = app_model_config.model_dict['mode']
+        conversation_histories_role = {}
+
+        raw_prompt_list = []
+        prompt_messages = []
+
+        if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
+            prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
+            raw_prompt_list = [{
+                'role': MessageType.USER.value,
+                'text': prompt_text
+            }]
+            conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
+        elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
+            raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
+        elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
+            raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
+        elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
+            prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
+            raw_prompt_list = [{
+                'role': MessageType.USER.value,
+                'text': prompt_text
+            }]
+        else:
+            raise Exception("app_mode or model_mode not support")
+
+        for prompt_item in raw_prompt_list:
+            prompt = prompt_item['text']
+
+            # set prompt template variables
+            prompt_template = PromptTemplateParser(template=prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+            if '#context#' in prompt:
+                if context:
+                    prompt_inputs['#context#'] = context
+                else:
+                    prompt_inputs['#context#'] = ''
+
+            if '#query#' in prompt:
+                if query:
+                    prompt_inputs['#query#'] = query
+                else:
+                    prompt_inputs['#query#'] = ''
+
+            if '#histories#' in prompt:
+                if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
+                    memory.human_prefix = conversation_histories_role['user_prefix']
+                    memory.ai_prefix = conversation_histories_role['assistant_prefix']
+                    histories = self._get_history_messages_from_memory(memory, 2000)
+                    prompt_inputs['#histories#'] = histories
+                else:
+                    prompt_inputs['#histories#'] = ''
+
+            prompt = prompt_template.format(
+                prompt_inputs
+            )
+
+            prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
+
+        if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
+            memory.human_prefix = MessageType.USER.value
+            memory.ai_prefix = MessageType.ASSISTANT.value
+            histories = self._get_history_messages_list_from_memory(memory, 2000)
+            prompt_messages.extend(histories)
+
+        if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
+            prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
+
+        return prompt_messages
+
     def prompt_file_name(self, mode: str) -> str:
         if mode == 'completion':
             return 'common_completion'
@@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel):
                              memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
         context_prompt_content = ''
         if context and 'context_prompt' in prompt_rules:
-            prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
+            prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
             context_prompt_content = prompt_template.format(
-                context=context
+                {'context': context}
             )
 
         pre_prompt_content = ''
         if pre_prompt:
-            prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
-            prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
+            prompt_template = PromptTemplateParser(template=pre_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
             pre_prompt_content = prompt_template.format(
-                **prompt_inputs
+                prompt_inputs
             )
 
         prompt = ''
@@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel):
             memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
 
             histories = self._get_history_messages_from_memory(memory, rest_tokens)
-            prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
-            histories_prompt_content = prompt_template.format(
-                histories=histories
-            )
+            prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
+            histories_prompt_content = prompt_template.format({'histories': histories})
 
             prompt = ''
             for order in prompt_rules['system_prompt_orders']:
@@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel):
                 elif order == 'histories_prompt':
                     prompt += histories_prompt_content
 
-        prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
-        query_prompt_content = prompt_template.format(
-            query=query
-        )
+        prompt_template = PromptTemplateParser(template=query_prompt)
+        query_prompt_content = prompt_template.format({'query': query})
 
         prompt += query_prompt_content
 
@@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel):
         external_context = memory.load_memory_variables({})
         return external_context[memory_key]
 
+    def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
+                                          max_token_limit: int) -> List[PromptMessage]:
+        """Get memory messages."""
+        memory.max_token_limit = max_token_limit
+        memory.return_messages = True
+        memory_key = memory.memory_variables[0]
+        external_context = memory.load_memory_variables({})
+        memory.return_messages = False
+        return to_prompt_messages(external_context[memory_key])
+
     def _get_prompt_from_messages(self, messages: List[PromptMessage],
                                   model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
         if not model_mode:

+ 6 - 1
api/core/model_providers/providers/anthropic_provider.py

@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
 from core.model_providers.models.entity.provider import ModelFeature
 from core.model_providers.models.llm.anthropic_model import AnthropicModel
 from core.model_providers.models.llm.base import ModelType
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
                 {
                     'id': 'claude-instant-1',
                     'name': 'claude-instant-1',
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'claude-2',
                     'name': 'claude-2',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.CHAT.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 16 - 1
api/core/model_providers/providers/azure_openai_provider.py

@@ -12,7 +12,7 @@ from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
     AZURE_OPENAI_API_VERSION
-from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
+from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
 from core.model_providers.models.entity.provider import ModelFeature
 from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
                 }
 
                 credentials = json.loads(provider_model.encrypted_config)
+
+                if provider_model.model_type == ModelType.TEXT_GENERATION.value:
+                    model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
+
                 if credentials['base_model_name'] in [
                     'gpt-4',
                     'gpt-4-32k',
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
 
         return model_list
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        if model_name == 'text-davinci-003':
+            return ModelMode.COMPLETION.value
+        else:
+            return ModelMode.CHAT.value
+
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         if model_type == ModelType.TEXT_GENERATION:
             models = [
                 {
                     'id': 'gpt-3.5-turbo',
                     'name': 'gpt-3.5-turbo',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-3.5-turbo-16k',
                     'name': 'gpt-3.5-turbo-16k',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-4',
                     'name': 'gpt-4',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-4-32k',
                     'name': 'gpt-4-32k',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
                 {
                     'id': 'text-davinci-003',
                     'name': 'text-davinci-003',
+                    'mode': ModelMode.COMPLETION.value,
                 }
             ]
 

+ 5 - 1
api/core/model_providers/providers/baichuan_provider.py

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.baichuan_model import BaichuanModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
         Returns the name of a provider.
         """
         return 'baichuan'
+    
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.CHAT.value
 
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         if model_type == ModelType.TEXT_GENERATION:
@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
                 {
                     'id': 'baichuan2-53b',
                     'name': 'Baichuan2-53B',
+                    'mode': ModelMode.CHAT.value,
                 }
             ]
         else:

+ 23 - 4
api/core/model_providers/providers/base.py

@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
             ProviderModel.is_valid == True
         ).order_by(ProviderModel.created_at.asc()).all()
 
-        return [{
-            'id': provider_model.model_name,
-            'name': provider_model.model_name
-        } for provider_model in provider_models]
+        provider_model_list = []
+        for provider_model in provider_models:
+            provider_model_dict = {
+                'id': provider_model.model_name,
+                'name': provider_model.model_name
+            }
+
+            if model_type == ModelType.TEXT_GENERATION:
+                provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
+
+            provider_model_list.append(provider_model_dict)
+
+        return provider_model_list
 
     @abstractmethod
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
         """
         raise NotImplementedError
 
+    @abstractmethod
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        """
+        get text generation model mode.
+
+        :param model_name:
+        :return:
+        """
+        raise NotImplementedError
+
     @abstractmethod
     def get_model_class(self, model_type: ModelType) -> Type:
         """

+ 6 - 1
api/core/model_providers/providers/chatglm_provider.py

@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.chatglm_model import ChatGLMModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from models.provider import ProviderType
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
                 {
                     'id': 'chatglm2-6b',
                     'name': 'ChatGLM2-6B',
+                    'mode': ModelMode.COMPLETION.value,
                 },
                 {
                     'id': 'chatglm-6b',
                     'name': 'ChatGLM-6B',
+                    'mode': ModelMode.COMPLETION.value,
                 }
             ]
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 4 - 1
api/core/model_providers/providers/huggingface_hub_provider.py

@@ -5,7 +5,7 @@ import requests
 from huggingface_hub import HfApi
 
 from core.helper import encrypter
-from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
 from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 8 - 1
api/core/model_providers/providers/localai_provider.py

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
+from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
 from core.model_providers.models.llm.localai_model import LocalAIModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
+        if credentials['completion_type'] == 'chat_completion':
+            return ModelMode.CHAT.value
+        else:
+            return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 6 - 1
api/core/model_providers/providers/minimax_provider.py

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.minimax_model import MinimaxModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
@@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
                 {
                     'id': 'abab5.5-chat',
                     'name': 'abab5.5-chat',
+                    'mode': ModelMode.COMPLETION.value,
                 },
                 {
                     'id': 'abab5-chat',
                     'name': 'abab5-chat',
+                    'mode': ModelMode.COMPLETION.value,
                 }
             ]
         elif model_type == ModelType.EMBEDDINGS:
@@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 14 - 2
api/core/model_providers/providers/openai_provider.py

@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
 from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
-from core.model_providers.models.llm.openai_model import OpenAIModel
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
+from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
 from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.model_providers.providers.hosted import hosted_model_providers
@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-3.5-turbo',
                     'name': 'gpt-3.5-turbo',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-3.5-turbo-instruct',
                     'name': 'GPT-3.5-Turbo-Instruct',
+                    'mode': ModelMode.COMPLETION.value,
                 },
                 {
                     'id': 'gpt-3.5-turbo-16k',
                     'name': 'gpt-3.5-turbo-16k',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-4',
                     'name': 'gpt-4',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
                 {
                     'id': 'gpt-4-32k',
                     'name': 'gpt-4-32k',
+                    'mode': ModelMode.CHAT.value,
                     'features': [
                         ModelFeature.AGENT_THOUGHT.value
                     ]
@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
                 {
                     'id': 'text-davinci-003',
                     'name': 'text-davinci-003',
+                    'mode': ModelMode.COMPLETION.value,
                 }
             ]
 
@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        if model_name in COMPLETION_MODELS:
+            return ModelMode.COMPLETION.value
+        else:
+            return ModelMode.CHAT.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 4 - 1
api/core/model_providers/providers/openllm_provider.py

@@ -3,7 +3,7 @@ from typing import Type
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
-from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
 from core.model_providers.models.llm.openllm_model import OpenLLMModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 5 - 1
api/core/model_providers/providers/replicate_provider.py

@@ -6,7 +6,8 @@ import replicate
 from replicate.exceptions import ReplicateError
 
 from core.helper import encrypter
-from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
+    ModelMode
 from core.model_providers.models.llm.replicate_model import ReplicateModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 6 - 1
api/core/model_providers/providers/spark_provider.py

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.spark_model import SparkModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.third_party.langchain.llms.spark import ChatSpark
@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
                 {
                     'id': 'spark',
                     'name': 'Spark V1.5',
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'spark-v2',
                     'name': 'Spark V2.0',
+                    'mode': ModelMode.CHAT.value,
                 }
             ]
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.CHAT.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 6 - 1
api/core/model_providers/providers/tongyi_provider.py

@@ -4,7 +4,7 @@ from typing import Type
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.tongyi_model import TongyiModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider):
                 {
                     'id': 'qwen-turbo',
                     'name': 'qwen-turbo',
+                    'mode': ModelMode.COMPLETION.value,
                 },
                 {
                     'id': 'qwen-plus',
                     'name': 'qwen-plus',
+                    'mode': ModelMode.COMPLETION.value,
                 }
             ]
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 7 - 1
api/core/model_providers/providers/wenxin_provider.py

@@ -4,7 +4,7 @@ from typing import Type
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.wenxin_model import WenxinModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.third_party.langchain.llms.wenxin import Wenxin
@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
                 {
                     'id': 'ernie-bot',
                     'name': 'ERNIE-Bot',
+                    'mode': ModelMode.COMPLETION.value,
                 },
                 {
                     'id': 'ernie-bot-turbo',
                     'name': 'ERNIE-Bot-turbo',
+                    'mode': ModelMode.COMPLETION.value,
                 },
                 {
                     'id': 'bloomz-7b',
                     'name': 'BLOOMZ-7B',
+                    'mode': ModelMode.COMPLETION.value,
                 }
             ]
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 4 - 1
api/core/model_providers/providers/xinference_provider.py

@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
-from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
+from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
 from core.model_providers.models.llm.xinference_model import XinferenceModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 8 - 1
api/core/model_providers/providers/zhipuai_provider.py

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
-from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
 from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
                 {
                     'id': 'chatglm_pro',
                     'name': 'chatglm_pro',
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'chatglm_std',
                     'name': 'chatglm_std',
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'chatglm_lite',
                     'name': 'chatglm_lite',
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'chatglm_lite_32k',
                     'name': 'chatglm_lite_32k',
+                    'mode': ModelMode.CHAT.value,
                 }
             ]
         elif model_type == ModelType.EMBEDDINGS:
@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
         else:
             return []
 
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.CHAT.value
+
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
         Returns the model class.

+ 37 - 42
api/core/orchestrator_rule_parser.py

@@ -1,4 +1,3 @@
-import math
 from typing import Optional
 
 from langchain import WikipediaAPIWrapper
@@ -50,6 +49,7 @@ class OrchestratorRuleParser:
             tool_configs = agent_mode_config.get('tools', [])
             agent_provider_name = model_dict.get('provider', 'openai')
             agent_model_name = model_dict.get('name', 'gpt-4')
+            dataset_configs = self.app_model_config.dataset_configs_dict
 
             agent_model_instance = ModelFactory.get_text_generation_model(
                 tenant_id=self.tenant_id,
@@ -96,13 +96,14 @@ class OrchestratorRuleParser:
                 summary_model_instance = None
 
             tools = self.to_tools(
-                agent_model_instance=agent_model_instance,
                 tool_configs=tool_configs,
+                callbacks=[agent_callback, DifyStdOutCallbackHandler()],
+                agent_model_instance=agent_model_instance,
                 conversation_message_task=conversation_message_task,
                 rest_tokens=rest_tokens,
-                callbacks=[agent_callback, DifyStdOutCallbackHandler()],
                 return_resource=return_resource,
-                retriever_from=retriever_from
+                retriever_from=retriever_from,
+                dataset_configs=dataset_configs
             )
 
             if len(tools) == 0:
@@ -170,20 +171,12 @@ class OrchestratorRuleParser:
 
         return None
 
-    def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
-                 conversation_message_task: ConversationMessageTask,
-                 rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
-                 retriever_from: str = 'dev') -> list[BaseTool]:
+    def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
         """
         Convert app agent tool configs to tools
 
-        :param agent_model_instance:
-        :param rest_tokens:
         :param tool_configs: app agent tool configs
-        :param conversation_message_task:
         :param callbacks:
-        :param return_resource:
-        :param retriever_from:
         :return:
         """
         tools = []
@@ -195,15 +188,15 @@ class OrchestratorRuleParser:
 
             tool = None
             if tool_type == "dataset":
-                tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
+                tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
             elif tool_type == "web_reader":
-                tool = self.to_web_reader_tool(agent_model_instance)
+                tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
             elif tool_type == "google_search":
-                tool = self.to_google_search_tool()
+                tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
             elif tool_type == "wikipedia":
-                tool = self.to_wikipedia_tool()
+                tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
             elif tool_type == "current_datetime":
-                tool = self.to_current_datetime_tool()
+                tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
 
             if tool:
                 if tool.callbacks is not None:
@@ -215,12 +208,15 @@ class OrchestratorRuleParser:
         return tools
 
     def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
-                                  rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
+                                  dataset_configs: dict, rest_tokens: int,
+                                  return_resource: bool = False, retriever_from: str = 'dev',
+                                  **kwargs) \
             -> Optional[BaseTool]:
         """
         A dataset tool is a tool that can be used to retrieve information from a dataset
         :param rest_tokens:
         :param tool_config:
+        :param dataset_configs:
         :param conversation_message_task:
         :param return_resource:
         :param retriever_from:
@@ -238,10 +234,20 @@ class OrchestratorRuleParser:
         if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
             return None
 
-        k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
+        top_k = dataset_configs.get("top_k", 2)
+
+        # dynamically adjust top_k when the remaining token number is not enough to support top_k
+        top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
+
+        score_threshold = None
+        score_threshold_config = dataset_configs.get("score_threshold")
+        if score_threshold_config and score_threshold_config.get("enable"):
+            score_threshold = score_threshold_config.get("value")
+
         tool = DatasetRetrieverTool.from_dataset(
             dataset=dataset,
-            k=k,
+            top_k=top_k,
+            score_threshold=score_threshold,
             callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
             conversation_message_task=conversation_message_task,
             return_resource=return_resource,
@@ -250,7 +256,7 @@ class OrchestratorRuleParser:
 
         return tool
 
-    def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
+    def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
         """
         A tool for reading web pages
 
@@ -278,7 +284,7 @@ class OrchestratorRuleParser:
 
         return tool
 
-    def to_google_search_tool(self) -> Optional[BaseTool]:
+    def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
         tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
         func_kwargs = tool_provider.credentials_to_func_kwargs()
         if not func_kwargs:
@@ -296,12 +302,12 @@ class OrchestratorRuleParser:
 
         return tool
 
-    def to_current_datetime_tool(self) -> Optional[BaseTool]:
+    def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
         tool = DatetimeTool()
 
         return tool
 
-    def to_wikipedia_tool(self) -> Optional[BaseTool]:
+    def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
         class WikipediaInput(BaseModel):
             query: str = Field(..., description="search query.")
 
@@ -312,22 +318,18 @@ class OrchestratorRuleParser:
         )
 
     @classmethod
-    def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
-        DEFAULT_K = 2
-        CONTEXT_TOKENS_PERCENT = 0.3
-        MAX_K = 10
-
+    def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
         if rest_tokens == -1:
-            return DEFAULT_K
+            return top_k
 
         processing_rule = dataset.latest_process_rule
         if not processing_rule:
-            return DEFAULT_K
+            return top_k
 
         if processing_rule.mode == "custom":
             rules = processing_rule.rules_dict
             if not rules:
-                return DEFAULT_K
+                return top_k
 
             segmentation = rules["segmentation"]
             segment_max_tokens = segmentation["max_tokens"]
@@ -335,14 +337,7 @@ class OrchestratorRuleParser:
             segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
 
         # when rest_tokens is less than default context tokens
-        if rest_tokens < segment_max_tokens * DEFAULT_K:
+        if rest_tokens < segment_max_tokens * top_k:
             return rest_tokens // segment_max_tokens
 
-        context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
-
-        # when context_limit_tokens is less than default context tokens, use default_k
-        if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
-            return DEFAULT_K
-
-        # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
-        return min(context_limit_tokens // segment_max_tokens, MAX_K)
+        return min(top_k, 10)

+ 79 - 0
api/core/prompt/advanced_prompt_templates.py

@@ -0,0 +1,79 @@
+CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
+
+BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
+
+CHAT_APP_COMPLETION_PROMPT_CONFIG = {
+    "completion_prompt_config": {
+        "prompt": {
+            "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
+        },
+        "conversation_histories_role": {
+            "user_prefix": "Human",
+            "assistant_prefix": "Assistant"
+        }
+    }
+}
+
+CHAT_APP_CHAT_PROMPT_CONFIG = { 
+    "chat_prompt_config": {
+        "prompt": [{
+            "role": "system",
+            "text": "{{#pre_prompt#}}"
+        }]
+    }
+}
+
+COMPLETION_APP_CHAT_PROMPT_CONFIG = {
+    "chat_prompt_config": {
+        "prompt": [{
+            "role": "user",
+            "text": "{{#pre_prompt#}}"
+        }]
+    }
+}
+
+COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
+    "completion_prompt_config": {
+        "prompt": {
+            "text": "{{#pre_prompt#}}"
+        }
+    }
+}
+
+BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
+    "completion_prompt_config": {
+        "prompt": {
+            "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
+        },
+        "conversation_histories_role": {
+            "user_prefix": "用户",
+            "assistant_prefix": "助手"
+        }
+    }
+}
+
+BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { 
+    "chat_prompt_config": {
+        "prompt": [{
+            "role": "system",
+            "text": "{{#pre_prompt#}}"
+        }]
+    }
+}
+
+BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
+    "chat_prompt_config": {
+        "prompt": [{
+            "role": "user",
+            "text": "{{#pre_prompt#}}"
+        }]
+    }
+}
+
+BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
+    "completion_prompt_config": {
+        "prompt": {
+            "text": "{{#pre_prompt#}}"
+        }
+    }
+}

+ 12 - 26
api/core/prompt/prompt_builder.py

@@ -1,38 +1,24 @@
-import re
+from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
 
-from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
-from langchain.schema import BaseMessage
-
-from core.prompt.prompt_template import JinjaPromptTemplate
+from core.prompt.prompt_template import PromptTemplateParser
 
 
 class PromptBuilder:
+    @classmethod
+    def parse_prompt(cls, prompt: str, inputs: dict) -> str:
+        prompt_template = PromptTemplateParser(prompt)
+        prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+        prompt = prompt_template.format(prompt_inputs)
+        return prompt
+
     @classmethod
     def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
-        system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
-        prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
-        system_message = system_prompt_template.format(**prompt_inputs)
-        return system_message
+        return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
 
     @classmethod
     def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
-        ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
-        prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
-        ai_message = ai_prompt_template.format(**prompt_inputs)
-        return ai_message
+        return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
 
     @classmethod
     def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
-        human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
-        human_message = human_prompt_template.format(**inputs)
-        return human_message
-
-    @classmethod
-    def process_template(cls, template: str):
-        processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
-        # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
-        # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
-        return processed_template
+        return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))

+ 28 - 68
api/core/prompt/prompt_template.py

@@ -1,79 +1,39 @@
 import re
-from typing import Any
 
-from jinja2 import Environment, meta
-from langchain import PromptTemplate
-from langchain.formatting import StrictFormatter
+REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
 
 
-class JinjaPromptTemplate(PromptTemplate):
-    template_format: str = "jinja2"
-    """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
+class PromptTemplateParser:
+    """
+    Rules:
 
-    @classmethod
-    def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
-        """Load a prompt template from a template."""
-        env = Environment()
-        template = template.replace("{{}}", "{}")
-        ast = env.parse(template)
-        input_variables = meta.find_undeclared_variables(ast)
-
-        if "partial_variables" in kwargs:
-            partial_variables = kwargs["partial_variables"]
-            input_variables = {
-                var for var in input_variables if var not in partial_variables
-            }
-
-        return cls(
-            input_variables=list(sorted(input_variables)), template=template, **kwargs
-        )
-
-
-class OutLinePromptTemplate(PromptTemplate):
-    @classmethod
-    def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
-        """Load a prompt template from a template."""
-        input_variables = {
-            v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
-        }
-        return cls(
-            input_variables=list(sorted(input_variables)), template=template, **kwargs
-        )
-
-    def format(self, **kwargs: Any) -> str:
-        """Format the prompt with the inputs.
+    1. Template variables must be enclosed in `{{}}`.
+    2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
+       and can only start with letters and underscores.
+    3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
+    4. In addition to the above, 3 types of special template variable Keys are accepted:
+       `{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
+    """
 
-        Args:
-            kwargs: Any arguments to be passed to the prompt template.
+    def __init__(self, template: str):
+        self.template = template
+        self.variable_keys = self.extract()
 
-        Returns:
-            A formatted string.
+    def extract(self) -> list:
+        # Regular expression to match the template rules
+        return re.findall(REGEX, self.template)
 
-        Example:
+    def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
+        def replacer(match):
+            key = match.group(1)
+            value = inputs.get(key, match.group(0))  # return original matched string if key not found
 
-        .. code-block:: python
+            if remove_template_variables:
+                return PromptTemplateParser.remove_template_variables(value)
+            return value
 
-            prompt.format(variable1="foo")
-        """
-        kwargs = self._merge_partial_and_user_variables(**kwargs)
-        return OneLineFormatter().format(self.template, **kwargs)
+        return re.sub(REGEX, replacer, self.template)
 
-
-class OneLineFormatter(StrictFormatter):
-    def parse(self, format_string):
-        last_end = 0
-        results = []
-        for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
-            field_name = match.group(1)
-            start, end = match.span()
-
-            literal_text = format_string[last_end:start]
-            last_end = end
-
-            results.append((literal_text, field_name, '', None))
-
-        remaining_literal_text = format_string[last_end:]
-        if remaining_literal_text:
-            results.append((remaining_literal_text, None, None, None))
-
-        return results
+    @classmethod
+    def remove_template_variables(cls, text: str):
+        return re.sub(REGEX, r'{\1}', text)

+ 2 - 32
api/core/prompt/prompts.py

@@ -61,36 +61,6 @@ User Input: yo, 你今天咋样?
 User Input: 
 """
 
-CONVERSATION_SUMMARY_PROMPT = (
-    "Please generate a short summary of the following conversation.\n"
-    "If the following conversation communicating in English, you should only return an English summary.\n"
-    "If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
-    "[Conversation Start]\n"
-    "{context}\n"
-    "[Conversation End]\n\n"
-    "summary:"
-)
-
-INTRODUCTION_GENERATE_PROMPT = (
-    "I am designing a product for users to interact with an AI through dialogue. "
-    "The Prompt given to the AI before the conversation is:\n\n"
-    "```\n{prompt}\n```\n\n"
-    "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
-    "Do not reveal the developer's motivation or deep logic behind the Prompt, "
-    "but focus on building a relationship with the user:\n"
-)
-
-MORE_LIKE_THIS_GENERATE_PROMPT = (
-    "-----\n"
-    "{original_completion}\n"
-    "-----\n\n"
-    "Please use the above content as a sample for generating the result, "
-    "and include key information points related to the original sample in the result. "
-    "Try to rephrase this information in different ways and predict according to the rules below.\n\n"
-    "-----\n"
-    "{prompt}\n"
-)
-
 SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
     "Please help me predict the three most likely questions that human would ask, "
     "and keeping each question under 20 characters.\n"
@@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
 ```
 
 << MY INTENDED AUDIENCES >>
-{audiences}
+{{audiences}}
 
 << HOPING TO SOLVE >>
-{hoping_to_solve}
+{{hoping_to_solve}}
 
 << OUTPUT >>
 """

+ 8 - 6
api/core/tool/dataset_retriever_tool.py

@@ -1,5 +1,5 @@
 import json
-from typing import Type
+from typing import Type, Optional
 
 from flask import current_app
 from langchain.tools import BaseTool
@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool):
 
     tenant_id: str
     dataset_id: str
-    k: int = 3
+    top_k: int = 2
+    score_threshold: Optional[float] = None
     conversation_message_task: ConversationMessageTask
     return_resource: bool
     retriever_from: str
@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
                 )
             )
 
-            documents = kw_table_index.search(query, search_kwargs={'k': self.k})
+            documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
             return str("\n".join([document.page_content for document in documents]))
         else:
 
@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
                 return ''
             except ProviderTokenNotInitError:
                 return ''
-            embeddings = CacheEmbedding(embedding_model)
 
+            embeddings = CacheEmbedding(embedding_model)
             vector_index = VectorIndex(
                 dataset=dataset,
                 config=current_app.config,
                 embeddings=embeddings
             )
 
-            if self.k > 0:
+            if self.top_k > 0:
                 documents = vector_index.search(
                     query,
                     search_type='similarity_score_threshold',
                     search_kwargs={
-                        'k': self.k,
+                        'k': self.top_k,
+                        'score_threshold': self.score_threshold,
                         'filter': {
                             'group_id': [dataset.id]
                         }

+ 0 - 1
api/events/event_handlers/__init__.py

@@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle
 from .clean_when_dataset_deleted import handle
 from .update_app_dataset_join_when_app_model_config_updated import handle
 from .generate_conversation_name_when_first_message_created import handle
-from .generate_conversation_summary_when_few_message_created import handle
 from .create_document_index import handle

+ 0 - 14
api/events/event_handlers/generate_conversation_summary_when_few_message_created.py

@@ -1,14 +0,0 @@
-from events.message_event import message_was_created
-from tasks.generate_conversation_summary_task import generate_conversation_summary_task
-
-
-@message_was_created.connect
-def handle(sender, **kwargs):
-    message = sender
-    conversation = kwargs.get('conversation')
-    is_first_message = kwargs.get('is_first_message')
-
-    if not is_first_message and conversation.mode == 'chat' and not conversation.summary:
-        history_message_count = conversation.message_count
-        if history_message_count >= 5:
-            generate_conversation_summary_task.delay(conversation.id)

+ 4 - 0
api/fields/app_fields.py

@@ -28,6 +28,10 @@ model_config_fields = {
     'dataset_query_variable': fields.String,
     'pre_prompt': fields.String,
     'agent_mode': fields.Raw(attribute='agent_mode_dict'),
+    'prompt_type': fields.String,
+    'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
+    'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
+    'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
 }
 
 app_detail_fields = {

+ 1 - 0
api/fields/conversation_fields.py

@@ -123,6 +123,7 @@ conversation_with_summary_fields = {
     'from_end_user_id': fields.String,
     'from_end_user_session_id': fields.String,
     'from_account_id': fields.String,
+    'name': fields.String,
     'summary': fields.String(attribute='summary_or_query'),
     'read_at': TimestampField,
     'created_at': TimestampField,

+ 37 - 0
api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py

@@ -0,0 +1,37 @@
+"""add advanced prompt templates
+
+Revision ID: b3a09c049e8e
+Revises: 2e9819ca5b28
+Create Date: 2023-10-10 15:23:23.395420
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision = 'b3a09c049e8e'
+down_revision = '2e9819ca5b28'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
+        batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
+        batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
+        batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.drop_column('dataset_configs')
+        batch_op.drop_column('completion_prompt_config')
+        batch_op.drop_column('chat_prompt_config')
+        batch_op.drop_column('prompt_type')
+
+    # ### end Alembic commands ###

+ 33 - 2
api/models/model.py

@@ -93,6 +93,10 @@ class AppModelConfig(db.Model):
     agent_mode = db.Column(db.Text)
     sensitive_word_avoidance = db.Column(db.Text)
     retriever_resource = db.Column(db.Text)
+    prompt_type = db.Column(db.String(255), nullable=False, default='simple')
+    chat_prompt_config = db.Column(db.Text)
+    completion_prompt_config = db.Column(db.Text)
+    dataset_configs = db.Column(db.Text)
 
     @property
     def app(self):
@@ -139,6 +143,18 @@ class AppModelConfig(db.Model):
     def agent_mode_dict(self) -> dict:
         return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
 
+    @property
+    def chat_prompt_config_dict(self) -> dict:
+        return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
+
+    @property
+    def completion_prompt_config_dict(self) -> dict:
+        return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
+
+    @property
+    def dataset_configs_dict(self) -> dict:
+        return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
+
     def to_dict(self) -> dict:
         return {
             "provider": "",
@@ -155,7 +171,11 @@ class AppModelConfig(db.Model):
             "user_input_form": self.user_input_form_list,
             "dataset_query_variable": self.dataset_query_variable,
             "pre_prompt": self.pre_prompt,
-            "agent_mode": self.agent_mode_dict
+            "agent_mode": self.agent_mode_dict,
+            "prompt_type": self.prompt_type,
+            "chat_prompt_config": self.chat_prompt_config_dict,
+            "completion_prompt_config": self.completion_prompt_config_dict,
+            "dataset_configs": self.dataset_configs_dict
         }
 
     def from_model_config_dict(self, model_config: dict):
@@ -177,6 +197,13 @@ class AppModelConfig(db.Model):
         self.agent_mode = json.dumps(model_config['agent_mode'])
         self.retriever_resource = json.dumps(model_config['retriever_resource']) \
             if model_config.get('retriever_resource') else None
+        self.prompt_type = model_config.get('prompt_type', 'simple')
+        self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \
+            if model_config.get('chat_prompt_config') else None
+        self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \
+            if model_config.get('completion_prompt_config') else None
+        self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
+            if model_config.get('dataset_configs') else None
         return self
 
     def copy(self):
@@ -197,7 +224,11 @@ class AppModelConfig(db.Model):
             dataset_query_variable=self.dataset_query_variable,
             pre_prompt=self.pre_prompt,
             agent_mode=self.agent_mode,
-            retriever_resource=self.retriever_resource
+            retriever_resource=self.retriever_resource,
+            prompt_type=self.prompt_type,
+            chat_prompt_config=self.chat_prompt_config,
+            completion_prompt_config=self.completion_prompt_config,
+            dataset_configs=self.dataset_configs
         )
 
         return new_app_model_config

+ 56 - 0
api/services/advanced_prompt_template_service.py

@@ -0,0 +1,56 @@
+
+import copy
+
+from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
+    BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
+
+class AdvancedPromptTemplateService:
+
+    def get_prompt(self, args: dict) -> dict:
+        app_mode = args['app_mode']
+        model_mode = args['model_mode']
+        model_name = args['model_name']
+        has_context = args['has_context']
+
+        if 'baichuan' in model_name:
+            return self.get_baichuan_prompt(app_mode, model_mode, has_context)
+        else:
+            return self.get_common_prompt(app_mode, model_mode, has_context)
+
+    def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
+        if app_mode == 'chat':
+            if model_mode == 'completion':
+                return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
+            elif model_mode == 'chat':
+                return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
+        elif app_mode == 'completion':
+            if model_mode == 'completion':
+                return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
+            elif model_mode == 'chat':
+                return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
+            
+    def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
+        if has_context == 'true':
+            prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
+        
+        return prompt_template
+
+
+    def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
+        if has_context == 'true':
+            prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
+        
+        return prompt_template
+
+
+    def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
+        if app_mode == 'chat':
+            if model_mode == 'completion':
+                return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
+            elif model_mode == 'chat':
+                return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
+        elif app_mode == 'completion':
+            if model_mode == 'completion':
+                return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
+            elif model_mode == 'chat':
+                return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)

+ 70 - 22
api/services/app_model_config_service.py

@@ -3,7 +3,7 @@ import uuid
 
 from core.agent.agent_executor import PlanningStrategy
 from core.model_providers.model_provider_factory import ModelProviderFactory
-from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.models.entity.model_params import ModelType, ModelMode
 from models.account import Account
 from services.dataset_service import DatasetService
 
@@ -34,40 +34,28 @@ class AppModelConfigService:
         # max_tokens
         if 'max_tokens' not in cp:
             cp["max_tokens"] = 512
-        #
-        # if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
-        #         llm_constant.max_context_token_length[model_name]:
-        #     raise ValueError(
-        #         "max_tokens must be an integer greater than 0 "
-        #         "and not exceeding the maximum value of the corresponding model")
-        #
+
         # temperature
         if 'temperature' not in cp:
             cp["temperature"] = 1
-        #
-        # if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
-        #     raise ValueError("temperature must be a float between 0 and 2")
-        #
+
         # top_p
         if 'top_p' not in cp:
             cp["top_p"] = 1
 
-        # if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
-        #     raise ValueError("top_p must be a float between 0 and 2")
-        #
         # presence_penalty
         if 'presence_penalty' not in cp:
             cp["presence_penalty"] = 0
 
-        # if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
-        #     raise ValueError("presence_penalty must be a float between -2 and 2")
-        #
         # presence_penalty
         if 'frequency_penalty' not in cp:
             cp["frequency_penalty"] = 0
 
-        # if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
-        #     raise ValueError("frequency_penalty must be a float between -2 and 2")
+        # stop
+        if 'stop' not in cp:
+            cp["stop"] = []
+        elif not isinstance(cp["stop"], list):
+            raise ValueError("stop in model.completion_params must be of list type")
 
         # Filter out extra parameters
         filtered_cp = {
@@ -75,7 +63,8 @@ class AppModelConfigService:
             "temperature": cp["temperature"],
             "top_p": cp["top_p"],
             "presence_penalty": cp["presence_penalty"],
-            "frequency_penalty": cp["frequency_penalty"]
+            "frequency_penalty": cp["frequency_penalty"],
+            "stop": cp["stop"]
         }
 
         return filtered_cp
@@ -211,6 +200,10 @@ class AppModelConfigService:
         model_ids = [m['id'] for m in model_list]
         if config["model"]["name"] not in model_ids:
             raise ValueError("model.name must be in the specified model list")
+        
+        # model.mode
+        if 'mode' not in config['model'] or not config['model']["mode"]:
+            config['model']["mode"] = ""
 
         # model.completion_params
         if 'completion_params' not in config["model"]:
@@ -339,6 +332,9 @@ class AppModelConfigService:
         # dataset_query_variable
         AppModelConfigService.is_dataset_query_variable_valid(config, mode)
 
+        # advanced prompt validation
+        AppModelConfigService.is_advanced_prompt_valid(config, mode)
+
         # Filter out extra parameters
         filtered_config = {
             "opening_statement": config["opening_statement"],
@@ -351,12 +347,17 @@ class AppModelConfigService:
             "model": {
                 "provider": config["model"]["provider"],
                 "name": config["model"]["name"],
+                "mode": config['model']["mode"],
                 "completion_params": config["model"]["completion_params"]
             },
             "user_input_form": config["user_input_form"],
             "dataset_query_variable": config.get('dataset_query_variable'),
             "pre_prompt": config["pre_prompt"],
-            "agent_mode": config["agent_mode"]
+            "agent_mode": config["agent_mode"],
+            "prompt_type": config["prompt_type"],
+            "chat_prompt_config": config["chat_prompt_config"],
+            "completion_prompt_config": config["completion_prompt_config"],
+            "dataset_configs": config["dataset_configs"]
         }
 
         return filtered_config
@@ -375,4 +376,51 @@ class AppModelConfigService:
 
         if dataset_exists and not dataset_query_variable:
             raise ValueError("Dataset query variable is required when dataset is exist")
+        
+
+    @staticmethod
+    def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
+        # prompt_type
+        if 'prompt_type' not in config or not config["prompt_type"]:
+            config["prompt_type"] = "simple"
+
+        if config['prompt_type'] not in ['simple', 'advanced']:
+            raise ValueError("prompt_type must be in ['simple', 'advanced']")
+        
+        # chat_prompt_config
+        if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
+            config["chat_prompt_config"] = {}
 
+        if not isinstance(config["chat_prompt_config"], dict):
+            raise ValueError("chat_prompt_config must be of object type")
+
+        # completion_prompt_config
+        if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
+            config["completion_prompt_config"] = {}
+
+        if not isinstance(config["completion_prompt_config"], dict):
+            raise ValueError("completion_prompt_config must be of object type")
+        
+        # dataset_configs
+        if 'dataset_configs' not in config or not config["dataset_configs"]:
+            config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
+
+        if not isinstance(config["dataset_configs"], dict):
+            raise ValueError("dataset_configs must be of object type")
+
+        if config['prompt_type'] == 'advanced':
+            if not config['chat_prompt_config'] and not config['completion_prompt_config']:
+                raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
+            
+            if config['model']["mode"] not in ['chat', 'completion']:
+                raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
+            
+            if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
+                user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
+                assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
+
+                if not user_prefix:
+                    config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
+
+                if not assistant_prefix:
+                    config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'

+ 17 - 49
api/services/completion_service.py

@@ -244,7 +244,8 @@ class CompletionService:
 
     @classmethod
     def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
-                                message_id: str, streaming: bool = True) -> Union[dict | Generator]:
+                                message_id: str, streaming: bool = True,
+                                retriever_from: str = 'dev') -> Union[dict | Generator]:
         if not user:
             raise ValueError('user cannot be None')
 
@@ -266,14 +267,11 @@ class CompletionService:
             raise MoreLikeThisDisabledError()
 
         app_model_config = message.app_model_config
-
-        if message.override_model_configs:
-            override_model_configs = json.loads(message.override_model_configs)
-            pre_prompt = override_model_configs.get("pre_prompt", '')
-        elif app_model_config:
-            pre_prompt = app_model_config.pre_prompt
-        else:
-            raise AppModelConfigBrokenError()
+        model_dict = app_model_config.model_dict
+        completion_params = model_dict.get('completion_params')
+        completion_params['temperature'] = 0.9
+        model_dict['completion_params'] = completion_params
+        app_model_config.model = json.dumps(model_dict)
 
         generate_task_id = str(uuid.uuid4())
 
@@ -282,58 +280,28 @@ class CompletionService:
 
         user = cls.get_real_user_instead_of_proxy_obj(user)
 
-        generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
+        generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
             'flask_app': current_app._get_current_object(),
             'generate_task_id': generate_task_id,
             'detached_app_model': app_model,
             'app_model_config': app_model_config,
-            'detached_message': message,
-            'pre_prompt': pre_prompt,
+            'query': message.query,
+            'inputs': message.inputs,
             'detached_user': user,
-            'streaming': streaming
+            'detached_conversation': None,
+            'streaming': streaming,
+            'is_model_config_override': True,
+            'retriever_from': retriever_from
         })
 
         generate_worker_thread.start()
 
-        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
+        # wait for 10 minutes to close the thread
+        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
+                                generate_task_id)
 
         return cls.compact_response(pubsub, streaming)
 
-    @classmethod
-    def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
-                                       app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
-                                       detached_user: Union[Account, EndUser], streaming: bool):
-        with flask_app.app_context():
-            # fixed the state of the model object when it detached from the original session
-            user = db.session.merge(detached_user)
-            app_model = db.session.merge(detached_app_model)
-            message = db.session.merge(detached_message)
-
-            try:
-                # run
-                Completion.generate_more_like_this(
-                    task_id=generate_task_id,
-                    app=app_model,
-                    user=user,
-                    message=message,
-                    pre_prompt=pre_prompt,
-                    app_model_config=app_model_config,
-                    streaming=streaming
-                )
-            except ConversationTaskStoppedException:
-                pass
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                    LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
-                    ModelCurrentlyNotSupportError) as e:
-                PubHandler.pub_error(user, generate_task_id, e)
-            except LLMAuthorizationError:
-                PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
-            except Exception as e:
-                logging.exception("Unknown Error in completion")
-                PubHandler.pub_error(user, generate_task_id, e)
-            finally:
-                db.session.commit()
-
     @classmethod
     def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
         if user_inputs is None:

+ 3 - 0
api/services/provider_service.py

@@ -482,6 +482,9 @@ class ProviderService:
                     'features': []
                 }
 
+                if 'mode' in model:
+                    valid_model_dict['model_mode'] = model['mode']
+
                 if 'features' in model:
                     valid_model_dict['features'] = model['features']
 

+ 0 - 55
api/tasks/generate_conversation_summary_task.py

@@ -1,55 +0,0 @@
-import logging
-import time
-
-import click
-from celery import shared_task
-from werkzeug.exceptions import NotFound
-
-from core.generator.llm_generator import LLMGenerator
-from core.model_providers.error import LLMError, ProviderTokenNotInitError
-from extensions.ext_database import db
-from models.model import Conversation, Message
-
-
-@shared_task(queue='generation')
-def generate_conversation_summary_task(conversation_id: str):
-    """
-    Async Generate conversation summary
-    :param conversation_id:
-
-    Usage: generate_conversation_summary_task.delay(conversation_id)
-    """
-    logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green'))
-    start_at = time.perf_counter()
-
-    conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
-    if not conversation:
-        raise NotFound('Conversation not found')
-
-    try:
-        # get conversation messages count
-        history_message_count = conversation.message_count
-        if history_message_count >= 5 and not conversation.summary:
-            app_model = conversation.app
-            if not app_model:
-                return
-
-            history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
-                .order_by(Message.created_at.asc()).all()
-
-            conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
-            db.session.add(conversation)
-            db.session.commit()
-    except (LLMError, ProviderTokenNotInitError):
-        conversation.summary = '[No Summary]'
-        db.session.commit()
-        pass
-    except Exception as e:
-        conversation.summary = '[No Summary]'
-        db.session.commit()
-        logging.exception(e)
-
-    end_at = time.perf_counter()
-    logging.info(
-        click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
-                    fg='green'))

+ 1 - 1
api/tests/integration_tests/models/llm/test_anthropic_model.py

@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt):
     model = get_mock_model('claude-2')
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 6
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_azure_openai_model.py

@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
     openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
     rst = openai_model.get_num_tokens([
         PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 22
 

+ 3 - 3
api/tests/integration_tests/models/llm/test_baichuan_model.py

@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
     model = get_mock_model('baichuan2-53b')
     rst = model.get_num_tokens([
         PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst > 0
 
@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
 
     model = get_mock_model('baichuan2-53b')
     messages = [
-        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+        PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
     ]
     rst = model.run(
         messages,
@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
 
     model = get_mock_model('baichuan2-53b', streaming=True)
     messages = [
-        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+        PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
     ]
     rst = model.run(
         messages

+ 2 - 2
api/tests/integration_tests/models/llm/test_huggingface_hub_model.py

@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
         mocker
     )
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 
@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
         mocker
     )
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_minimax_model.py

@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt):
     model = get_mock_model('abab5.5-chat')
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_openai_model.py

@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
     openai_model = get_mock_openai_model('gpt-3.5-turbo')
     rst = openai_model.get_num_tokens([
         PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 22
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_openllm_model.py

@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt, mocker):
     model = get_mock_model('facebook/opt-125m', mocker)
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_replicate_model.py

@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt, mocker):
     model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 7
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_spark_model.py

@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt):
     model = get_mock_model('spark')
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 6
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_tongyi_model.py

@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt):
     model = get_mock_model('qwen-turbo')
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_wenxin_model.py

@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt):
     model = get_mock_model('ernie-bot')
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 

+ 1 - 1
api/tests/integration_tests/models/llm/test_xinference_model.py

@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
 def test_get_num_tokens(mock_decrypt, mocker):
     model = get_mock_model('llama-2-chat', mocker)
     rst = model.get_num_tokens([
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst == 5
 

+ 3 - 3
api/tests/integration_tests/models/llm/test_zhipuai_model.py

@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
     model = get_mock_model('chatglm_lite')
     rst = model.get_num_tokens([
         PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
-        PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
+        PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
     ])
     assert rst > 0
 
@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
 
     model = get_mock_model('chatglm_lite')
     messages = [
-        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+        PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
     ]
     rst = model.run(
         messages,
@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
 
     model = get_mock_model('chatglm_lite', streaming=True)
     messages = [
-        PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
+        PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
     ]
     rst = model.run(
         messages

+ 5 - 2
api/tests/unit_tests/model_providers/fake_model_provider.py

@@ -1,7 +1,7 @@
 from typing import Type
 
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
+from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
 from core.model_providers.models.llm.openai_model import OpenAIModel
 from core.model_providers.providers.base import BaseModelProvider
 
@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
         return 'fake'
 
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
-        return [{'id': 'test_model', 'name': 'Test Model'}]
+        return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
+
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.COMPLETION.value
 
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         return OpenAIModel

+ 1 - 1
api/tests/unit_tests/model_providers/test_base_model_provider.py

@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
     provider = FakeModelProvider(provider=Provider())
     result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
 
-    assert result == [{'id': 'test_model', 'name': 'test_model'}]
+    assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
 
 
 def test_check_quota_over_limit(mocker):