Prechádzať zdrojové kódy

fix:Fix a bug that returns null when the passed path is a file. (#12775)

Co-authored-by: 刘江波 <jiangbo721@163.com>
jiangbo721 3 mesiacov pred
rodič
commit
2f41bd495d

+ 40 - 42
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py

@@ -11,19 +11,21 @@ class GitlabFilesTool(BuiltinTool):
     def _invoke(
         self, user_id: str, tool_parameters: dict[str, Any]
     ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
-        project = tool_parameters.get("project", "")
         repository = tool_parameters.get("repository", "")
+        project = tool_parameters.get("project", "")
         branch = tool_parameters.get("branch", "")
         path = tool_parameters.get("path", "")
+        file_path = tool_parameters.get("file_path", "")
 
-        if not project and not repository:
-            return self.create_text_message("Either project or repository is required")
+        if not repository and not project:
+            return self.create_text_message("Either repository or project is required")
         if not branch:
             return self.create_text_message("Branch is required")
-        if not path:
-            return self.create_text_message("Path is required")
+        if not path and not file_path:
+            return self.create_text_message("Either path or file_path is required")
 
         access_token = self.runtime.credentials.get("access_tokens")
+        headers = {"PRIVATE-TOKEN": access_token}
         site_url = self.runtime.credentials.get("site_url")
 
         if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
@@ -31,33 +33,45 @@ class GitlabFilesTool(BuiltinTool):
         if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
             site_url = "https://gitlab.com"
 
-        # Get file content
         if repository:
-            result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
+            # URL encode the repository path
+            identifier = urllib.parse.quote(repository, safe="")
         else:
-            result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
+            identifier = self.get_project_id(site_url, access_token, project)
+            if not identifier:
+                raise Exception(f"Project '{project}' not found.)")
 
-        return [self.create_json_message(item) for item in result]
+        # Get file content
+        if path:
+            results = self.fetch_files(site_url, headers, identifier, branch, path)
+            return [self.create_json_message(item) for item in results]
+        else:
+            result = self.fetch_file(site_url, headers, identifier, branch, file_path)
+            return [self.create_json_message(result)]
+
+    @staticmethod
+    def fetch_file(
+        site_url: str,
+        headers: dict[str, str],
+        identifier: str,
+        branch: str,
+        path: str,
+    ) -> dict[str, Any]:
+        encoded_file_path = urllib.parse.quote(path, safe="")
+        file_url = f"{site_url}/api/v4/projects/{identifier}/repository/files/{encoded_file_path}/raw?ref={branch}"
+
+        file_response = requests.get(file_url, headers=headers)
+        file_response.raise_for_status()
+        file_content = file_response.text
+        return {"path": path, "branch": branch, "content": file_content}
 
     def fetch_files(
-        self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
+        self, site_url: str, headers: dict[str, str], identifier: str, branch: str, path: str
     ) -> list[dict[str, Any]]:
-        domain = site_url
-        headers = {"PRIVATE-TOKEN": access_token}
         results = []
 
         try:
-            if is_repository:
-                # URL encode the repository path
-                encoded_identifier = urllib.parse.quote(identifier, safe="")
-                tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
-            else:
-                # Get project ID from project name
-                project_id = self.get_project_id(site_url, access_token, identifier)
-                if not project_id:
-                    return self.create_text_message(f"Project '{identifier}' not found.")
-                tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
-
+            tree_url = f"{site_url}/api/v4/projects/{identifier}/repository/tree?path={path}&ref={branch}"
             response = requests.get(tree_url, headers=headers)
             response.raise_for_status()
             items = response.json()
@@ -65,26 +79,10 @@ class GitlabFilesTool(BuiltinTool):
             for item in items:
                 item_path = item["path"]
                 if item["type"] == "tree":  # It's a directory
-                    results.extend(
-                        self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
-                    )
+                    results.extend(self.fetch_files(site_url, headers, identifier, branch, item_path))
                 else:  # It's a file
-                    encoded_item_path = urllib.parse.quote(item_path, safe="")
-                    if is_repository:
-                        file_url = (
-                            f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
-                            f"/{encoded_item_path}/raw?ref={branch}"
-                        )
-                    else:
-                        file_url = (
-                            f"{domain}/api/v4/projects/{project_id}/repository/files"
-                            f"{encoded_item_path}/raw?ref={branch}"
-                        )
-
-                    file_response = requests.get(file_url, headers=headers)
-                    file_response.raise_for_status()
-                    file_content = file_response.text
-                    results.append({"path": item_path, "branch": branch, "content": file_content})
+                    result = self.fetch_file(site_url, headers, identifier, branch, item_path)
+                    results.append(result)
         except requests.RequestException as e:
             print(f"Error fetching data from GitLab: {e}")
 

+ 12 - 3
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml

@@ -29,7 +29,7 @@ parameters:
       zh_Hans: 项目
     human_description:
       en_US: project
-      zh_Hans: 项目
+      zh_Hans: 项目(和仓库路径二选一,都填写以仓库路径优先)
     llm_description: Project for GitLab
     form: llm
   - name: branch
@@ -45,12 +45,21 @@ parameters:
     form: llm
   - name: path
     type: string
-    required: true
     label:
       en_US: path
-      zh_Hans: 文件路径
+      zh_Hans: 文件
     human_description:
       en_US: path
+      zh_Hans: 文件夹
+    llm_description: Dir path for GitLab
+    form: llm
+  - name: file_path
+    type: string
+    label:
+      en_US: file_path
       zh_Hans: 文件路径
+    human_description:
+      en_US: file_path
+      zh_Hans: 文件路径(和文件夹二选一,都填写以文件夹优先)
     llm_description: File path for GitLab
     form: llm