瀏覽代碼

Update Gitlab query field, add query by path (#8244)

Leo.Wang 7 月之前
父節點
當前提交
75c1a82556

+ 90 - 74
api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py

@@ -1,4 +1,5 @@
 import json
+import urllib.parse
 from datetime import datetime, timedelta
 from typing import Any, Union
 
@@ -13,13 +14,14 @@ class GitlabCommitsTool(BuiltinTool):
         self, user_id: str, tool_parameters: dict[str, Any]
     ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         project = tool_parameters.get("project", "")
+        repository = tool_parameters.get("repository", "")
         employee = tool_parameters.get("employee", "")
         start_time = tool_parameters.get("start_time", "")
         end_time = tool_parameters.get("end_time", "")
         change_type = tool_parameters.get("change_type", "all")
 
-        if not project:
-            return self.create_text_message("Project is required")
+        if not project and not repository:
+            return self.create_text_message("Either project or repository is required")
 
         if not start_time:
             start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
@@ -35,91 +37,105 @@ class GitlabCommitsTool(BuiltinTool):
             site_url = "https://gitlab.com"
 
         # Get commit content
-        result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
+        if repository:
+            result = self.fetch_commits(
+                site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
+            )
+        else:
+            result = self.fetch_commits(
+                site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
+            )
 
         return [self.create_json_message(item) for item in result]
 
-    def fetch(
+    def fetch_commits(
         self,
-        user_id: str,
         site_url: str,
         access_token: str,
-        project: str,
-        employee: str = None,
-        start_time: str = "",
-        end_time: str = "",
-        change_type: str = "",
+        identifier: str,
+        employee: str,
+        start_time: str,
+        end_time: str,
+        change_type: str,
+        is_repository: bool,
     ) -> list[dict[str, Any]]:
         domain = site_url
         headers = {"PRIVATE-TOKEN": access_token}
         results = []
 
         try:
-            # Get all of projects
-            url = f"{domain}/api/v4/projects"
-            response = requests.get(url, headers=headers)
-            response.raise_for_status()
-            projects = response.json()
-
-            filtered_projects = [p for p in projects if project == "*" or p["name"] == project]
-
-            for project in filtered_projects:
-                project_id = project["id"]
-                project_name = project["name"]
-                print(f"Project: {project_name}")
-
-                # Get all of project commits
-                commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
-                params = {"since": start_time, "until": end_time}
-                if employee:
-                    params["author"] = employee
-
-                commits_response = requests.get(commits_url, headers=headers, params=params)
-                commits_response.raise_for_status()
-                commits = commits_response.json()
-
-                for commit in commits:
-                    commit_sha = commit["id"]
-                    author_name = commit["author_name"]
-
+            if is_repository:
+                # URL encode the repository path
+                encoded_identifier = urllib.parse.quote(identifier, safe="")
+                commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
+            else:
+                # Get all projects
+                url = f"{domain}/api/v4/projects"
+                response = requests.get(url, headers=headers)
+                response.raise_for_status()
+                projects = response.json()
+
+                filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
+
+                for project in filtered_projects:
+                    project_id = project["id"]
+                    project_name = project["name"]
+                    print(f"Project: {project_name}")
+
+                    commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
+
+            params = {"since": start_time, "until": end_time}
+            if employee:
+                params["author"] = employee
+
+            commits_response = requests.get(commits_url, headers=headers, params=params)
+            commits_response.raise_for_status()
+            commits = commits_response.json()
+
+            for commit in commits:
+                commit_sha = commit["id"]
+                author_name = commit["author_name"]
+
+                if is_repository:
+                    diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
+                else:
                     diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
-                    diff_response = requests.get(diff_url, headers=headers)
-                    diff_response.raise_for_status()
-                    diffs = diff_response.json()
-
-                    for diff in diffs:
-                        # Calculate code lines of changed
-                        added_lines = diff["diff"].count("\n+")
-                        removed_lines = diff["diff"].count("\n-")
-                        total_changes = added_lines + removed_lines
-
-                        if change_type == "new":
-                            if added_lines > 1:
-                                final_code = "".join(
-                                    [
-                                        line[1:]
-                                        for line in diff["diff"].split("\n")
-                                        if line.startswith("+") and not line.startswith("+++")
-                                    ]
-                                )
-                                results.append(
-                                    {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}
-                                )
-                        else:
-                            if total_changes > 1:
-                                final_code = "".join(
-                                    [
-                                        line[1:]
-                                        for line in diff["diff"].split("\n")
-                                        if (line.startswith("+") or line.startswith("-"))
-                                        and not line.startswith("+++")
-                                        and not line.startswith("---")
-                                    ]
-                                )
-                                final_code_escaped = json.dumps(final_code)[1:-1]  # Escape the final code
-                                results.append(
-                                    {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
-                                )
+
+                diff_response = requests.get(diff_url, headers=headers)
+                diff_response.raise_for_status()
+                diffs = diff_response.json()
+
+                for diff in diffs:
+                    # Calculate code lines of changes
+                    added_lines = diff["diff"].count("\n+")
+                    removed_lines = diff["diff"].count("\n-")
+                    total_changes = added_lines + removed_lines
+
+                    if change_type == "new":
+                        if added_lines > 1:
+                            final_code = "".join(
+                                [
+                                    line[1:]
+                                    for line in diff["diff"].split("\n")
+                                    if line.startswith("+") and not line.startswith("+++")
+                                ]
+                            )
+                            results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
+                    else:
+                        if total_changes > 1:
+                            final_code = "".join(
+                                [
+                                    line[1:]
+                                    for line in diff["diff"].split("\n")
+                                    if (line.startswith("+") or line.startswith("-"))
+                                    and not line.startswith("+++")
+                                    and not line.startswith("---")
+                                ]
+                            )
+                            final_code_escaped = json.dumps(final_code)[1:-1]  # Escape the final code
+                            results.append(
+                                {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
+                            )
         except requests.RequestException as e:
             print(f"Error fetching data from GitLab: {e}")
 

+ 12 - 1
api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml

@@ -21,9 +21,20 @@ parameters:
       zh_Hans: 员工用户名
     llm_description: User name for GitLab
     form: llm
+  - name: repository
+    type: string
+    required: false
+    label:
+      en_US: repository
+      zh_Hans: 仓库路径
+    human_description:
+      en_US: repository
+      zh_Hans: 仓库路径,以namespace/project_name的形式。
+    llm_description: Repository path for GitLab, like namespace/project_name.
+    form: llm
   - name: project
     type: string
-    required: true
+    required: false
     label:
       en_US: project
       zh_Hans: 项目名

+ 47 - 37
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py

@@ -1,3 +1,4 @@
+import urllib.parse
 from typing import Any, Union
 
 import requests
@@ -11,14 +12,14 @@ class GitlabFilesTool(BuiltinTool):
         self, user_id: str, tool_parameters: dict[str, Any]
     ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
         project = tool_parameters.get("project", "")
+        repository = tool_parameters.get("repository", "")
         branch = tool_parameters.get("branch", "")
         path = tool_parameters.get("path", "")
 
-        if not project:
-            return self.create_text_message("Project is required")
+        if not project and not repository:
+            return self.create_text_message("Either project or repository is required")
         if not branch:
             return self.create_text_message("Branch is required")
-
         if not path:
             return self.create_text_message("Path is required")
 
@@ -30,56 +31,51 @@ class GitlabFilesTool(BuiltinTool):
         if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
             site_url = "https://gitlab.com"
 
-        # Get project ID from project name
-        project_id = self.get_project_id(site_url, access_token, project)
-        if not project_id:
-            return self.create_text_message(f"Project '{project}' not found.")
-
-        # Get commit content
-        result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
+        # Get file content
+        if repository:
+            result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
+        else:
+            result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
 
         return [self.create_json_message(item) for item in result]
 
-    def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
-        parts = path.split("/", 1)
-        if len(parts) < 2:
-            return None, None
-        return parts[0], parts[1]
-
-    def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
-        headers = {"PRIVATE-TOKEN": access_token}
-        try:
-            url = f"{site_url}/api/v4/projects?search={project_name}"
-            response = requests.get(url, headers=headers)
-            response.raise_for_status()
-            projects = response.json()
-            for project in projects:
-                if project["name"] == project_name:
-                    return project["id"]
-        except requests.RequestException as e:
-            print(f"Error fetching project ID from GitLab: {e}")
-        return None
-
-    def fetch(
-        self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None
+    def fetch_files(
+        self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
     ) -> list[dict[str, Any]]:
         domain = site_url
         headers = {"PRIVATE-TOKEN": access_token}
         results = []
 
         try:
-            # List files and directories in the given path
-            url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
-            response = requests.get(url, headers=headers)
+            if is_repository:
+                # URL encode the repository path
+                encoded_identifier = urllib.parse.quote(identifier, safe="")
+                tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
+            else:
+                # Get project ID from project name
+                project_id = self.get_project_id(site_url, access_token, identifier)
+                if not project_id:
+                    return self.create_text_message(f"Project '{identifier}' not found.")
+                tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
+
+            response = requests.get(tree_url, headers=headers)
             response.raise_for_status()
             items = response.json()
 
             for item in items:
                 item_path = item["path"]
                 if item["type"] == "tree":  # It's a directory
-                    results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
+                    results.extend(
+                        self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
+                    )
                 else:  # It's a file
-                    file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
+                    if is_repository:
+                        file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}"
+                    else:
+                        file_url = (
+                            f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
+                        )
+
                     file_response = requests.get(file_url, headers=headers)
                     file_response.raise_for_status()
                     file_content = file_response.text
@@ -88,3 +84,17 @@ class GitlabFilesTool(BuiltinTool):
             print(f"Error fetching data from GitLab: {e}")
 
         return results
+
+    def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
+        headers = {"PRIVATE-TOKEN": access_token}
+        try:
+            url = f"{site_url}/api/v4/projects?search={project_name}"
+            response = requests.get(url, headers=headers)
+            response.raise_for_status()
+            projects = response.json()
+            for project in projects:
+                if project["name"] == project_name:
+                    return project["id"]
+        except requests.RequestException as e:
+            print(f"Error fetching project ID from GitLab: {e}")
+        return None

+ 12 - 1
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml

@@ -10,9 +10,20 @@ description:
     zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
   llm: A tool for query GitLab files, Input should be a exists file or directory path.
 parameters:
+  - name: repository
+    type: string
+    required: false
+    label:
+      en_US: repository
+      zh_Hans: 仓库路径
+    human_description:
+      en_US: repository
+      zh_Hans: 仓库路径,以namespace/project_name的形式。
+    llm_description: Repository path for GitLab, like namespace/project_name.
+    form: llm
   - name: project
     type: string
-    required: true
+    required: false
     label:
       en_US: project
       zh_Hans: 项目