Commit Diff


commit - 98538380976f8964720f2eefc9018b50ecb3832a
commit + 52328a607b0de317a4d50f1763c91b054df3c1b1
blob - db0735180179361254559910548c5f38f57e7c4c
blob + e9266b05d801c6d50cc6a8b1149abd5749b773ce
--- dulwich/index.py
+++ dulwich/index.py
@@ -69,6 +69,27 @@ IndexEntry = collections.namedtuple(
         "extended_flags",
     ],
 )
+
+
+class ConflictedIndexEntry:
+    """Index entry that represents a conflict."""
+
+    ancestor: Optional[IndexEntry]
+    this: Optional[IndexEntry]
+    other: Optional[IndexEntry]
+
+    def __init__(self):
+        self.ancestor = None
+        self.this = None
+        self.other = None
+
+    def entries(self) -> Iterable[IndexEntry]:
+        if self.ancestor:
+            yield self.ancestor
+        if self.this:
+            yield self.this
+        if self.other:
+            yield self.other
 
 
 # 2-bit stage (during merge)
@@ -90,7 +111,6 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 
 DEFAULT_VERSION = 2
 
-
 class Stage(Enum):
     NORMAL = 0
     MERGE_CONFLICT_ANCESTOR = 1
@@ -98,9 +118,8 @@ class Stage(Enum):
     MERGE_CONFLICT_OTHER = 3
 
 
-class UnmergedEntriesInIndexEx(Exception):
-    def __init__(self, message):
-        super().__init__(message)
+class UnmergedEntries(Exception):
+    """Unmerged entries exist in the index"""
 
 
 def read_stage(entry: IndexEntry) -> Stage:
@@ -250,7 +269,7 @@ class UnsupportedIndexFormat(Exception):
         self.index_format_version = version
 
 
-def read_index(f: BinaryIO):
+def read_index(f: BinaryIO) -> Iterator[Tuple[bytes, IndexEntry]]:
     """Read an index file, yielding the individual entries."""
     header = f.read(4)
     if header != b"DIRC":
@@ -262,20 +281,6 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
 
 
-def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]:
-    """Read an index file and return it as a dictionary.
-       Dict Key is tuple of path and stage number, as
-            path alone is not unique
-    Args:
-      f: File object to read fromls.
-    """
-    ret = {}
-    for name, entry in read_index(f):
-        stage = read_stage(entry)
-        ret[(name, stage)] = entry
-    return ret
-
-
 def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: Optional[int] = None):
     """Write an index file.
 
@@ -294,7 +299,7 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes
 
 def write_index_dict(
     f: BinaryIO,
-    entries: Dict[Tuple[bytes, Stage], IndexEntry],
+    entries: Dict[bytes, IndexEntry | ConflictedIndexEntry],
     version: Optional[int] = None,
 ) -> None:
     """Write an index file based on the contents of a dictionary.
@@ -302,12 +307,16 @@ def write_index_dict(
     """
     entries_list = []
     for key in sorted(entries):
-        if isinstance(key, tuple):
-            name, stage = key
+        value = entries[key]
+        if isinstance(value, ConflictedIndexEntry):
+            if value.ancestor is not None:
+                entries_list.append((key, value.ancestor))
+            if value.this is not None:
+                entries_list.append((key, value.this))
+            if value.other is not None:
+                entries_list.append((key, value.other))
         else:
-            name = key
-            stage = Stage.NORMAL
-        entries_list.append((name, entries[(name, stage)]))
+            entries_list.append((key, value))
     write_index(f, entries_list, version=version)
 
 
@@ -337,7 +346,7 @@ def cleanup_mode(mode: int) -> int:
 class Index:
     """A Git Index file."""
 
-    _bynamestage: Dict[Tuple[bytes, Stage], IndexEntry]
+    _byname: Dict[bytes, IndexEntry | ConflictedIndexEntry]
 
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
         """Create an index object associated with the given filename.
@@ -365,7 +374,7 @@ class Index:
         f = GitFile(self._filename, "wb")
         try:
             f = SHA1Writer(f)
-            write_index_dict(f, self._bynamestage, version=self._version)
+            write_index_dict(f, self._byname, version=self._version)
         finally:
             f.close()
 
@@ -378,7 +387,19 @@ class Index:
             f = SHA1Reader(f)
             for name, entry in read_index(f):
                 stage = read_stage(entry)
-                self[(name, stage)] = entry
+                if stage == Stage.NORMAL:
+                    self[name] = entry
+                else:
+                    import pdb; pdb.set_trace()
+                    existing = self._byname.setdefault(name, ConflictedIndexEntry())
+                    if isinstance(existing, IndexEntry):
+                        raise AssertionError("Non-conflicted entry for %r exists" % name)
+                    if stage == Stage.MERGE_CONFLICT_ANCESTOR:
+                        existing.ancestor = entry
+                    elif stage == Stage.MERGE_CONFLICT_THIS:
+                        existing.this = entry
+                    elif stage == Stage.MERGE_CONFLICT_OTHER:
+                        existing.other = entry
             # FIXME: Additional data?
             f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.check_sha()
@@ -387,60 +408,42 @@ class Index:
 
     def __len__(self) -> int:
         """Number of entries in this index file."""
-        return len(self._bynamestage)
+        return len(self._byname)
 
-    def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry:
+    def __getitem__(self, key: bytes) -> IndexEntry | ConflictedIndexEntry:
         """Retrieve entry by relative path and stage.
 
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
             flags)
         """
-        if isinstance(key, tuple):
-            return self._bynamestage[key]
-        if (key, Stage.NORMAL) in self._bynamestage:
-            return self._bynamestage[(key, Stage.NORMAL)]
-        # there is a conflict return 'this' entry
-        return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)]
+        return self._byname[key]
 
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over the paths and stages in this index."""
-        for (name, stage) in self._bynamestage:
-            if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
-                continue
-            yield name
+        return iter(self._byname)
 
     def __contains__(self, key):
-        if isinstance(key, tuple):
-            return key in self._bynamestage
-        if (key, Stage.NORMAL) in self._bynamestage:
-            return True
-        if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-            return True
-        return False
+        return key in self._byname
 
-    def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes:
+    def get_sha1(self, path: bytes) -> bytes:
         """Return the (git object) SHA1 for the object at a path."""
-        return self[(path, stage)].sha
+        return self[path].sha
 
-    def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int:
+    def get_mode(self, path: bytes) -> int:
         """Return the POSIX file mode for the object at a path."""
-        return self[(path, stage)].mode
+        return self[path].mode
 
     def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for path in self:
             entry = self[path]
+            if isinstance(entry, ConflictedIndexEntry):
+                raise UnmergedEntries()
             yield path, entry.sha, cleanup_mode(entry.mode)
 
-    def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]:
-        """Iterate over path, sha, mode tuples for use with commit_tree."""
-        for (name, stage), entry in self._bynamestage.items():
-            if stage != Stage.NORMAL:
-                yield cleanup_mode(entry.mode), entry.sha, stage, name
-
-    def has_conflicts(self):
-        for (name, stage) in self._bynamestage.keys():
-            if stage != Stage.NORMAL:
+    def has_conflicts(self) -> bool:
+        for value in self._byname.values():
+            if isinstance(value, ConflictedIndexEntry):
                 return True
         return False
 
@@ -456,65 +459,36 @@ class Index:
                            sha,
                            stage << FLAG_STAGESHIFT,
                            0)
-        if (apath, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(apath, Stage.NORMAL)]
-        self._bynamestage[(apath, stage)] = entry
+        if (apath, Stage.NORMAL) in self._byname:
+            del self._byname[(apath, Stage.NORMAL)]
+        self._byname[(apath, stage)] = entry
 
     def clear(self):
         """Remove all contents from this index."""
-        self._bynamestage = {}
+        self._byname = {}
 
-    def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None:
-        assert len(x) == len(IndexEntry._fields)
-        if isinstance(key, tuple):
-            name, stage = key
-        else:
-            name = key
-            stage = Stage.NORMAL  # default when stage not explicitly specified
+    def __setitem__(self, name: bytes, value: IndexEntry | ConflictedIndexEntry) -> None:
         assert isinstance(name, bytes)
-        # Remove merge conflict entries if new entry is stage 0
-        # Remove normal stage entry if new entry has conflicts (stage > 0)
-        if stage == Stage.NORMAL:
-            if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
-            if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
-            if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
-        if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(name, Stage.NORMAL)]
-        self._bynamestage[(name, stage)] = IndexEntry(*x)
+        if not isinstance(value, (IndexEntry, ConflictedIndexEntry)):
+            value = IndexEntry(*value)
+        self._byname[name] = value
 
-    def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None:
-        if isinstance(key, tuple):
-            del self._bynamestage[key]
-            return
-        name = key
-        assert isinstance(name, bytes)
-        if (name, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(name, Stage.NORMAL)]
-        if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
-        if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
-        if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
+    def __delitem__(self, name: bytes) -> None:
+        del self._byname[name]
 
-    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
-        for (name, stage), entry in self._bynamestage.items():
-            yield name, entry
+    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry | ConflictedIndexEntry]]:
+        return iter(self._byname.items())
 
-    def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]:
-        return iter(self._bynamestage.items())
+    def items(self) -> Iterator[Tuple[bytes, IndexEntry | ConflictedIndexEntry]]:
+        return iter(self._byname.items())
 
-    def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]):
+    def update(self, entries: Dict[bytes, IndexEntry]):
         for key, value in entries.items():
             self[key] = value
 
     def paths(self):
-        for (name, stage) in self._bynamestage.keys():
-            if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS:
-                yield name
+        for name in self._byname.keys():
+            yield name
 
     def changes_from_tree(
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
@@ -548,9 +522,6 @@ class Index:
         Returns:
           Root tree SHA
         """
-        # as with git check for unmerged entries in the index and fail if found
-        if self.has_conflicts():
-            raise UnmergedEntriesInIndexEx('Unmerged entries exist in index these need to be handled first')
         return commit_tree(object_store, self.iterobjects())
 
 
@@ -858,7 +829,7 @@ def build_index_from_tree(
             st = st.__class__(st_tuple)
             # default to a stage 0 index entry (normal)
             # when reading from the filesystem
-        index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0)
+        index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
 
     index.write()
 
@@ -962,9 +933,11 @@ def get_unstaged_changes(
 
     for tree_path, entry in index.iteritems():
         full_path = _tree_to_fs_path(root_path, tree_path)
-        stage = read_stage(entry)
-        if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
+        if isinstance(entry, ConflictedIndexEntry):
+            # Conflicted files are always unstaged
+            yield tree_path
             continue
+
         try:
             st = os.lstat(full_path)
             if stat.S_ISDIR(st.st_mode):
@@ -1143,7 +1116,7 @@ class locked_index:
             return
         try:
             f = SHA1Writer(self._file)
-            write_index_dict(f, self._index._bynamestage)
+            write_index_dict(f, self._index._byname)
         except BaseException:
             self._file.abort()
         else:
blob - a174addc22fd3dd804f2ac9ce10f8bfee56aef37
blob + 411fb6485f956dec8b810720f0a1172b9bf80eed
--- dulwich/tests/test_index.py
+++ dulwich/tests/test_index.py
@@ -43,7 +43,6 @@ from ..index import (
     get_unstaged_changes,
     index_entry_from_stat,
     read_index,
-    read_index_dict,
     validate_path_element_default,
     validate_path_element_ntfs,
     write_cache_time,
@@ -156,41 +155,8 @@ class SimpleIndexWriterTestCase(IndexTestCase):
 
         with open(filename, "rb") as x:
             self.assertEqual(entries, list(read_index(x)))
-
-
-class ReadIndexDictTests(IndexTestCase):
-    def setUp(self):
-        IndexTestCase.setUp(self)
-        self.tempdir = tempfile.mkdtemp()
-
-    def tearDown(self):
-        IndexTestCase.tearDown(self)
-        shutil.rmtree(self.tempdir)
-
-    def test_simple_write(self):
-        entries = {
-            (b"barbla", Stage.NORMAL): IndexEntry(
-                (1230680220, 0),
-                (1230680220, 0),
-                2050,
-                3761020,
-                33188,
-                1000,
-                1000,
-                0,
-                b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
-                0,
-                0,
-            )
-        }
-        filename = os.path.join(self.tempdir, "test-simple-write-index")
-        with open(filename, "wb+") as x:
-            write_index_dict(x, entries)
 
-        with open(filename, "rb") as x:
-            self.assertEqual(entries, read_index_dict(x))
 
-
 class CommitTreeTests(TestCase):
     def setUp(self):
         super().setUp()