Commit Diff


commit - 99a377e60c1e81c9453e463022c3b207d4b66116
commit + fc15ca82fd66c12f4956b971eab943f15f861a83
blob - e71c3e2b132e7b47fd7fe7043c02fad7492d4448
blob + 438c60c4bff448b4949c1c989f0e8f1f777289d1
--- NEWS
+++ NEWS
@@ -2,6 +2,8 @@
 
  * Fix handling of symrefs with protocol v2.
    (Jelmer Vernooij, #1389)
+
+ * Add ``ObjectStore.iter_prefix``.  (Jelmer Vernooij)
 
 0.22.3	2024-10-15
 
blob - a28e141764774f3ac0f23d144c49588a87891fdb
blob + 0e7886f680e54e6e0e0534c1062ee3ec2198671e
--- dulwich/object_store.py
+++ dulwich/object_store.py
@@ -22,6 +22,7 @@
 
 """Git object store interfaces and implementation."""
 
+import binascii
 import os
 import stat
 import sys
@@ -358,7 +359,18 @@ class BaseObjectStore:
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
+        """Iterate over all SHA1s that start with a given prefix.
 
+        The default implementation is a naive iteration over all objects.
+        However, subclasses may override this method with more efficient
+        implementations.
+        """
+        for sha in self:
+            if sha.startswith(prefix):
+                yield sha
+
+
 class PackBasedObjectStore(BaseObjectStore):
     def __init__(self, pack_compression_level=-1) -> None:
         self._pack_cache: Dict[str, Pack] = {}
@@ -1026,6 +1038,32 @@ class DiskObjectStore(PackBasedObjectStore):
         os.mkdir(os.path.join(path, "info"))
         os.mkdir(os.path.join(path, PACKDIR))
         return cls(path)
+
+    def iter_prefix(self, prefix):
+        if len(prefix) < 2:
+            return super().iter_prefix(prefix)
+        seen = set()
+        dir = prefix[:2].decode()
+        rest = prefix[2:].decode()
+        for name in os.listdir(os.path.join(self.path, dir)):
+            if name.startswith(rest):
+                sha = os.fsencode(dir + name)
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
+        for p in self.packs:
+            bin_prefix = binascii.unhexlify(prefix) if len(prefix) % 2 == 0 else binascii.unhexlify(prefix[:-1])
+            for sha in p.index.iter_prefix(bin_prefix):
+                sha = sha_to_hex(sha)
+                if sha.startswith(prefix) and sha not in seen:
+                    seen.add(sha)
+                    yield sha
+        for alternate in self.alternates:
+            for sha in alternate.iter_prefix(prefix):
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
 
 
 class MemoryObjectStore(BaseObjectStore):
blob - 988f7a9dfe3f6f60c119ef13436627f88deef18d
blob + e815f4f6285efa1425603d0c01b3d64f6f80826c
--- dulwich/pack.py
+++ dulwich/pack.py
@@ -745,6 +745,28 @@ class FilePackIndex(PackIndex):
         if i is None:
             raise KeyError(sha)
         return self._unpack_offset(i)
+
+    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+        """Iterate over all SHA1s with the given prefix."""
+        start = ord(prefix[:1])
+        if start == 0:
+            start = 0
+        else:
+            start = self._fan_out_table[start - 1]
+        end = ord(prefix[:1]) + 1
+        if end == 0x100:
+            end = len(self)
+        else:
+            end = self._fan_out_table[end]
+        assert start <= end
+        started = False
+        for i in range(start, end):
+            name = self._unpack_name(i)
+            if name.startswith(prefix):
+                yield name
+                started = True
+            elif started:
+                break
 
 
 class PackIndex1(FilePackIndex):
blob - 313be5d38b7bc35de706d798614e55a88d7e96f6
blob + 47f4fcfc2f0fba13b71a0d6caaa2998ead2784fb
--- dulwich/tests/test_object_store.py
+++ dulwich/tests/test_object_store.py
@@ -236,7 +236,12 @@ class ObjectStoreTests:
         self.store.add_object(testobject)
         self.store.close()
 
+    def test_iter_prefix(self):
+        self.store.add_object(testobject)
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(b"")))
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(testobject.id[:10])))
 
+
 class PackBasedObjectStoreTests(ObjectStoreTests):
     def tearDown(self):
         for pack in self.store.packs:
blob - 811e2e1c2d83f656a40945b5ec409ea43169c167
blob + 2bf0c89d42dbbc7a068f35beb5e31be8f8e059e9
--- tests/test_pack.py
+++ tests/test_pack.py
@@ -123,6 +123,12 @@ class PackIndexTests(PackTests):
         self.assertEqual(p.object_sha1(178), hex_to_sha(a_sha))
         self.assertEqual(p.object_sha1(138), hex_to_sha(tree_sha))
         self.assertEqual(p.object_sha1(12), hex_to_sha(commit_sha))
+
+    def test_iter_prefix(self):
+        p = self.get_pack_index(pack1_sha)
+        self.assertEqual([p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha))))
+        self.assertEqual([p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:5])))
+        self.assertEqual([p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:2])))
 
     def test_index_len(self):
         p = self.get_pack_index(pack1_sha)