Browse Source

feat: optimize override app model config convert (#874)

takatost 1 year ago
parent
commit
cc2d71c253

+ 27 - 46
api/controllers/console/app/app.py

@@ -124,12 +124,29 @@ class AppListApi(Resource):
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
             raise Forbidden()
 
+        default_model = ModelFactory.get_default_model(
+            tenant_id=current_user.current_tenant_id,
+            model_type=ModelType.TEXT_GENERATION
+        )
+
+        if default_model:
+            default_model_provider = default_model.provider_name
+            default_model_name = default_model.model_name
+        else:
+            raise ProviderNotInitializeError(
+                f"No Text Generation Model available. Please configure a valid provider "
+                f"in the Settings -> Model Provider.")
+
         if args['model_config'] is not None:
             # validate config
+            model_config_dict = args['model_config']
+            model_config_dict["model"]["provider"] = default_model_provider
+            model_config_dict["model"]["name"] = default_model_name
+
             model_configuration = AppModelConfigService.validate_configuration(
                 tenant_id=current_user.current_tenant_id,
                 account=current_user,
-                config=args['model_config']
+                config=model_config_dict
             )
 
             app = App(
@@ -141,21 +158,8 @@ class AppListApi(Resource):
                 status='normal'
             )
 
-            app_model_config = AppModelConfig(
-                provider="",
-                model_id="",
-                configs={},
-                opening_statement=model_configuration['opening_statement'],
-                suggested_questions=json.dumps(model_configuration['suggested_questions']),
-                suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
-                speech_to_text=json.dumps(model_configuration['speech_to_text']),
-                more_like_this=json.dumps(model_configuration['more_like_this']),
-                sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
-                model=json.dumps(model_configuration['model']),
-                user_input_form=json.dumps(model_configuration['user_input_form']),
-                pre_prompt=model_configuration['pre_prompt'],
-                agent_mode=json.dumps(model_configuration['agent_mode']),
-            )
+            app_model_config = AppModelConfig()
+            app_model_config = app_model_config.from_model_config_dict(model_configuration)
         else:
             if 'mode' not in args or args['mode'] is None:
                 abort(400, message="mode is required")
@@ -165,20 +169,10 @@ class AppListApi(Resource):
             app = App(**model_config_template['app'])
             app_model_config = AppModelConfig(**model_config_template['model_config'])
 
-            default_model = ModelFactory.get_default_model(
-                tenant_id=current_user.current_tenant_id,
-                model_type=ModelType.TEXT_GENERATION
-            )
-
-            if default_model:
-                model_dict = app_model_config.model_dict
-                model_dict['provider'] = default_model.provider_name
-                model_dict['name'] = default_model.model_name
-                app_model_config.model = json.dumps(model_dict)
-            else:
-                raise ProviderNotInitializeError(
-                    f"No Text Generation Model available. Please configure a valid provider "
-                    f"in the Settings -> Model Provider.")
+            model_dict = app_model_config.model_dict
+            model_dict['provider'] = default_model_provider
+            model_dict['name'] = default_model_name
+            app_model_config.model = json.dumps(model_dict)
 
         app.name = args['name']
         app.mode = args['mode']
@@ -416,22 +410,9 @@ class AppCopy(Resource):
 
     @staticmethod
     def create_app_model_config_copy(app_config, copy_app_id):
-        copy_app_model_config = AppModelConfig(
-            app_id=copy_app_id,
-            provider=app_config.provider,
-            model_id=app_config.model_id,
-            configs=app_config.configs,
-            opening_statement=app_config.opening_statement,
-            suggested_questions=app_config.suggested_questions,
-            suggested_questions_after_answer=app_config.suggested_questions_after_answer,
-            speech_to_text=app_config.speech_to_text,
-            more_like_this=app_config.more_like_this,
-            sensitive_word_avoidance=app_config.sensitive_word_avoidance,
-            model=app_config.model,
-            user_input_form=app_config.user_input_form,
-            pre_prompt=app_config.pre_prompt,
-            agent_mode=app_config.agent_mode
-        )
+        copy_app_model_config = app_config.copy()
+        copy_app_model_config.app_id = copy_app_id
+
         return copy_app_model_config
 
     @setup_required

+ 1 - 13
api/controllers/console/app/model_config.py

@@ -35,20 +35,8 @@ class ModelConfigResource(Resource):
 
         new_app_model_config = AppModelConfig(
             app_id=app_model.id,
-            provider="",
-            model_id="",
-            configs={},
-            opening_statement=model_configuration['opening_statement'],
-            suggested_questions=json.dumps(model_configuration['suggested_questions']),
-            suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
-            speech_to_text=json.dumps(model_configuration['speech_to_text']),
-            more_like_this=json.dumps(model_configuration['more_like_this']),
-            sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
-            model=json.dumps(model_configuration['model']),
-            user_input_form=json.dumps(model_configuration['user_input_form']),
-            pre_prompt=model_configuration['pre_prompt'],
-            agent_mode=json.dumps(model_configuration['agent_mode']),
         )
+        new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
 
         db.session.add(new_app_model_config)
         db.session.flush()

+ 1 - 1
api/core/agent/agent/structured_chat.py

@@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
                                           "I don't know how to respond to that."}, "")
 
     def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
-        if len(intermediate_steps) >= 2:
+        if len(intermediate_steps) >= 2 and self.summary_llm:
             should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
             should_summary_messages = [AIMessage(content=observation)
                                        for _, observation in should_summary_intermediate_steps]

+ 6 - 3
api/core/agent/agent_executor.py

@@ -65,7 +65,8 @@ class AgentExecutor:
                 llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 output_parser=StructuredChatOutputParser(),
-                summary_llm=self.configuration.summary_model_instance.client,
+                summary_llm=self.configuration.summary_model_instance.client
+                if self.configuration.summary_model_instance else None,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
@@ -74,7 +75,8 @@ class AgentExecutor:
                 llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_model_instance.client,
+                summary_llm=self.configuration.summary_model_instance.client
+                if self.configuration.summary_model_instance else None,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
@@ -83,7 +85,8 @@ class AgentExecutor:
                 llm=self.configuration.model_instance.client,
                 tools=self.configuration.tools,
                 extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,  # used for read chat histories memory
-                summary_llm=self.configuration.summary_model_instance.client,
+                summary_llm=self.configuration.summary_model_instance.client
+                if self.configuration.summary_model_instance else None,
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.ROUTER:

+ 1 - 11
api/core/conversation_message_task.py

@@ -60,17 +60,7 @@ class ConversationMessageTask:
     def init(self):
         override_model_configs = None
         if self.is_override:
-            override_model_configs = {
-                "model": self.app_model_config.model_dict,
-                "pre_prompt": self.app_model_config.pre_prompt,
-                "agent_mode": self.app_model_config.agent_mode_dict,
-                "opening_statement": self.app_model_config.opening_statement,
-                "suggested_questions": self.app_model_config.suggested_questions_list,
-                "suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
-                "more_like_this": self.app_model_config.more_like_this_dict,
-                "sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
-                "user_input_form": self.app_model_config.user_input_form_list,
-            }
+            override_model_configs = self.app_model_config.to_dict()
 
         introduction = ''
         system_instruction = ''

+ 10 - 7
api/core/generator/llm_generator.py

@@ -2,7 +2,7 @@ import logging
 
 from langchain.schema import OutputParserException
 
-from core.model_providers.error import LLMError
+from core.model_providers.error import LLMError, ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.message import PromptMessage, MessageType
 from core.model_providers.models.entity.model_params import ModelKwargs
@@ -108,13 +108,16 @@ class LLMGenerator:
 
         _input = prompt.format_prompt(histories=histories)
 
-        model_instance = ModelFactory.get_text_generation_model(
-            tenant_id=tenant_id,
-            model_kwargs=ModelKwargs(
-                max_tokens=256,
-                temperature=0
+        try:
+            model_instance = ModelFactory.get_text_generation_model(
+                tenant_id=tenant_id,
+                model_kwargs=ModelKwargs(
+                    max_tokens=256,
+                    temperature=0
+                )
             )
-        )
+        except ProviderTokenNotInitError:
+            return []
 
         prompts = [PromptMessage(content=_input.to_string())]
 

+ 10 - 6
api/core/orchestrator_rule_parser.py

@@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
 from core.conversation_message_task import ConversationMessageTask
+from core.model_providers.error import ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -78,13 +79,16 @@ class OrchestratorRuleParser:
                 elif planning_strategy == PlanningStrategy.ROUTER:
                     planning_strategy = PlanningStrategy.REACT_ROUTER
 
-            summary_model_instance = ModelFactory.get_text_generation_model(
-                tenant_id=self.tenant_id,
-                model_kwargs=ModelKwargs(
-                    temperature=0,
-                    max_tokens=500
+            try:
+                summary_model_instance = ModelFactory.get_text_generation_model(
+                    tenant_id=self.tenant_id,
+                    model_kwargs=ModelKwargs(
+                        temperature=0,
+                        max_tokens=500
+                    )
                 )
-            )
+            except ProviderTokenNotInitError as e:
+                summary_model_instance = None
 
             tools = self.to_tools(
                 tool_configs=tool_configs,

+ 9 - 1
api/core/prompt/output_parser/suggested_questions_after_answer.py

@@ -1,7 +1,10 @@
 import json
+import re
 from typing import Any
 
 from langchain.schema import BaseOutputParser
+
+from core.model_providers.error import LLMError
 from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
 
 
@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
 
     def parse(self, text: str) -> Any:
         json_string = text.strip()
-        json_obj = json.loads(json_string)
+        action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
+        if action_match is not None:
+            json_obj = json.loads(action_match.group(1).strip(), strict=False)
+        else:
+            raise LLMError("Could not parse LLM output: {text}")
+
         return json_obj

+ 46 - 4
api/models/model.py

@@ -108,7 +108,7 @@ class AppModelConfig(db.Model):
     def suggested_questions_after_answer_dict(self) -> dict:
         return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \
             else {"enabled": False}
-    
+
     @property
     def speech_to_text_dict(self) -> dict:
         return json.loads(self.speech_to_text) if self.speech_to_text \
@@ -148,6 +148,46 @@ class AppModelConfig(db.Model):
             "agent_mode": self.agent_mode_dict
         }
 
+    def from_model_config_dict(self, model_config: dict):
+        self.provider = ""
+        self.model_id = ""
+        self.configs = {}
+        self.opening_statement = model_config['opening_statement']
+        self.suggested_questions = json.dumps(model_config['suggested_questions'])
+        self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer'])
+        self.speech_to_text = json.dumps(model_config['speech_to_text']) \
+            if model_config.get('speech_to_text') else None
+        self.more_like_this = json.dumps(model_config['more_like_this'])
+        self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance'])
+        self.model = json.dumps(model_config['model'])
+        self.user_input_form = json.dumps(model_config['user_input_form'])
+        self.pre_prompt = model_config['pre_prompt']
+        self.agent_mode = json.dumps(model_config['agent_mode'])
+
+        return self
+
+    def copy(self):
+        new_app_model_config = AppModelConfig(
+            id=self.id,
+            app_id=self.app_id,
+            provider="",
+            model_id="",
+            configs={},
+            opening_statement=self.opening_statement,
+            suggested_questions=self.suggested_questions,
+            suggested_questions_after_answer=self.suggested_questions_after_answer,
+            speech_to_text=self.speech_to_text,
+            more_like_this=self.more_like_this,
+            sensitive_word_avoidance=self.sensitive_word_avoidance,
+            model=self.model,
+            user_input_form=self.user_input_form,
+            pre_prompt=self.pre_prompt,
+            agent_mode=self.agent_mode
+        )
+
+        return new_app_model_config
+
+
 class RecommendedApp(db.Model):
     __tablename__ = 'recommended_apps'
     __table_args__ = (
@@ -234,7 +274,8 @@ class Conversation(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
     messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
-    message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
+    message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
+                                          passive_deletes="all")
 
     is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
 
@@ -429,7 +470,7 @@ class Message(db.Model):
 
     @property
     def agent_thoughts(self):
-        return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id)\
+        return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
             .order_by(MessageAgentThought.position.asc()).all()
 
 
@@ -557,7 +598,8 @@ class Site(db.Model):
 
     @property
     def app_base_url(self):
-        return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
+        return (
+            current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
 
 
 class ApiToken(db.Model):

+ 11 - 35
api/services/completion_service.py

@@ -63,26 +63,23 @@ class CompletionService:
                 raise ConversationCompletedError()
 
             if not conversation.override_model_configs:
-                app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id)
+                app_model_config = db.session.query(AppModelConfig).filter(
+                    AppModelConfig.id == conversation.app_model_config_id,
+                    AppModelConfig.app_id == app_model.id
+                ).first()
 
                 if not app_model_config:
                     raise AppModelConfigBrokenError()
             else:
                 conversation_override_model_configs = json.loads(conversation.override_model_configs)
+
                 app_model_config = AppModelConfig(
                     id=conversation.app_model_config_id,
                     app_id=app_model.id,
-                    provider="",
-                    model_id="",
-                    configs="",
-                    opening_statement=conversation_override_model_configs['opening_statement'],
-                    suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
-                    model=json.dumps(conversation_override_model_configs['model']),
-                    user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
-                    pre_prompt=conversation_override_model_configs['pre_prompt'],
-                    agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
                 )
 
+                app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
+
             if is_model_config_override:
                 # build new app model config
                 if 'model' not in args['model_config']:
@@ -99,19 +96,8 @@ class CompletionService:
                 app_model_config_model = app_model_config.model_dict
                 app_model_config_model['completion_params'] = completion_params
 
-                app_model_config = AppModelConfig(
-                    id=app_model_config.id,
-                    app_id=app_model.id,
-                    provider="",
-                    model_id="",
-                    configs="",
-                    opening_statement=app_model_config.opening_statement,
-                    suggested_questions=app_model_config.suggested_questions,
-                    model=json.dumps(app_model_config_model),
-                    user_input_form=app_model_config.user_input_form,
-                    pre_prompt=app_model_config.pre_prompt,
-                    agent_mode=app_model_config.agent_mode,
-                )
+                app_model_config = app_model_config.copy()
+                app_model_config.model = json.dumps(app_model_config_model)
         else:
             if app_model.app_model_config_id is None:
                 raise AppModelConfigBrokenError()
@@ -135,20 +121,10 @@ class CompletionService:
                 app_model_config = AppModelConfig(
                     id=app_model_config.id,
                     app_id=app_model.id,
-                    provider="",
-                    model_id="",
-                    configs="",
-                    opening_statement=model_config['opening_statement'],
-                    suggested_questions=json.dumps(model_config['suggested_questions']),
-                    suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
-                    more_like_this=json.dumps(model_config['more_like_this']),
-                    sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
-                    model=json.dumps(model_config['model']),
-                    user_input_form=json.dumps(model_config['user_input_form']),
-                    pre_prompt=model_config['pre_prompt'],
-                    agent_mode=json.dumps(model_config['agent_mode']),
                 )
 
+                app_model_config = app_model_config.from_model_config_dict(model_config)
+
         # clean input by app_model_config form rules
         inputs = cls.get_cleaned_inputs(inputs, app_model_config)
 

+ 33 - 9
api/services/message_service.py

@@ -1,3 +1,4 @@
+import json
 from typing import Optional, Union, List
 
 from core.completion import Completion
@@ -5,8 +6,10 @@ from core.generator.llm_generator import LLMGenerator
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from extensions.ext_database import db
 from models.account import Account
-from models.model import App, EndUser, Message, MessageFeedback
+from models.model import App, EndUser, Message, MessageFeedback, AppModelConfig
 from services.conversation_service import ConversationService
+from services.errors.app_model_config import AppModelConfigBrokenError
+from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
 from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \
     SuggestedQuestionsAfterAnswerDisabledError
 
@@ -172,12 +175,6 @@ class MessageService:
         if not user:
             raise ValueError('user cannot be None')
 
-        app_model_config = app_model.app_model_config
-        suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
-
-        if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
-            raise SuggestedQuestionsAfterAnswerDisabledError()
-
         message = cls.get_message(
             app_model=app_model,
             user=user,
@@ -190,10 +187,38 @@ class MessageService:
             user=user
         )
 
+        if not conversation:
+            raise ConversationNotExistsError()
+
+        if conversation.status != 'normal':
+            raise ConversationCompletedError()
+
+        if not conversation.override_model_configs:
+            app_model_config = db.session.query(AppModelConfig).filter(
+                AppModelConfig.id == conversation.app_model_config_id,
+                AppModelConfig.app_id == app_model.id
+            ).first()
+
+            if not app_model_config:
+                raise AppModelConfigBrokenError()
+        else:
+            conversation_override_model_configs = json.loads(conversation.override_model_configs)
+            app_model_config = AppModelConfig(
+                id=conversation.app_model_config_id,
+                app_id=app_model.id,
+            )
+
+            app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
+
+        suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
+
+        if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
+            raise SuggestedQuestionsAfterAnswerDisabledError()
+
         # get memory of conversation (read-only)
         memory = Completion.get_memory_from_conversation(
             tenant_id=app_model.tenant_id,
-            app_model_config=app_model.app_model_config,
+            app_model_config=app_model_config,
             conversation=conversation,
             max_token_limit=3000,
             message_limit=3,
@@ -209,4 +234,3 @@ class MessageService:
         )
 
         return questions
-

+ 11 - 5
api/tasks/generate_conversation_summary_task.py

@@ -6,7 +6,7 @@ from celery import shared_task
 from werkzeug.exceptions import NotFound
 
 from core.generator.llm_generator import LLMGenerator
-from core.model_providers.error import LLMError
+from core.model_providers.error import LLMError, ProviderTokenNotInitError
 from extensions.ext_database import db
 from models.model import Conversation, Message
 
@@ -40,10 +40,16 @@ def generate_conversation_summary_task(conversation_id: str):
             conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
             db.session.add(conversation)
             db.session.commit()
-
-        end_at = time.perf_counter()
-        logging.info(click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), fg='green'))
-    except LLMError:
+    except (LLMError, ProviderTokenNotInitError):
+        conversation.summary = '[No Summary]'
+        db.session.commit()
         pass
     except Exception as e:
+        conversation.summary = '[No Summary]'
+        db.session.commit()
         logging.exception(e)
+
+    end_at = time.perf_counter()
+    logging.info(
+        click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
+                    fg='green'))