Parcourir la source

feat: huggingface llm add new params. (#2014)

Garfield Dai il y a 1 an
Parent
commit
cb7be3767c

+ 49 - 1
api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py

@@ -134,7 +134,55 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
             precision=0,
         )
 
-        return [temperature_rule, top_k_rule, top_p_rule]
+        max_new_tokens = ParameterRule(
+            name='max_new_tokens',
+            label={
+                'en_US': 'Max New Tokens',
+                'zh_Hans': '最大新标记',
+            },
+            type='int',
+            help={
+                'en_US': 'Maximum number of generated tokens.',
+                'zh_Hans': '生成的标记的最大数量。',
+            },
+            required=False,
+            default=20,
+            min=1,
+            max=4096,
+            precision=0,
+        )
+
+        seed = ParameterRule(
+            name='seed',
+            label={
+                'en_US': 'Random sampling seed',
+                'zh_Hans': '随机采样种子',
+            },
+            type='int',
+            help={
+                'en_US': 'Random sampling seed.',
+                'zh_Hans': '随机采样种子。',
+            },
+            required=False,
+            precision=0,
+        )
+
+        repetition_penalty = ParameterRule(
+            name='repetition_penalty',
+            label={
+                'en_US': 'Repetition Penalty',
+                'zh_Hans': '重复惩罚',
+            },
+            type='float',
+            help={
+                'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.',
+                'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。',
+            },
+            required=False,
+            precision=1,
+        )
+
+        return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty]
 
     def _handle_generate_stream_response(self,
                                          model: str,