Browse Source

fix: add hf task field (#976)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
takatost 1 year ago
parent
commit
071e7800a0

+ 53 - 0
web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx

@@ -38,6 +38,7 @@ const config: ProviderConfig = {
     defaultValue: {
       model_type: 'text-generation',
       huggingfacehub_api_type: 'hosted_inference_api',
+      task_type: 'text-generation',
     },
     validateKeys: (v?: FormValue) => {
       if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
@@ -51,10 +52,36 @@ const config: ProviderConfig = {
           'huggingfacehub_api_token',
           'model_name',
           'huggingfacehub_endpoint_url',
+          'task_type',
         ]
       }
       return []
     },
+    filterValue: (v?: FormValue) => {
+      let filteredKeys: string[] = []
+      if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
+        filteredKeys = [
+          'huggingfacehub_api_type',
+          'huggingfacehub_api_token',
+          'model_name',
+          'model_type',
+        ]
+      }
+      if (v?.huggingfacehub_api_type === 'inference_endpoints') {
+        filteredKeys = [
+          'huggingfacehub_api_type',
+          'huggingfacehub_api_token',
+          'model_name',
+          'huggingfacehub_endpoint_url',
+          'task_type',
+          'model_type',
+        ]
+      }
+      return filteredKeys.reduce((prev: FormValue, next: string) => {
+        prev[next] = v?.[next] || ''
+        return prev
+      }, {})
+    },
     fields: [
       {
         type: 'radio',
@@ -120,6 +147,32 @@ const config: ProviderConfig = {
           'zh-Hans': '在此输入您的端点 URL',
         },
       },
+      {
+        hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
+        type: 'radio',
+        key: 'task_type',
+        required: true,
+        label: {
+          'en': 'Task',
+          'zh-Hans': 'Task',
+        },
+        options: [
+          {
+            key: 'text2text-generation',
+            label: {
+              'en': 'Text-to-Text Generation',
+              'zh-Hans': 'Text-to-Text Generation',
+            },
+          },
+          {
+            key: 'text-generation',
+            label: {
+              'en': 'Text Generation',
+              'zh-Hans': 'Text Generation',
+            },
+          },
+        ],
+      },
     ],
   },
 }

+ 1 - 0
web/app/components/header/account-setting/model-page/declarations.ts

@@ -91,6 +91,7 @@ export type ProviderConfigModal = {
   icon: ReactElement
   defaultValue?: FormValue
   validateKeys?: string[] | ((v?: FormValue) => string[])
+  filterValue?: (v?: FormValue) => FormValue
   fields: Field[]
   link: {
     href: string

+ 3 - 2
web/app/components/header/account-setting/model-page/index.tsx

@@ -124,8 +124,9 @@ const ModelPage = () => {
     updateModelList(ModelType.embeddings)
     mutateProviders()
   }
-  const handleSave = async (v?: FormValue) => {
-    if (v && modelModalConfig) {
+  const handleSave = async (originValue?: FormValue) => {
+    if (originValue && modelModalConfig) {
+      const v = modelModalConfig.filterValue ? modelModalConfig.filterValue(originValue) : originValue
       let body, url
       if (ConfigurableProviders.includes(modelModalConfig.key)) {
         const { model_name, model_type, ...config } = v

+ 1 - 1
web/app/components/header/account-setting/model-page/model-modal/Form.tsx

@@ -68,7 +68,7 @@ const Form: FC<FormProps> = ({
           return true
         },
         run: () => {
-          return validateModelProviderFn(modelModal!.key, v)
+          return validateModelProviderFn(modelModal!.key, modelModal?.filterValue ? modelModal?.filterValue(v) : v)
         },
       })
     }