Explorar el Código

feat: auto rule generator (#273)

John Wang hace 1 año
padre
commit
490858a4d5

+ 1 - 1
api/controllers/console/__init__.py

@@ -9,7 +9,7 @@ api = ExternalApi(bp)
 from . import setup, version, apikey, admin
 
 # Import app controllers
-from .app import app, site, completion, model_config, statistic, conversation, message
+from .app import app, site, completion, model_config, statistic, conversation, message, generator
 
 # Import auth controllers
 from .auth import login, oauth

+ 2 - 37
api/controllers/console/app/app.py

@@ -9,18 +9,13 @@ from werkzeug.exceptions import Unauthorized, Forbidden
 
 from constants.model_template import model_templates, demo_model_templates
 from controllers.console import api
-from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError, ProviderQuotaExceededError, \
-    CompletionRequestError, ProviderModelCurrentlyNotSupportError
+from controllers.console.app.error import AppNotFoundError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.generator.llm_generator import LLMGenerator
-from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
-    LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
 from events.app_event import app_was_created, app_was_deleted
 from libs.helper import TimestampField
 from extensions.ext_database import db
-from models.model import App, AppModelConfig, Site, InstalledApp
-from services.account_service import TenantService
+from models.model import App, AppModelConfig, Site
 from services.app_model_config_service import AppModelConfigService
 
 model_config_fields = {
@@ -478,35 +473,6 @@ class AppExport(Resource):
         pass
 
 
-class IntroductionGenerateApi(Resource):
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self):
-        parser = reqparse.RequestParser()
-        parser.add_argument('prompt_template', type=str, required=True, location='json')
-        args = parser.parse_args()
-
-        account = current_user
-
-        try:
-            answer = LLMGenerator.generate_introduction(
-                account.current_tenant_id,
-                args['prompt_template']
-            )
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
-        except QuotaExceededError:
-            raise ProviderQuotaExceededError()
-        except ModelCurrentlyNotSupportError:
-            raise ProviderModelCurrentlyNotSupportError()
-        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
-                LLMRateLimitError, LLMAuthorizationError) as e:
-            raise CompletionRequestError(str(e))
-
-        return {'introduction': answer}
-
-
 api.add_resource(AppListApi, '/apps')
 api.add_resource(AppTemplateApi, '/app-templates')
 api.add_resource(AppApi, '/apps/<uuid:app_id>')
@@ -515,4 +481,3 @@ api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
 api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
 api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
 api.add_resource(AppRateLimit, '/apps/<uuid:app_id>/rate-limit')
-api.add_resource(IntroductionGenerateApi, '/introduction-generate')

+ 75 - 0
api/controllers/console/app/generator.py

@@ -0,0 +1,75 @@
+from flask_login import login_required, current_user
+from flask_restful import Resource, reqparse
+
+from controllers.console import api
+from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
+    CompletionRequestError, ProviderModelCurrentlyNotSupportError
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from core.generator.llm_generator import LLMGenerator
+from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
+    LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
+
+
+class IntroductionGenerateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('prompt_template', type=str, required=True, location='json')
+        args = parser.parse_args()
+
+        account = current_user
+
+        try:
+            answer = LLMGenerator.generate_introduction(
+                account.current_tenant_id,
+                args['prompt_template']
+            )
+        except ProviderTokenNotInitError:
+            raise ProviderNotInitializeError()
+        except QuotaExceededError:
+            raise ProviderQuotaExceededError()
+        except ModelCurrentlyNotSupportError:
+            raise ProviderModelCurrentlyNotSupportError()
+        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
+                LLMRateLimitError, LLMAuthorizationError) as e:
+            raise CompletionRequestError(str(e))
+
+        return {'introduction': answer}
+
+
+class RuleGenerateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('audiences', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('hoping_to_solve', type=str, required=True, nullable=False, location='json')
+        args = parser.parse_args()
+
+        account = current_user
+
+        try:
+            rules = LLMGenerator.generate_rule_config(
+                account.current_tenant_id,
+                args['audiences'],
+                args['hoping_to_solve']
+            )
+        except ProviderTokenNotInitError:
+            raise ProviderNotInitializeError()
+        except QuotaExceededError:
+            raise ProviderQuotaExceededError()
+        except ModelCurrentlyNotSupportError:
+            raise ProviderModelCurrentlyNotSupportError()
+        except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
+                LLMRateLimitError, LLMAuthorizationError) as e:
+            raise CompletionRequestError(str(e))
+
+        return rules
+
+
+api.add_resource(IntroductionGenerateApi, '/introduction-generate')
+api.add_resource(RuleGenerateApi, '/rule-generate')

+ 4 - 34
api/core/chain/llm_router_chain.py

@@ -11,6 +11,8 @@ from langchain.chains import LLMChain
 from langchain.prompts import BasePromptTemplate
 from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
 
+from libs.json_in_md_parser import parse_and_check_json_markdown
+
 
 class Route(NamedTuple):
     destination: Optional[str]
@@ -82,42 +84,10 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
     next_inputs_type: Type = str
     next_inputs_inner_key: str = "input"
 
-    def parse_json_markdown(self, json_string: str) -> dict:
-        # Remove the triple backticks if present
-        json_string = json_string.strip()
-        start_index = json_string.find("```json")
-        end_index = json_string.find("```", start_index + len("```json"))
-
-        if start_index != -1 and end_index != -1:
-            extracted_content = json_string[start_index + len("```json"):end_index].strip()
-
-            # Parse the JSON string into a Python dictionary
-            parsed = json.loads(extracted_content)
-        elif json_string.startswith("{"):
-            # Parse the JSON string into a Python dictionary
-            parsed = json.loads(json_string)
-        else:
-            raise Exception("Could not find JSON block in the output.")
-
-        return parsed
-
-    def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
-        try:
-            json_obj = self.parse_json_markdown(text)
-        except json.JSONDecodeError as e:
-            raise OutputParserException(f"Got invalid JSON object. Error: {e}")
-        for key in expected_keys:
-            if key not in json_obj:
-                raise OutputParserException(
-                    f"Got invalid return object. Expected key `{key}` "
-                    f"to be present, but got {json_obj}"
-                )
-        return json_obj
-
     def parse(self, text: str) -> Dict[str, Any]:
         try:
             expected_keys = ["destination", "next_inputs"]
-            parsed = self.parse_and_check_json_markdown(text, expected_keys)
+            parsed = parse_and_check_json_markdown(text, expected_keys)
             if not isinstance(parsed["destination"], str):
                 raise ValueError("Expected 'destination' to be a string.")
             if not isinstance(parsed["next_inputs"], self.next_inputs_type):
@@ -135,5 +105,5 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
             return parsed
         except Exception as e:
             raise OutputParserException(
-                f"Parsing text\n{text}\n raised following error:\n{e}"
+                f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
             )

+ 2 - 1
api/core/chain/multi_dataset_router_chain.py

@@ -23,7 +23,8 @@ think that revising it will ultimately lead to a better response from the langua
 model.
 
 << FORMATTING >>
-Return a markdown code snippet with a JSON object formatted to look like:
+Return a markdown code snippet with a JSON object formatted to look like, \
+no any other string out of markdown code snippet:
 ```json
 {{{{
     "destination": string \\ name of the prompt to use or "DEFAULT"

+ 44 - 0
api/core/generator/llm_generator.py

@@ -7,6 +7,7 @@ from core.constant import llm_constant
 from core.llm.llm_builder import LLMBuilder
 from core.llm.streamable_open_ai import StreamableOpenAI
 from core.llm.token_calculator import TokenCalculator
+from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
 from core.prompt.prompt_template import OutLinePromptTemplate
@@ -118,3 +119,46 @@ class LLMGenerator:
             questions = []
 
         return questions
+
+    @classmethod
+    def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
+        output_parser = RuleConfigGeneratorOutputParser()
+
+        prompt = OutLinePromptTemplate(
+            template=output_parser.get_format_instructions(),
+            input_variables=["audiences", "hoping_to_solve"],
+            partial_variables={
+                "variable": '{variable}',
+                "lanA": '{lanA}',
+                "lanB": '{lanB}',
+                "topic": '{topic}'
+            },
+            validate_template=False
+        )
+
+        _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
+
+        llm: StreamableOpenAI = LLMBuilder.to_llm(
+            tenant_id=tenant_id,
+            model_name=generate_base_model,
+            temperature=0,
+            max_tokens=512
+        )
+
+        if isinstance(llm, BaseChatModel):
+            query = [HumanMessage(content=_input.to_string())]
+        else:
+            query = _input.to_string()
+
+        try:
+            output = llm(query)
+            rule_config = output_parser.parse(output)
+        except Exception:
+            logging.exception("Error generating prompt")
+            rule_config = {
+                "prompt": "",
+                "variables": [],
+                "opening_statement": ""
+            }
+
+        return rule_config

+ 32 - 0
api/core/prompt/output_parser/rule_config_generator.py

@@ -0,0 +1,32 @@
+from typing import Any
+
+from langchain.schema import BaseOutputParser, OutputParserException
+from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE
+from libs.json_in_md_parser import parse_and_check_json_markdown
+
+
+class RuleConfigGeneratorOutputParser(BaseOutputParser):
+
+    def get_format_instructions(self) -> str:
+        return RULE_CONFIG_GENERATE_TEMPLATE
+
+    def parse(self, text: str) -> Any:
+        try:
+            expected_keys = ["prompt", "variables", "opening_statement"]
+            parsed = parse_and_check_json_markdown(text, expected_keys)
+            if not isinstance(parsed["prompt"], str):
+                raise ValueError("Expected 'prompt' to be a string.")
+            if not isinstance(parsed["variables"], list):
+                raise ValueError(
+                    f"Expected 'variables' to be a list."
+                )
+            if not isinstance(parsed["opening_statement"], str):
+                raise ValueError(
+                    f"Expected 'opening_statement' to be a str."
+                )
+            return parsed
+        except Exception as e:
+            raise OutputParserException(
+                f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}"
+            )
+

+ 57 - 0
api/core/prompt/prompts.py

@@ -61,3 +61,60 @@ QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
 QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
     QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
 )
+
+RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
+the model prompt that best suits the input. 
+You will be provided with the prompt, variables, and an opening statement. 
+Only the content enclosed in double curly braces, such as {{variable}}, in the prompt can be considered as a variable; \
+otherwise, it cannot exist as a variable in the variables.
+If you believe revising the original input will result in a better response from the language model, you may \
+suggest revisions.
+
+<< FORMATTING >>
+Return a markdown code snippet with a JSON object formatted to look like, \
+no any other string out of markdown code snippet:
+```json
+{{{{
+    "prompt": string \\ generated prompt
+    "variables": list of string \\ variables
+    "opening_statement": string \\ an opening statement to guide users on how to ask questions with generated prompt \
+and fill in variables, with a welcome sentence, and keep TLDR.
+}}}}
+```
+
+<< EXAMPLES >>
+[EXAMPLE A]
+```json
+{
+  "prompt": "Write a letter about love",
+  "variables": [],
+  "opening_statement": "Hi! I'm your love letter writer AI."
+}
+```
+
+[EXAMPLE B]
+```json
+{
+  "prompt": "Translate from {{lanA}} to {{lanB}}",
+  "variables": ["lanA", "lanB"],
+  "opening_statement": "Welcome to use translate app"
+}
+```
+
+[EXAMPLE C]
+```json
+{
+  "prompt": "Write a story about {{topic}}",
+  "variables": ["topic"],
+  "opening_statement": "I'm your story writer"
+}
+```
+
+<< MY INTENDED AUDIENCES >>
+{audiences}
+
+<< HOPING TO SOLVE >>
+{hoping_to_solve}
+
+<< OUTPUT >>
+"""

+ 38 - 0
api/libs/json_in_md_parser.py

@@ -0,0 +1,38 @@
+import json
+from typing import List
+
+from langchain.schema import OutputParserException
+
+
+def parse_json_markdown(json_string: str) -> dict:
+    # Remove the triple backticks if present
+    json_string = json_string.strip()
+    start_index = json_string.find("```json")
+    end_index = json_string.find("```", start_index + len("```json"))
+
+    if start_index != -1 and end_index != -1:
+        extracted_content = json_string[start_index + len("```json"):end_index].strip()
+
+        # Parse the JSON string into a Python dictionary
+        parsed = json.loads(extracted_content)
+    elif json_string.startswith("{"):
+        # Parse the JSON string into a Python dictionary
+        parsed = json.loads(json_string)
+    else:
+        raise Exception("Could not find JSON block in the output.")
+
+    return parsed
+
+
+def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
+    try:
+        json_obj = parse_json_markdown(text)
+    except json.JSONDecodeError as e:
+        raise OutputParserException(f"Got invalid JSON object. Error: {e}")
+    for key in expected_keys:
+        if key not in json_obj:
+            raise OutputParserException(
+                f"Got invalid return object. Expected key `{key}` "
+                f"to be present, but got {json_obj}"
+            )
+    return json_obj