Commit Diff


commit - 879fe6a17b694df8d6f4f91cdbf808a59d6a105d
commit + 52f323db90c8dc59438f2226f6025994b67095b4
blob - ccb4d4aa0a97b832e3ae767bd51d1912a3e39916
blob + f447ef2b67315091e44029b7aabe911e4608becb
--- dulwich/index.py
+++ dulwich/index.py
@@ -72,6 +72,8 @@ IndexEntry = collections.namedtuple(
 
 # 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
+FLAG_STAGESHIFT = 12
+FLAG_NAMEMASK = 0x0FFF
 
 # assume-valid
 FLAG_VALID = 0x8000
@@ -79,15 +81,28 @@ FLAG_VALID = 0x8000
 # extended flag (must be zero in version 2)
 FLAG_EXTENDED = 0x4000
 
-
 # used by sparse checkout
 EXTENDED_FLAG_SKIP_WORKTREE = 0x4000
 
 # used by "git add -N"
 EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 
-
 DEFAULT_VERSION = 2
+
+
+class UnmergedEntriesInIndexEx(Exception):
+    def __init__(self, message):
+        super().__init__(message)
+
+
+def read_stage(entry: IndexEntry) -> int:
+    """Stage of an Entry
+       0 - normal
+       1 - merge conflict 'ancestor' entry
+       2 - merge conflict 'this' entry
+       3 - merge conflict 'other' entry
+     """
+    return (entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT
 
 
 def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
@@ -168,7 +183,7 @@ def read_cache_entry(f, version: int) -> Tuple[str, In
         (extended_flags, ) = struct.unpack(">H", f.read(2))
     else:
         extended_flags = 0
-    name = f.read(flags & 0x0FFF)
+    name = f.read(flags & FLAG_NAMEMASK)
     # Padding:
     if version < 4:
         real_size = (f.tell() - beginoffset + 8) & ~7
@@ -185,7 +200,7 @@ def read_cache_entry(f, version: int) -> Tuple[str, In
             gid,
             size,
             sha_to_hex(sha),
-            flags & ~0x0FFF,
+            flags & ~FLAG_NAMEMASK,
             extended_flags,
         ))
 
@@ -200,7 +215,7 @@ def write_cache_entry(f, name: bytes, entry: IndexEntr
     beginoffset = f.tell()
     write_cache_time(f, entry.ctime)
     write_cache_time(f, entry.mtime)
-    flags = len(name) | (entry.flags & ~0x0FFF)
+    flags = len(name) | (entry.flags & ~FLAG_NAMEMASK)
     if entry.extended_flags:
         flags |= FLAG_EXTENDED
     if flags & FLAG_EXTENDED and version is not None and version < 3:
@@ -245,15 +260,17 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
 
 
-def read_index_dict(f) -> Dict[bytes, IndexEntry]:
+def read_index_dict(f) -> Dict[Tuple[bytes, int], IndexEntry]:
     """Read an index file and return it as a dictionary.
-
+       Dict Key is tuple of path and stage number, as
+            path alone is not unique
     Args:
-      f: File object to read from
+      f: File object to read fromls
     """
     ret = {}
     for name, entry in read_index(f):
-        ret[name] = entry
+        stage = read_stage(entry)
+        ret[(name, stage)] = entry
     return ret
 
 
@@ -275,13 +292,20 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes
 
 def write_index_dict(
     f: BinaryIO,
-    entries: Dict[bytes, IndexEntry],
+    entries: Dict[Tuple[bytes, int], IndexEntry],
     version: Optional[int] = None,
 ) -> None:
-    """Write an index file based on the contents of a dictionary."""
+    """Write an index file based on the contents of a dictionary.
+       being careful to sort by path and then by stage
+    """
     entries_list = []
-    for name in sorted(entries):
-        entries_list.append((name, entries[name]))
+    for key in sorted(entries):
+        if isinstance(key, tuple):
+            name, stage = key
+        else:
+            name = key
+            stage = 0
+        entries_list.append((name, entries[(name, stage)]))
     write_index(f, entries_list, version=version)
 
 
@@ -337,7 +361,7 @@ class Index:
         f = GitFile(self._filename, "wb")
         try:
             f = SHA1Writer(f)
-            write_index_dict(f, self._byname, version=self._version)
+            write_index_dict(f, self._bynamestage, version=self._version)
         finally:
             f.close()
 
@@ -349,7 +373,8 @@ class Index:
         try:
             f = SHA1Reader(f)
             for name, entry in read_index(f):
-                self[name] = entry
+                stage = read_stage(entry)
+                self[(name, stage)] = entry
             # FIXME: Additional data?
             f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.check_sha()
@@ -358,27 +383,44 @@ class Index:
 
     def __len__(self) -> int:
         """Number of entries in this index file."""
-        return len(self._byname)
+        return len(self._bynamestage)
 
-    def __getitem__(self, name: bytes) -> IndexEntry:
-        """Retrieve entry by relative path.
+    def __getitem__(self, key: Union[Tuple[bytes, int], bytes]) -> IndexEntry:
+        """Retrieve entry by relative path and stage.
 
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
             flags)
         """
-        return self._byname[name]
+        if isinstance(key, tuple):
+            return self._bynamestage[key]
+        if (key, 0) in self._bynamestage:
+            return self._bynamestage[(key, 0)]
+        # there is a conflict return 'this' entry
+        return self._bynamestage[(key, 2)]
 
     def __iter__(self) -> Iterator[bytes]:
-        """Iterate over the paths in this index."""
-        return iter(self._byname)
+        """Iterate over the paths and stages in this index."""
+        for (name, stage) in self._bynamestage:
+            if stage == 1 or stage == 3:
+                continue
+            yield name
 
-    def get_sha1(self, path: bytes) -> bytes:
+    def __contains__(self, key):
+        if isinstance(key, tuple):
+            return key in self._bynamestage
+        if (key, 0) in self._bynamestage:
+            return True
+        if (key, 2) in self._bynamestage:
+            return True
+        return False
+    
+    def get_sha1(self, path: bytes, stage: int = 0) -> bytes:
         """Return the (git object) SHA1 for the object at a path."""
-        return self[path].sha
+        return self[(path, stage)].sha
 
-    def get_mode(self, path: bytes) -> int:
+    def get_mode(self, path: bytes, stage: int = 0) -> int:
         """Return the POSIX file mode for the object at a path."""
-        return self[path].mode
+        return self[(path, stage)].mode
 
     def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
@@ -386,30 +428,90 @@ class Index:
             entry = self[path]
             yield path, entry.sha, cleanup_mode(entry.mode)
 
-    def clear(self):
-        """Remove all contents from this index."""
-        self._byname = {}
-
-    def __setitem__(self, name: bytes, x: IndexEntry) -> None:
-        assert isinstance(name, bytes)
-        assert len(x) == len(IndexEntry._fields)
-        # Remove the old entry if any
-        self._byname[name] = IndexEntry(*x)
-
-    def __delitem__(self, name: bytes) -> None:
-        assert isinstance(name, bytes)
-        del self._byname[name]
+    def iterconflicts(self) -> Iterable[Tuple[int, bytes, int, bytes]]:
+        """Iterate over path, sha, mode tuples for use with commit_tree."""
+        for (name, stage), entry in self._bynamestage.items():
+            if stage > 0:
+                yield cleanup_mode(entry.mode), entry.sha, stage, name
 
-    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
-        return self._byname.items()
+    def has_conflicts(self):
+        for (name, stage) in self._bynamestage.keys():
+            if stage > 0:
+                return True
+        return False
 
-    def items(self) -> Iterator[Tuple[bytes, IndexEntry]]:
-        return self._byname.items()
+    def set_merge_conflict(self, apath, stage, mode, sha, time):
+        entry = IndexEntry(time,
+                           time,
+                           0,
+                           0,
+                           mode,
+                           0,
+                           0,
+                           0,
+                           sha,
+                           stage << FLAG_STAGESHIFT,
+                           0)
+        if (apath, 0) in self._bynamestage:
+            del self._bynamestage[(apath, 0)]
+        self._bynamestage[(apath, stage)] = entry
 
-    def update(self, entries: Dict[bytes, IndexEntry]):
-        for name, value in entries.items():
-            self[name] = value
+    def clear(self):
+        """Remove all contents from this index."""
+        self._bynamestage = {}
+
+    def __setitem__(self, key: Union[Tuple[bytes, int], bytes], x: IndexEntry) -> None:
+        assert len(x) == len(IndexEntry._fields)
+        if isinstance(key, tuple):
+            name, stage = key
+        else:
+            name = key
+            stage = 0  # default when stage not explicitly specified
+        assert isinstance(name, bytes)
+        # Remove merge conflict entries if new entry is stage 0
+        # Remove stage 0 entry if new entry has conflicts (stage > 0)
+        if stage == 0:
+            if (name, 1) in self._bynamestage:
+                del self._bynamestage[(name, 1)]
+            if (name, 2) in self._bynamestage:
+                del self._bynamestage[(name, 2)]
+            if (name, 3) in self._bynamestage:
+                del self._bynamestage[(name, 3)]
+        if stage > 0 and (name, 0) in self._bynamestage:
+            del self._bynamestage[(name, 0)]
+        self._bynamestage[(name, stage)] = IndexEntry(*x)
+
+    def __delitem__(self, key: Union[Tuple[bytes, int], bytes]) -> None:
+        if isinstance(key, tuple):
+            del self._bynamestage[key]
+            return
+        name = key
+        assert isinstance(name, bytes)
+        if (name, 0) in self._bynamestage:
+            del self._bynamestage[(name, 0)]
+        if (name, 1) in self._bynamestage:
+            del self._bynamestage[(name, 1)]
+        if (name, 2) in self._bynamestage:
+            del self._bynamestage[(name, 2)]
+        if (name, 3) in self._bynamestage:
+            del self._bynamestage[(name, 3)]
+
+    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
+        for (name, stage), entry in self._bynamestage.items():
+            yield name, entry
+
+    def items(self) -> Iterator[Tuple[Tuple[bytes, int], IndexEntry]]:
+        return self._bynamestage.items()
+
+    def update(self, entries: Dict[Tuple[bytes, int], IndexEntry]):
+        for key, value in entries.items():
+            self[key] = value
 
+    def paths(self):
+        for (name, stage) in self._bynamestage.keys():
+            if stage == 0 or stage == 2:  # normal or conflict 'this'
+                yield name
+    
     def changes_from_tree(
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
         """Find the differences between the contents of this index and a tree.
@@ -427,7 +529,7 @@ class Index:
             return entry.sha, cleanup_mode(entry.mode)
 
         yield from changes_from_tree(
-            self._byname.keys(),
+            self.paths(),
             lookup_entry,
             object_store,
             tree,
@@ -442,6 +544,9 @@ class Index:
         Returns:
           Root tree SHA
         """
+        # as with git check for unmerged entries in the index and fail if found
+        if self.has_conflicts():
+            raise UnmergedEntriesInIndexEx('Unmerged entries exist in index these need to be handled first')
         return commit_tree(object_store, self.iterobjects())
 
 
@@ -456,6 +561,7 @@ def commit_tree(
     Returns:
       SHA1 of the created tree.
     """
+
     trees: Dict[bytes, Any] = {b"": {}}
 
     def add_tree(path):
@@ -747,7 +853,9 @@ def build_index_from_tree(
                 st.st_ctime,
             )
             st = st.__class__(st_tuple)
-        index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
+            # default to a stage 0 index entry (normal)
+            # when reading from the filesystem
+        index[(entry.path, 0)] = index_entry_from_stat(st, entry.sha, 0)
 
     index.write()
 
@@ -851,6 +959,9 @@ def get_unstaged_changes(
 
     for tree_path, entry in index.iteritems():
         full_path = _tree_to_fs_path(root_path, tree_path)
+        stage = read_stage(entry)
+        if stage == 1 or stage == 3:
+            continue
         try:
             st = os.lstat(full_path)
             if stat.S_ISDIR(st.st_mode):
@@ -1006,7 +1117,8 @@ def refresh_index(index: Index, root_path: bytes):
     """
     for path, entry in iter_fresh_entries(index, root_path):
         if entry:
-            index[path] = entry
+            stage = read_stage(entry)
+            index[(path, stage)] = entry
 
 
 class locked_index:
@@ -1028,7 +1140,7 @@ class locked_index:
             return
         try:
             f = SHA1Writer(self._file)
-            write_index_dict(f, self._index._byname)
+            write_index_dict(f, self._index._bynamestage)
         except BaseException:
             self._file.abort()
         else:
blob - f0bd4134ab622913240567d6e8e72d16855ac822
blob + 0db1c8d98220b0626d043f88fc283d18e07e7dc4
--- dulwich/tests/test_index.py
+++ dulwich/tests/test_index.py
@@ -168,7 +168,7 @@ class ReadIndexDictTests(IndexTestCase):
 
     def test_simple_write(self):
         entries = {
-            b"barbla": IndexEntry(
+            (b"barbla", 0): IndexEntry(
                 (1230680220, 0),
                 (1230680220, 0),
                 2050,