Sfoglia il codice sorgente

Chore: optimize the code of PromptTransform (#16143)

Yongtao Huang 1 mese fa
parent
commit
d339403e89

+ 8 - 8
api/core/prompt/simple_prompt_transform.py

@@ -93,7 +93,7 @@ class SimplePromptTransform(PromptTransform):
 
         return prompt_messages, stops
 
-    def get_prompt_str_and_rules(
+    def _get_prompt_str_and_rules(
         self,
         app_mode: AppMode,
         model_config: ModelConfigWithCredentialsEntity,
@@ -184,7 +184,7 @@ class SimplePromptTransform(PromptTransform):
         prompt_messages: list[PromptMessage] = []
 
         # get prompt
-        prompt, _ = self.get_prompt_str_and_rules(
+        prompt, _ = self._get_prompt_str_and_rules(
             app_mode=app_mode,
             model_config=model_config,
             pre_prompt=pre_prompt,
@@ -209,9 +209,9 @@ class SimplePromptTransform(PromptTransform):
             )
 
         if query:
-            prompt_messages.append(self.get_last_user_message(query, files, image_detail_config))
+            prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
         else:
-            prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config))
+            prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
 
         return prompt_messages, None
 
@@ -228,7 +228,7 @@ class SimplePromptTransform(PromptTransform):
         image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
     ) -> tuple[list[PromptMessage], Optional[list[str]]]:
         # get prompt
-        prompt, prompt_rules = self.get_prompt_str_and_rules(
+        prompt, prompt_rules = self._get_prompt_str_and_rules(
             app_mode=app_mode,
             model_config=model_config,
             pre_prompt=pre_prompt,
@@ -254,7 +254,7 @@ class SimplePromptTransform(PromptTransform):
             )
 
             # get prompt
-            prompt, prompt_rules = self.get_prompt_str_and_rules(
+            prompt, prompt_rules = self._get_prompt_str_and_rules(
                 app_mode=app_mode,
                 model_config=model_config,
                 pre_prompt=pre_prompt,
@@ -268,9 +268,9 @@ class SimplePromptTransform(PromptTransform):
         if stops is not None and len(stops) == 0:
             stops = None
 
-        return [self.get_last_user_message(prompt, files, image_detail_config)], stops
+        return [self._get_last_user_message(prompt, files, image_detail_config)], stops
 
-    def get_last_user_message(
+    def _get_last_user_message(
         self,
         prompt: str,
         files: Sequence["File"],

+ 0 - 2
api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py

@@ -64,12 +64,10 @@ def test_get_prompt():
     transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
     result = transform.get_prompt()
 
-    assert len(result) <= max_token_limit
     assert len(result) == 4
 
     max_token_limit = 20
     transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
     result = transform.get_prompt()
 
-    assert len(result) <= max_token_limit
     assert len(result) == 12

+ 0 - 1
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py

@@ -84,7 +84,6 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
         query_in_prompt=True,
         with_memory_prompt=False,
     )
-    print(prompt_template["prompt_template"].template)
     prompt_rules = prompt_template["prompt_rules"]
     assert prompt_template["prompt_template"].template == (
         prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]