Commit Diff


commit - f51d542ff43af954555c33e192f9a496f4fe11d6
commit + c6b2b577a7ff49c11c5d1e46162c0acba48305d5
blob - 120f91c16bcfbaf4a6231a9d90cc33c1d6dc044d
blob + 36fb900810aa976c924adaaa9923f6a2f091acfb
--- swh/loader/git/dumb.py
+++ swh/loader/git/dumb.py
@@ -6,11 +6,12 @@
 from __future__ import annotations
 
 from collections import defaultdict
+import copy
 import logging
 import stat
 import struct
 from tempfile import SpooledTemporaryFile
-from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Set, cast
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Set, cast
 import urllib.parse
 
 from dulwich.errors import NotGitRepository
@@ -28,17 +29,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-HEADERS = {"User-Agent": "Software Heritage dumb Git loader"}
+def requests_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+    """Inject User-Agent header in the requests kwargs"""
+    ret = copy.deepcopy(kwargs)
+    ret.setdefault("headers", {}).update(
+        {"User-Agent": "Software Heritage dumb Git loader"}
+    )
+    ret.setdefault("timeout", (120, 60))
+    return ret
 
 
 @http_retry(
     before_sleep=before_sleep_log(logger, logging.WARNING),
 )
-def check_protocol(repo_url: str) -> bool:
+def check_protocol(repo_url: str, requests_extra_kwargs: Dict[str, Any] = {}) -> bool:
     """Checks if a git repository can be cloned using the dumb protocol.
 
     Args:
         repo_url: Base URL of a git repository
+        requests_extra_kwargs: extra keyword arguments to be passed to requests,
+          e.g. `timeout`, `verify`.
 
     Returns:
         Whether the dumb protocol is supported.
@@ -50,7 +60,7 @@ def check_protocol(repo_url: str) -> bool:
         repo_url.rstrip("/") + "/", "info/refs?service=git-upload-pack/"
     )
     logger.debug("Fetching %s", url)
-    response = requests.get(url, headers=HEADERS)
+    response = requests.get(url, **requests_kwargs(requests_extra_kwargs))
     response.raise_for_status()
     content_type = response.headers.get("Content-Type")
     return (
@@ -73,10 +83,18 @@ class GitObjectsFetcher:
     Args:
         repo_url: Base URL of a git repository
         base_repo: State of repository archived by Software Heritage
+        requests_extra_kwargs: extra keyword arguments to be passed to requests,
+          e.g. `timeout`, `verify`.
     """
 
-    def __init__(self, repo_url: str, base_repo: RepoRepresentation):
+    def __init__(
+        self,
+        repo_url: str,
+        base_repo: RepoRepresentation,
+        requests_extra_kwargs: Dict[str, Any] = {},
+    ):
         self._session = requests.Session()
+        self.requests_extra_kwargs = requests_extra_kwargs
         self.repo_url = repo_url
         self.base_repo = base_repo
         self.objects: Dict[bytes, Set[bytes]] = defaultdict(set)
@@ -130,7 +148,7 @@ class GitObjectsFetcher:
     def _http_get(self, path: str) -> SpooledTemporaryFile:
         url = urllib.parse.urljoin(self.repo_url.rstrip("/") + "/", path)
         logger.debug("Fetching %s", url)
-        response = self._session.get(url, headers=HEADERS)
+        response = self._session.get(url, **requests_kwargs(self.requests_extra_kwargs))
         response.raise_for_status()
         buffer = SpooledTemporaryFile(max_size=100 * 1024 * 1024)
         for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
blob - 2122fb173037e5cc07309f99a841c36c17bd849e
blob + d03f399da60b5916b96eaa891f278e65d989c977
--- swh/loader/git/loader.py
+++ swh/loader/git/loader.py
@@ -180,6 +180,7 @@ class GitLoader(BaseGitLoader):
         pack_size_bytes: int = 4 * 1024 * 1024 * 1024,
         temp_file_cutoff: int = 100 * 1024 * 1024,
         urllib3_extra_kwargs: Dict[str, Any] = {},
+        requests_extra_kwargs: Dict[str, Any] = {},
         **kwargs: Any,
     ):
         """Initialize the bulk updater.
@@ -206,6 +207,7 @@ class GitLoader(BaseGitLoader):
         self.ext_refs: Dict[bytes, Optional[Tuple[int, bytes]]] = {}
         self.repo_pack_size_bytes = 0
         self.urllib3_extra_kwargs = urllib3_extra_kwargs
+        self.requests_extra_kwargs = requests_extra_kwargs
 
     def fetch_pack_from_origin(
         self,
@@ -371,7 +373,7 @@ class GitLoader(BaseGitLoader):
             # by the fetch_pack operation when encountering a repository with
             # dumb transfer protocol so we check if the repository supports it
             # here to continue the loading if it is the case
-            self.dumb = dumb.check_protocol(self.origin.url)
+            self.dumb = dumb.check_protocol(self.origin.url, self.requests_extra_kwargs)
             if not self.dumb:
                 raise
 
@@ -379,7 +381,11 @@ class GitLoader(BaseGitLoader):
             "Protocol used for communication: %s", "dumb" if self.dumb else "smart"
         )
         if self.dumb:
-            self.dumb_fetcher = dumb.GitObjectsFetcher(self.origin.url, base_repo)
+            self.dumb_fetcher = dumb.GitObjectsFetcher(
+                self.origin.url,
+                base_repo,
+                requests_extra_kwargs=self.requests_extra_kwargs,
+            )
             self.dumb_fetcher.fetch_object_ids()
             self.remote_refs = utils.filter_refs(self.dumb_fetcher.refs)
             self.symbolic_refs = utils.filter_refs(self.dumb_fetcher.head)