Selaa lähdekoodia

Feat/tools/gitlab (#10407)

Leo.Wang 5 kuukautta sitten
vanhempi
commit
c9f785e00f

+ 3 - 4
api/core/rag/extractor/word_extractor.py

@@ -28,7 +28,6 @@ logger = logging.getLogger(__name__)
 class WordExtractor(BaseExtractor):
     """Load docx files.
 
-
     Args:
         file_path: Path to the file to load.
     """
@@ -51,9 +50,9 @@ class WordExtractor(BaseExtractor):
 
             self.web_path = self.file_path
             # TODO: use a better way to handle the file
-            self.temp_file = tempfile.NamedTemporaryFile()  # noqa: SIM115
-            self.temp_file.write(r.content)
-            self.file_path = self.temp_file.name
+            with tempfile.NamedTemporaryFile(delete=False) as self.temp_file:
+                self.temp_file.write(r.content)
+                self.file_path = self.temp_file.name
         elif not os.path.isfile(self.file_path):
             raise ValueError(f"File path {self.file_path} is not a valid file or url")
 

+ 29 - 37
api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py

@@ -13,15 +13,15 @@ class GitlabCommitsTool(BuiltinTool):
     def _invoke(
         self, user_id: str, tool_parameters: dict[str, Any]
     ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
-        project = tool_parameters.get("project", "")
+        branch = tool_parameters.get("branch", "")
         repository = tool_parameters.get("repository", "")
         employee = tool_parameters.get("employee", "")
         start_time = tool_parameters.get("start_time", "")
         end_time = tool_parameters.get("end_time", "")
         change_type = tool_parameters.get("change_type", "all")
 
-        if not project and not repository:
-            return self.create_text_message("Either project or repository is required")
+        if not repository:
+            return self.create_text_message("Either repository is required")
 
         if not start_time:
             start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
@@ -37,14 +37,9 @@ class GitlabCommitsTool(BuiltinTool):
             site_url = "https://gitlab.com"
 
         # Get commit content
-        if repository:
-            result = self.fetch_commits(
-                site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
-            )
-        else:
-            result = self.fetch_commits(
-                site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
-            )
+        result = self.fetch_commits(
+            site_url, access_token, repository, branch, employee, start_time, end_time, change_type, is_repository=True
+        )
 
         return [self.create_json_message(item) for item in result]
 
@@ -52,7 +47,8 @@ class GitlabCommitsTool(BuiltinTool):
         self,
         site_url: str,
         access_token: str,
-        identifier: str,
+        repository: str,
+        branch: str,
         employee: str,
         start_time: str,
         end_time: str,
@@ -64,27 +60,14 @@ class GitlabCommitsTool(BuiltinTool):
         results = []
 
         try:
-            if is_repository:
-                # URL encode the repository path
-                encoded_identifier = urllib.parse.quote(identifier, safe="")
-                commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
-            else:
-                # Get all projects
-                url = f"{domain}/api/v4/projects"
-                response = requests.get(url, headers=headers)
-                response.raise_for_status()
-                projects = response.json()
-
-                filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
-
-                for project in filtered_projects:
-                    project_id = project["id"]
-                    project_name = project["name"]
-                    print(f"Project: {project_name}")
-
-                    commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
+            # URL encode the repository path
+            encoded_repository = urllib.parse.quote(repository, safe="")
+            commits_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits"
 
+            # Fetch commits for the repository
             params = {"since": start_time, "until": end_time}
+            if branch:
+                params["ref_name"] = branch
             if employee:
                 params["author"] = employee
 
@@ -96,10 +79,7 @@ class GitlabCommitsTool(BuiltinTool):
                 commit_sha = commit["id"]
                 author_name = commit["author_name"]
 
-                if is_repository:
-                    diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
-                else:
-                    diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
+                diff_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits/{commit_sha}/diff"
 
                 diff_response = requests.get(diff_url, headers=headers)
                 diff_response.raise_for_status()
@@ -120,7 +100,14 @@ class GitlabCommitsTool(BuiltinTool):
                                     if line.startswith("+") and not line.startswith("+++")
                                 ]
                             )
-                            results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
+                            results.append(
+                                {
+                                    "diff_url": diff_url,
+                                    "commit_sha": commit_sha,
+                                    "author_name": author_name,
+                                    "diff": final_code,
+                                }
+                            )
                     else:
                         if total_changes > 1:
                             final_code = "".join(
@@ -134,7 +121,12 @@ class GitlabCommitsTool(BuiltinTool):
                             )
                             final_code_escaped = json.dumps(final_code)[1:-1]  # Escape the final code
                             results.append(
-                                {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
+                                {
+                                    "diff_url": diff_url,
+                                    "commit_sha": commit_sha,
+                                    "author_name": author_name,
+                                    "diff": final_code_escaped,
+                                }
                             )
         except requests.RequestException as e:
             print(f"Error fetching data from GitLab: {e}")

+ 7 - 7
api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml

@@ -23,7 +23,7 @@ parameters:
     form: llm
   - name: repository
     type: string
-    required: false
+    required: true
     label:
       en_US: repository
       zh_Hans: 仓库路径
@@ -32,16 +32,16 @@ parameters:
       zh_Hans: 仓库路径,以namespace/project_name的形式。
     llm_description: Repository path for GitLab, like namespace/project_name.
     form: llm
-  - name: project
+  - name: branch
     type: string
     required: false
     label:
-      en_US: project
-      zh_Hans: 项目
+      en_US: branch
+      zh_Hans: 分支
     human_description:
-      en_US: project
-      zh_Hans: 项目
-    llm_description: project for GitLab
+      en_US: branch
+      zh_Hans: 分支
+    llm_description: branch for GitLab
     form: llm
   - name: start_time
     type: string

+ 78 - 0
api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.py

@@ -0,0 +1,78 @@
+import urllib.parse
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class GitlabMergeRequestsTool(BuiltinTool):
+    def _invoke(
+        self, user_id: str, tool_parameters: dict[str, Any]
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        repository = tool_parameters.get("repository", "")
+        branch = tool_parameters.get("branch", "")
+        start_time = tool_parameters.get("start_time", "")
+        end_time = tool_parameters.get("end_time", "")
+        state = tool_parameters.get("state", "opened")  # Default to "opened"
+
+        if not repository:
+            return self.create_text_message("Repository is required")
+
+        access_token = self.runtime.credentials.get("access_tokens")
+        site_url = self.runtime.credentials.get("site_url")
+
+        if not access_token:
+            return self.create_text_message("Gitlab API Access Tokens is required.")
+        if not site_url:
+            site_url = "https://gitlab.com"
+
+        # Get merge requests
+        result = self.get_merge_requests(site_url, access_token, repository, branch, start_time, end_time, state)
+
+        return [self.create_json_message(item) for item in result]
+
+    def get_merge_requests(
+        self, site_url: str, access_token: str, repository: str, branch: str, start_time: str, end_time: str, state: str
+    ) -> list[dict[str, Any]]:
+        domain = site_url
+        headers = {"PRIVATE-TOKEN": access_token}
+        results = []
+
+        try:
+            # URL encode the repository path
+            encoded_repository = urllib.parse.quote(repository, safe="")
+            merge_requests_url = f"{domain}/api/v4/projects/{encoded_repository}/merge_requests"
+            params = {"state": state}
+
+            # Add time filters if provided
+            if start_time:
+                params["created_after"] = start_time
+            if end_time:
+                params["created_before"] = end_time
+
+            response = requests.get(merge_requests_url, headers=headers, params=params)
+            response.raise_for_status()
+            merge_requests = response.json()
+
+            for mr in merge_requests:
+                # Filter by target branch
+                if branch and mr["target_branch"] != branch:
+                    continue
+
+                results.append(
+                    {
+                        "id": mr["id"],
+                        "title": mr["title"],
+                        "author": mr["author"]["name"],
+                        "web_url": mr["web_url"],
+                        "target_branch": mr["target_branch"],
+                        "created_at": mr["created_at"],
+                        "state": mr["state"],
+                    }
+                )
+        except requests.RequestException as e:
+            print(f"Error fetching merge requests from GitLab: {e}")
+
+        return results

+ 77 - 0
api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml

@@ -0,0 +1,77 @@
+identity:
+  name: gitlab_mergerequests
+  author: Leo.Wang
+  label:
+    en_US: GitLab Merge Requests
+    zh_Hans: GitLab 合并请求查询
+description:
+  human:
+    en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
+    zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。
+  llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch.
+parameters:
+  - name: repository
+    type: string
+    required: false
+    label:
+      en_US: repository
+      zh_Hans: 仓库路径
+    human_description:
+      en_US: repository
+      zh_Hans: 仓库路径,以namespace/project_name的形式。
+    llm_description: Repository path for GitLab, like namespace/project_name.
+    form: llm
+  - name: branch
+    type: string
+    required: false
+    label:
+      en_US: branch
+      zh_Hans: 分支名
+    human_description:
+      en_US: branch
+      zh_Hans: 分支名
+    llm_description: branch for GitLab
+    form: llm
+  - name: start_time
+    type: string
+    required: false
+    label:
+      en_US: start_time
+      zh_Hans: 开始时间
+    human_description:
+      en_US: start_time
+      zh_Hans: 开始时间
+    llm_description: Start time for GitLab
+    form: llm
+  - name: end_time
+    type: string
+    required: false
+    label:
+      en_US: end_time
+      zh_Hans: 结束时间
+    human_description:
+      en_US: end_time
+      zh_Hans: 结束时间
+    llm_description: End time for GitLab
+    form: llm
+  - name: state
+    type: select
+    required: false
+    options:
+      - value: opened
+        label:
+          en_US: opened
+          zh_Hans: 打开
+      - value: closed
+        label:
+          en_US: closed
+          zh_Hans: 关闭
+    default: opened
+    label:
+      en_US: state
+      zh_Hans: 变更状态
+    human_description:
+      en_US: state
+      zh_Hans: 变更状态
+    llm_description: Merge request state type for GitLab
+    form: llm

+ 81 - 0
api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.py

@@ -0,0 +1,81 @@
+import urllib.parse
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class GitlabProjectsTool(BuiltinTool):
+    def _invoke(
+        self, user_id: str, tool_parameters: dict[str, Any]
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        project_name = tool_parameters.get("project_name", "")
+        page = tool_parameters.get("page", 1)
+        page_size = tool_parameters.get("page_size", 20)
+
+        access_token = self.runtime.credentials.get("access_tokens")
+        site_url = self.runtime.credentials.get("site_url")
+
+        if not access_token:
+            return self.create_text_message("Gitlab API Access Tokens is required.")
+        if not site_url:
+            site_url = "https://gitlab.com"
+
+        # Get project content
+        result = self.fetch_projects(site_url, access_token, project_name, page, page_size)
+
+        return [self.create_json_message(item) for item in result]
+
+    def fetch_projects(
+        self,
+        site_url: str,
+        access_token: str,
+        project_name: str,
+        page: str,
+        page_size: str,
+    ) -> list[dict[str, Any]]:
+        domain = site_url
+        headers = {"PRIVATE-TOKEN": access_token}
+        results = []
+
+        try:
+            if project_name:
+                # URL encode the project name for the search query
+                encoded_project_name = urllib.parse.quote(project_name, safe="")
+                projects_url = (
+                    f"{domain}/api/v4/projects?search={encoded_project_name}&page={page}&per_page={page_size}"
+                )
+            else:
+                projects_url = f"{domain}/api/v4/projects?page={page}&per_page={page_size}"
+
+            response = requests.get(projects_url, headers=headers)
+            response.raise_for_status()
+            projects = response.json()
+
+            for project in projects:
+                # Filter projects by exact name match if necessary
+                if project_name and project["name"].lower() == project_name.lower():
+                    results.append(
+                        {
+                            "id": project["id"],
+                            "name": project["name"],
+                            "description": project.get("description", ""),
+                            "web_url": project["web_url"],
+                        }
+                    )
+                elif not project_name:
+                    # If no specific project name is provided, add all projects
+                    results.append(
+                        {
+                            "id": project["id"],
+                            "name": project["name"],
+                            "description": project.get("description", ""),
+                            "web_url": project["web_url"],
+                        }
+                    )
+        except requests.RequestException as e:
+            print(f"Error fetching data from GitLab: {e}")
+
+        return results

+ 45 - 0
api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.yaml

@@ -0,0 +1,45 @@
+identity:
+  name: gitlab_projects
+  author: Leo.Wang
+  label:
+    en_US: GitLab Projects
+    zh_Hans: GitLab 项目列表查询
+description:
+  human:
+    en_US: A tool for query GitLab projects, Input should be a project name.
+    zh_Hans: 一个用于查询 GitLab 项目列表的工具,输入的内容应该是一个项目名称。
+  llm: A tool for query GitLab projects, Input should be a project name.
+parameters:
+  - name: project_name
+    type: string
+    required: false
+    label:
+      en_US: project_name
+      zh_Hans: 项目名称
+    human_description:
+      en_US: project_name
+      zh_Hans: 项目名称
+    llm_description: Project name for GitLab
+    form: llm
+  - name: page
+    type: string
+    required: false
+    label:
+      en_US: page
+      zh_Hans: 页码
+    human_description:
+      en_US: page
+      zh_Hans: 页码
+    llm_description: Page index for GitLab
+    form: llm
+  - name: page_size
+    type: string
+    required: false
+    label:
+      en_US: page_size
+      zh_Hans: 每页数量
+    human_description:
+      en_US: page_size
+      zh_Hans: 每页数量
+    llm_description: Page size for GitLab
+    form: llm