commit c6b2b577a7ff49c11c5d1e46162c0acba48305d5 from: Nicolas Dandrimont date: Wed Jan 24 16:16:18 2024 UTC dumb loader: add support for extra requests kwargs This is useful to override the default settings of the requests Session, e.g. certificate verification of connect/read timeouts. commit - f51d542ff43af954555c33e192f9a496f4fe11d6 commit + c6b2b577a7ff49c11c5d1e46162c0acba48305d5 blob - 120f91c16bcfbaf4a6231a9d90cc33c1d6dc044d blob + 36fb900810aa976c924adaaa9923f6a2f091acfb --- swh/loader/git/dumb.py +++ swh/loader/git/dumb.py @@ -6,11 +6,12 @@ from __future__ import annotations from collections import defaultdict +import copy import logging import stat import struct from tempfile import SpooledTemporaryFile -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Set, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Set, cast import urllib.parse from dulwich.errors import NotGitRepository @@ -28,17 +29,26 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -HEADERS = {"User-Agent": "Software Heritage dumb Git loader"} +def requests_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Inject User-Agent header in the requests kwargs""" + ret = copy.deepcopy(kwargs) + ret.setdefault("headers", {}).update( + {"User-Agent": "Software Heritage dumb Git loader"} + ) + ret.setdefault("timeout", (120, 60)) + return ret @http_retry( before_sleep=before_sleep_log(logger, logging.WARNING), ) -def check_protocol(repo_url: str) -> bool: +def check_protocol(repo_url: str, requests_extra_kwargs: Dict[str, Any] = {}) -> bool: """Checks if a git repository can be cloned using the dumb protocol. Args: repo_url: Base URL of a git repository + requests_extra_kwargs: extra keyword arguments to be passed to requests, + e.g. `timeout`, `verify`. Returns: Whether the dumb protocol is supported. @@ -50,7 +60,7 @@ def check_protocol(repo_url: str) -> bool: repo_url.rstrip("/") + "/", "info/refs?service=git-upload-pack/" ) logger.debug("Fetching %s", url) - response = requests.get(url, headers=HEADERS) + response = requests.get(url, **requests_kwargs(requests_extra_kwargs)) response.raise_for_status() content_type = response.headers.get("Content-Type") return ( @@ -73,10 +83,18 @@ class GitObjectsFetcher: Args: repo_url: Base URL of a git repository base_repo: State of repository archived by Software Heritage + requests_extra_kwargs: extra keyword arguments to be passed to requests, + e.g. `timeout`, `verify`. """ - def __init__(self, repo_url: str, base_repo: RepoRepresentation): + def __init__( + self, + repo_url: str, + base_repo: RepoRepresentation, + requests_extra_kwargs: Dict[str, Any] = {}, + ): self._session = requests.Session() + self.requests_extra_kwargs = requests_extra_kwargs self.repo_url = repo_url self.base_repo = base_repo self.objects: Dict[bytes, Set[bytes]] = defaultdict(set) @@ -130,7 +148,7 @@ class GitObjectsFetcher: def _http_get(self, path: str) -> SpooledTemporaryFile: url = urllib.parse.urljoin(self.repo_url.rstrip("/") + "/", path) logger.debug("Fetching %s", url) - response = self._session.get(url, headers=HEADERS) + response = self._session.get(url, **requests_kwargs(self.requests_extra_kwargs)) response.raise_for_status() buffer = SpooledTemporaryFile(max_size=100 * 1024 * 1024) for chunk in response.iter_content(chunk_size=10 * 1024 * 1024): blob - 2122fb173037e5cc07309f99a841c36c17bd849e blob + d03f399da60b5916b96eaa891f278e65d989c977 --- swh/loader/git/loader.py +++ swh/loader/git/loader.py @@ -180,6 +180,7 @@ class GitLoader(BaseGitLoader): pack_size_bytes: int = 4 * 1024 * 1024 * 1024, temp_file_cutoff: int = 100 * 1024 * 1024, urllib3_extra_kwargs: Dict[str, Any] = {}, + requests_extra_kwargs: Dict[str, Any] = {}, **kwargs: Any, ): """Initialize the bulk updater. @@ -206,6 +207,7 @@ class GitLoader(BaseGitLoader): self.ext_refs: Dict[bytes, Optional[Tuple[int, bytes]]] = {} self.repo_pack_size_bytes = 0 self.urllib3_extra_kwargs = urllib3_extra_kwargs + self.requests_extra_kwargs = requests_extra_kwargs def fetch_pack_from_origin( self, @@ -371,7 +373,7 @@ class GitLoader(BaseGitLoader): # by the fetch_pack operation when encountering a repository with # dumb transfer protocol so we check if the repository supports it # here to continue the loading if it is the case - self.dumb = dumb.check_protocol(self.origin.url) + self.dumb = dumb.check_protocol(self.origin.url, self.requests_extra_kwargs) if not self.dumb: raise @@ -379,7 +381,11 @@ class GitLoader(BaseGitLoader): "Protocol used for communication: %s", "dumb" if self.dumb else "smart" ) if self.dumb: - self.dumb_fetcher = dumb.GitObjectsFetcher(self.origin.url, base_repo) + self.dumb_fetcher = dumb.GitObjectsFetcher( + self.origin.url, + base_repo, + requests_extra_kwargs=self.requests_extra_kwargs, + ) self.dumb_fetcher.fetch_object_ids() self.remote_refs = utils.filter_refs(self.dumb_fetcher.refs) self.symbolic_refs = utils.filter_refs(self.dumb_fetcher.head)