Commit Diff


commit - 15d6c817bc5e2628a9a8eb3d8c4326f1bd86eb24
commit + b1287d36edf7f54d38a0ed93021f2dc84f6db027
blob - 9911115f6f025c9e5ba5453912b0911dfad41b57
blob + e71c3e2b132e7b47fd7fe7043c02fad7492d4448
--- NEWS
+++ NEWS
@@ -1,3 +1,8 @@
+0.22.4	UNRELEASED
+
+ * Fix handling of symrefs with protocol v2.
+   (Jelmer Vernooij, #1389)
+
 0.22.3	2024-10-15
 
  * Improve wheel building in CI, so we can upload wheels for the next release.
blob - 58b3cadebb283d765494edf4f3523577de485b4a
blob + a7b9bdfc180d33bd96a94d605436532b8c5a4454
--- dulwich/client.py
+++ dulwich/client.py
@@ -258,7 +258,33 @@ def read_server_capabilities(pkt_seq):
     return set(server_capabilities)
 
 
-def read_pkt_refs(pkt_seq, server_capabilities=None):
+def read_pkt_refs_v2(
+    pkt_seq,
+) -> Tuple[Dict[bytes, bytes], Dict[bytes, bytes], Dict[bytes, bytes]]:
+    refs = {}
+    symrefs = {}
+    peeled = {}
+    # Receive refs from server
+    for pkt in pkt_seq:
+        parts = pkt.rstrip(b"\n").split(b" ")
+        sha = parts[0]
+        if sha == b"unborn":
+            sha = None
+        ref = parts[1]
+        for part in parts[2:]:
+            if part.startswith(b"peeled:"):
+                peeled[ref] = part[7:]
+            elif part.startswith(b"symref-target:"):
+                symrefs[ref] = part[14:]
+            else:
+                logging.warning("unknown part in pkt-ref: %s", part)
+        refs[ref] = sha
+
+    return refs, symrefs, peeled
+
+
+def read_pkt_refs_v1(pkt_seq) -> Tuple[Dict[bytes, bytes], Set[bytes]]:
+    server_capabilities = None
     refs = {}
     # Receive refs from server
     for pkt in pkt_seq:
@@ -267,24 +293,13 @@ def read_pkt_refs(pkt_seq, server_capabilities=None):
             raise GitProtocolError(ref.decode("utf-8", "replace"))
         if server_capabilities is None:
             (ref, server_capabilities) = extract_capabilities(ref)
-        else:  # Git protocol-v2:
-            try:
-                symref, target = ref.split(b" ", 1)
-            except ValueError:
-                pass
-            else:
-                if symref and target and target[:14] == b"symref-target:":
-                    server_capabilities.add(
-                        b"%s=%s:%s"
-                        % (CAPABILITY_SYMREF, symref, target.split(b":", 1)[1])
-                    )
-                    ref = symref
         refs[ref] = sha
 
     if len(refs) == 0:
         return {}, set()
     if refs == {CAPABILITIES_REF: ZERO_SHA}:
         refs = {}
+    assert server_capabilities is not None
     return refs, set(server_capabilities)
 
 
@@ -682,6 +697,26 @@ def _handle_upload_pack_tail(
             if data == b"":
                 break
             pack_data(data)
+
+
+def _extract_symrefs_and_agent(capabilities):
+    """Extract symrefs and agent from capabilities.
+
+    Args:
+     capabilities: List of capabilities
+    Returns:
+     (symrefs, agent) tuple
+    """
+    symrefs = {}
+    agent = None
+    for capability in capabilities:
+        k, v = parse_capability(capability)
+        if k == CAPABILITY_SYMREF:
+            (src, dst) = v.split(b":", 1)
+            symrefs[src] = dst
+        if k == CAPABILITY_AGENT:
+            agent = v
+    return (symrefs, agent)
 
 
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
@@ -1012,11 +1047,7 @@ class GitClient:
 
     def _negotiate_receive_pack_capabilities(self, server_capabilities):
         negotiated_capabilities = self._send_capabilities & server_capabilities
-        agent = None
-        for capability in server_capabilities:
-            k, v = parse_capability(capability)
-            if k == CAPABILITY_AGENT:
-                agent = v
+        (agent, _symrefs) = _extract_symrefs_and_agent(server_capabilities)
         (extract_capability_names(server_capabilities) - KNOWN_RECEIVE_CAPABILITIES)
         # TODO(jelmer): warn about unknown capabilities
         return negotiated_capabilities, agent
@@ -1069,23 +1100,16 @@ class GitClient:
     def _negotiate_upload_pack_capabilities(self, server_capabilities):
         (extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES)
         # TODO(jelmer): warn about unknown capabilities
-        symrefs = {}
-        agent = None
         fetch_capa = None
         for capability in server_capabilities:
             k, v = parse_capability(capability)
-            if k == CAPABILITY_SYMREF:
-                (src, dst) = v.split(b":", 1)
-                symrefs[src] = dst
-            if k == CAPABILITY_AGENT:
-                agent = v
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
                 fetch_capa = CAPABILITY_FETCH
                 fetch_features = []
-                v = v.strip()
-                if b"shallow" in v.split(b" "):
+                v = v.strip().split(b" ")
+                if b"shallow" in v:
                     fetch_features.append(CAPABILITY_SHALLOW)
-                if b"filter" in v.split(b" "):
+                if b"filter" in v:
                     fetch_features.append(CAPABILITY_FILTER)
                 for i in range(len(fetch_features)):
                     if i == 0:
@@ -1094,6 +1118,8 @@ class GitClient:
                         fetch_capa += b" "
                     fetch_capa += fetch_features[i]
 
+        (symrefs, agent) = _extract_symrefs_and_agent(server_capabilities)
+
         negotiated_capabilities = self._fetch_capabilities & server_capabilities
         if fetch_capa:
             negotiated_capabilities.add(fetch_capa)
@@ -1196,7 +1222,7 @@ class TraditionalGitClient(GitClient):
         proto, unused_can_read, stderr = self._connect(b"receive-pack", path)
         with proto:
             try:
-                old_refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
+                old_refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
             except HangupException as exc:
                 raise _remote_error_from_stderr(stderr) from exc
             (
@@ -1329,7 +1355,7 @@ class TraditionalGitClient(GitClient):
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
                     refs = None
                 else:
-                    refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
+                    refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
             except HangupException as exc:
                 raise _remote_error_from_stderr(stderr) from exc
             (
@@ -1345,9 +1371,7 @@ class TraditionalGitClient(GitClient):
                 for prefix in ref_prefix:
                     proto.write_pkt_line(b"ref-prefix " + prefix)
                 proto.write_pkt_line(None)
-                refs, server_capabilities = read_pkt_refs(
-                    proto.read_pkt_seq(), server_capabilities
-                )
+                refs, symrefs, _peeled = read_pkt_refs_v2(proto.read_pkt_seq())
 
             if refs is None:
                 proto.write_pkt_line(None)
@@ -1425,17 +1449,22 @@ class TraditionalGitClient(GitClient):
             proto.write(b"0001")  # delim-pkt
             proto.write_pkt_line(b"symrefs")
             proto.write_pkt_line(None)
+            with proto:
+                try:
+                    refs, _symrefs, _peeled = read_pkt_refs_v2(proto.read_pkt_seq())
+                except HangupException as exc:
+                    raise _remote_error_from_stderr(stderr) from exc
+                proto.write_pkt_line(None)
+                return refs
         else:
-            server_capabilities = None  # read_pkt_refs will find them
-        with proto:
-            try:
-                refs, server_capabilities = read_pkt_refs(
-                    proto.read_pkt_seq(), server_capabilities
-                )
-            except HangupException as exc:
-                raise _remote_error_from_stderr(stderr) from exc
-            proto.write_pkt_line(None)
-            return refs
+            with proto:
+                try:
+                    refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
+                except HangupException as exc:
+                    raise _remote_error_from_stderr(stderr) from exc
+                proto.write_pkt_line(None)
+                (_symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
+                return refs
 
     def archive(
         self,
@@ -2384,6 +2413,9 @@ class AbstractHttpGitClient(GitClient):
                 self.protocol_version = server_protocol_version
                 if self.protocol_version == 2:
                     server_capabilities, resp, read, proto = begin_protocol_v2(proto)
+                    (refs, _symrefs, _peeled) = read_pkt_refs_v2(proto.read_pkt_seq())
+                    return refs, server_capabilities, base_url
+
                 else:
                     server_capabilities = None  # read_pkt_refs will find them
                     try:
@@ -2414,11 +2446,11 @@ class AbstractHttpGitClient(GitClient):
                         server_capabilities, resp, read, proto = begin_protocol_v2(
                             proto
                         )
-                (
-                    refs,
-                    server_capabilities,
-                ) = read_pkt_refs(proto.read_pkt_seq(), server_capabilities)
-                return refs, server_capabilities, base_url
+                    (
+                        refs,
+                        server_capabilities,
+                    ) = read_pkt_refs_v1(proto.read_pkt_seq())
+                    return refs, server_capabilities, base_url
             else:
                 self.protocol_version = 0  # dumb servers only support protocol v0
                 return read_info_refs(resp), set(), base_url
blob - 329ace1c3432a9f8130942a0b75f5f1f550d9f93
blob + 83c9c59c8772b4a2471a4a55c92a4bac73eb7ef8
--- tests/test_client.py
+++ tests/test_client.py
@@ -47,6 +47,7 @@ from dulwich.client import (
     SubprocessSSHVendor,
     TCPGitClient,
     TraditionalGitClient,
+    _extract_symrefs_and_agent,
     _remote_error_from_stderr,
     check_wants,
     default_urllib3_manager,
@@ -1866,4 +1867,13 @@ And this line is just random noise, too.
                     b"And this line is just random noise, too.",
                 ]
             ),
+        )
+
+
+class TestExtractAgentAndSymrefs(TestCase):
+    def test_extract_agent_and_symrefs(self):
+        (symrefs, agent) = _extract_symrefs_and_agent(
+            [b"agent=git/2.31.1", b"symref=HEAD:refs/heads/master"]
         )
+        self.assertEqual(agent, b"git/2.31.1")
+        self.assertEqual(symrefs, {b"HEAD": b"refs/heads/master"})