Commit Diff


commit - 52f323db90c8dc59438f2226f6025994b67095b4
commit + 835a3fd9c5fd049ebf5d9bb70e8cbb01ebbe8ef2
blob - f447ef2b67315091e44029b7aabe911e4608becb
blob + db0735180179361254559910548c5f38f57e7c4c
--- dulwich/index.py
+++ dulwich/index.py
@@ -25,6 +25,7 @@ import os
 import stat
 import struct
 import sys
+from enum import Enum
 from typing import (
     Any,
     BinaryIO,
@@ -90,19 +91,20 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 DEFAULT_VERSION = 2
 
 
+class Stage(Enum):
+    NORMAL = 0
+    MERGE_CONFLICT_ANCESTOR = 1
+    MERGE_CONFLICT_THIS = 2
+    MERGE_CONFLICT_OTHER = 3
+
+
 class UnmergedEntriesInIndexEx(Exception):
     def __init__(self, message):
         super().__init__(message)
 
 
-def read_stage(entry: IndexEntry) -> int:
-    """Stage of an Entry
-       0 - normal
-       1 - merge conflict 'ancestor' entry
-       2 - merge conflict 'this' entry
-       3 - merge conflict 'other' entry
-     """
-    return (entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT
+def read_stage(entry: IndexEntry) -> Stage:
+    return Stage((entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT)
 
 
 def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
@@ -155,7 +157,7 @@ def write_cache_time(f, t):
     f.write(struct.pack(">LL", *t))
 
 
-def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
+def read_cache_entry(f, version: int) -> Tuple[bytes, IndexEntry]:
     """Read an entry from a cache file.
 
     Args:
@@ -260,12 +262,12 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
 
 
-def read_index_dict(f) -> Dict[Tuple[bytes, int], IndexEntry]:
+def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]:
     """Read an index file and return it as a dictionary.
        Dict Key is tuple of path and stage number, as
             path alone is not unique
     Args:
-      f: File object to read fromls
+      f: File object to read fromls.
     """
     ret = {}
     for name, entry in read_index(f):
@@ -292,11 +294,11 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes
 
 def write_index_dict(
     f: BinaryIO,
-    entries: Dict[Tuple[bytes, int], IndexEntry],
+    entries: Dict[Tuple[bytes, Stage], IndexEntry],
     version: Optional[int] = None,
 ) -> None:
     """Write an index file based on the contents of a dictionary.
-       being careful to sort by path and then by stage
+    being careful to sort by path and then by stage.
     """
     entries_list = []
     for key in sorted(entries):
@@ -304,7 +306,7 @@ def write_index_dict(
             name, stage = key
         else:
             name = key
-            stage = 0
+            stage = Stage.NORMAL
         entries_list.append((name, entries[(name, stage)]))
     write_index(f, entries_list, version=version)
 
@@ -335,6 +337,8 @@ def cleanup_mode(mode: int) -> int:
 class Index:
     """A Git Index file."""
 
+    _bynamestage: Dict[Tuple[bytes, Stage], IndexEntry]
+
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
         """Create an index object associated with the given filename.
 
@@ -385,7 +389,7 @@ class Index:
         """Number of entries in this index file."""
         return len(self._bynamestage)
 
-    def __getitem__(self, key: Union[Tuple[bytes, int], bytes]) -> IndexEntry:
+    def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry:
         """Retrieve entry by relative path and stage.
 
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
@@ -393,32 +397,32 @@ class Index:
         """
         if isinstance(key, tuple):
             return self._bynamestage[key]
-        if (key, 0) in self._bynamestage:
-            return self._bynamestage[(key, 0)]
+        if (key, Stage.NORMAL) in self._bynamestage:
+            return self._bynamestage[(key, Stage.NORMAL)]
         # there is a conflict return 'this' entry
-        return self._bynamestage[(key, 2)]
+        return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)]
 
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over the paths and stages in this index."""
         for (name, stage) in self._bynamestage:
-            if stage == 1 or stage == 3:
+            if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
                 continue
             yield name
 
     def __contains__(self, key):
         if isinstance(key, tuple):
             return key in self._bynamestage
-        if (key, 0) in self._bynamestage:
+        if (key, Stage.NORMAL) in self._bynamestage:
             return True
-        if (key, 2) in self._bynamestage:
+        if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
             return True
         return False
-    
-    def get_sha1(self, path: bytes, stage: int = 0) -> bytes:
+
+    def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes:
         """Return the (git object) SHA1 for the object at a path."""
         return self[(path, stage)].sha
 
-    def get_mode(self, path: bytes, stage: int = 0) -> int:
+    def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int:
         """Return the POSIX file mode for the object at a path."""
         return self[(path, stage)].mode
 
@@ -428,15 +432,15 @@ class Index:
             entry = self[path]
             yield path, entry.sha, cleanup_mode(entry.mode)
 
-    def iterconflicts(self) -> Iterable[Tuple[int, bytes, int, bytes]]:
+    def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for (name, stage), entry in self._bynamestage.items():
-            if stage > 0:
+            if stage != Stage.NORMAL:
                 yield cleanup_mode(entry.mode), entry.sha, stage, name
 
     def has_conflicts(self):
         for (name, stage) in self._bynamestage.keys():
-            if stage > 0:
+            if stage != Stage.NORMAL:
                 return True
         return False
 
@@ -452,66 +456,66 @@ class Index:
                            sha,
                            stage << FLAG_STAGESHIFT,
                            0)
-        if (apath, 0) in self._bynamestage:
-            del self._bynamestage[(apath, 0)]
+        if (apath, Stage.NORMAL) in self._bynamestage:
+            del self._bynamestage[(apath, Stage.NORMAL)]
         self._bynamestage[(apath, stage)] = entry
 
     def clear(self):
         """Remove all contents from this index."""
         self._bynamestage = {}
 
-    def __setitem__(self, key: Union[Tuple[bytes, int], bytes], x: IndexEntry) -> None:
+    def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None:
         assert len(x) == len(IndexEntry._fields)
         if isinstance(key, tuple):
             name, stage = key
         else:
             name = key
-            stage = 0  # default when stage not explicitly specified
+            stage = Stage.NORMAL  # default when stage not explicitly specified
         assert isinstance(name, bytes)
         # Remove merge conflict entries if new entry is stage 0
-        # Remove stage 0 entry if new entry has conflicts (stage > 0)
-        if stage == 0:
-            if (name, 1) in self._bynamestage:
-                del self._bynamestage[(name, 1)]
-            if (name, 2) in self._bynamestage:
-                del self._bynamestage[(name, 2)]
-            if (name, 3) in self._bynamestage:
-                del self._bynamestage[(name, 3)]
-        if stage > 0 and (name, 0) in self._bynamestage:
-            del self._bynamestage[(name, 0)]
+        # Remove normal stage entry if new entry has conflicts (stage > 0)
+        if stage == Stage.NORMAL:
+            if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
+                del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
+            if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
+                del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
+            if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
+                del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
+        if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage:
+            del self._bynamestage[(name, Stage.NORMAL)]
         self._bynamestage[(name, stage)] = IndexEntry(*x)
 
-    def __delitem__(self, key: Union[Tuple[bytes, int], bytes]) -> None:
+    def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None:
         if isinstance(key, tuple):
             del self._bynamestage[key]
             return
         name = key
         assert isinstance(name, bytes)
-        if (name, 0) in self._bynamestage:
-            del self._bynamestage[(name, 0)]
-        if (name, 1) in self._bynamestage:
-            del self._bynamestage[(name, 1)]
-        if (name, 2) in self._bynamestage:
-            del self._bynamestage[(name, 2)]
-        if (name, 3) in self._bynamestage:
-            del self._bynamestage[(name, 3)]
+        if (name, Stage.NORMAL) in self._bynamestage:
+            del self._bynamestage[(name, Stage.NORMAL)]
+        if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
+            del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
+        if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
+            del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
+        if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
+            del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
 
     def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
         for (name, stage), entry in self._bynamestage.items():
             yield name, entry
 
-    def items(self) -> Iterator[Tuple[Tuple[bytes, int], IndexEntry]]:
-        return self._bynamestage.items()
+    def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]:
+        return iter(self._bynamestage.items())
 
-    def update(self, entries: Dict[Tuple[bytes, int], IndexEntry]):
+    def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]):
         for key, value in entries.items():
             self[key] = value
 
     def paths(self):
         for (name, stage) in self._bynamestage.keys():
-            if stage == 0 or stage == 2:  # normal or conflict 'this'
+            if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS:
                 yield name
-    
+
     def changes_from_tree(
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
         """Find the differences between the contents of this index and a tree.
@@ -561,7 +565,6 @@ def commit_tree(
     Returns:
       SHA1 of the created tree.
     """
-
     trees: Dict[bytes, Any] = {b"": {}}
 
     def add_tree(path):
@@ -855,7 +858,7 @@ def build_index_from_tree(
             st = st.__class__(st_tuple)
             # default to a stage 0 index entry (normal)
             # when reading from the filesystem
-        index[(entry.path, 0)] = index_entry_from_stat(st, entry.sha, 0)
+        index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0)
 
     index.write()
 
@@ -960,7 +963,7 @@ def get_unstaged_changes(
     for tree_path, entry in index.iteritems():
         full_path = _tree_to_fs_path(root_path, tree_path)
         stage = read_stage(entry)
-        if stage == 1 or stage == 3:
+        if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
             continue
         try:
             st = os.lstat(full_path)
blob - 0db1c8d98220b0626d043f88fc283d18e07e7dc4
blob + a174addc22fd3dd804f2ac9ce10f8bfee56aef37
--- dulwich/tests/test_index.py
+++ dulwich/tests/test_index.py
@@ -34,6 +34,7 @@ from dulwich.tests import TestCase, skipIf
 from ..index import (
     Index,
     IndexEntry,
+    Stage,
     _fs_to_tree_path,
     _tree_to_fs_path,
     build_index_from_tree,
@@ -168,7 +169,7 @@ class ReadIndexDictTests(IndexTestCase):
 
     def test_simple_write(self):
         entries = {
-            (b"barbla", 0): IndexEntry(
+            (b"barbla", Stage.NORMAL): IndexEntry(
                 (1230680220, 0),
                 (1230680220, 0),
                 2050,