commit 52328a607b0de317a4d50f1763c91b054df3c1b1 from: Jelmer Vernooij date: Fri Jul 21 00:38:17 2023 UTC Split out a separate ConflictedIndexEntry class This should make the changes transparent to existing API users (so long as they don't work in trees with conflicts). It also prevents repeated dictionary access. commit - 98538380976f8964720f2eefc9018b50ecb3832a commit + 52328a607b0de317a4d50f1763c91b054df3c1b1 blob - db0735180179361254559910548c5f38f57e7c4c blob + e9266b05d801c6d50cc6a8b1149abd5749b773ce --- dulwich/index.py +++ dulwich/index.py @@ -69,6 +69,27 @@ IndexEntry = collections.namedtuple( "extended_flags", ], ) + + +class ConflictedIndexEntry: + """Index entry that represents a conflict.""" + + ancestor: Optional[IndexEntry] + this: Optional[IndexEntry] + other: Optional[IndexEntry] + + def __init__(self): + self.ancestor = None + self.this = None + self.other = None + + def entries(self) -> Iterable[IndexEntry]: + if self.ancestor: + yield self.ancestor + if self.this: + yield self.this + if self.other: + yield self.other # 2-bit stage (during merge) @@ -90,7 +111,6 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000 DEFAULT_VERSION = 2 - class Stage(Enum): NORMAL = 0 MERGE_CONFLICT_ANCESTOR = 1 @@ -98,9 +118,8 @@ class Stage(Enum): MERGE_CONFLICT_OTHER = 3 -class UnmergedEntriesInIndexEx(Exception): - def __init__(self, message): - super().__init__(message) +class UnmergedEntries(Exception): + """Unmerged entries exist in the index""" def read_stage(entry: IndexEntry) -> Stage: @@ -250,7 +269,7 @@ class UnsupportedIndexFormat(Exception): self.index_format_version = version -def read_index(f: BinaryIO): +def read_index(f: BinaryIO) -> Iterator[Tuple[bytes, IndexEntry]]: """Read an index file, yielding the individual entries.""" header = f.read(4) if header != b"DIRC": @@ -262,20 +281,6 @@ def read_index(f: BinaryIO): yield read_cache_entry(f, version) -def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]: - """Read an index file and return it as a dictionary. - Dict Key is tuple of path and stage number, as - path alone is not unique - Args: - f: File object to read fromls. - """ - ret = {} - for name, entry in read_index(f): - stage = read_stage(entry) - ret[(name, stage)] = entry - return ret - - def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: Optional[int] = None): """Write an index file. @@ -294,7 +299,7 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes def write_index_dict( f: BinaryIO, - entries: Dict[Tuple[bytes, Stage], IndexEntry], + entries: Dict[bytes, IndexEntry | ConflictedIndexEntry], version: Optional[int] = None, ) -> None: """Write an index file based on the contents of a dictionary. @@ -302,12 +307,16 @@ def write_index_dict( """ entries_list = [] for key in sorted(entries): - if isinstance(key, tuple): - name, stage = key + value = entries[key] + if isinstance(value, ConflictedIndexEntry): + if value.ancestor is not None: + entries_list.append((key, value.ancestor)) + if value.this is not None: + entries_list.append((key, value.this)) + if value.other is not None: + entries_list.append((key, value.other)) else: - name = key - stage = Stage.NORMAL - entries_list.append((name, entries[(name, stage)])) + entries_list.append((key, value)) write_index(f, entries_list, version=version) @@ -337,7 +346,7 @@ def cleanup_mode(mode: int) -> int: class Index: """A Git Index file.""" - _bynamestage: Dict[Tuple[bytes, Stage], IndexEntry] + _byname: Dict[bytes, IndexEntry | ConflictedIndexEntry] def __init__(self, filename: Union[bytes, str], read=True) -> None: """Create an index object associated with the given filename. @@ -365,7 +374,7 @@ class Index: f = GitFile(self._filename, "wb") try: f = SHA1Writer(f) - write_index_dict(f, self._bynamestage, version=self._version) + write_index_dict(f, self._byname, version=self._version) finally: f.close() @@ -378,7 +387,19 @@ class Index: f = SHA1Reader(f) for name, entry in read_index(f): stage = read_stage(entry) - self[(name, stage)] = entry + if stage == Stage.NORMAL: + self[name] = entry + else: + import pdb; pdb.set_trace() + existing = self._byname.setdefault(name, ConflictedIndexEntry()) + if isinstance(existing, IndexEntry): + raise AssertionError("Non-conflicted entry for %r exists" % name) + if stage == Stage.MERGE_CONFLICT_ANCESTOR: + existing.ancestor = entry + elif stage == Stage.MERGE_CONFLICT_THIS: + existing.this = entry + elif stage == Stage.MERGE_CONFLICT_OTHER: + existing.other = entry # FIXME: Additional data? f.read(os.path.getsize(self._filename) - f.tell() - 20) f.check_sha() @@ -387,60 +408,42 @@ class Index: def __len__(self) -> int: """Number of entries in this index file.""" - return len(self._bynamestage) + return len(self._byname) - def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry: + def __getitem__(self, key: bytes) -> IndexEntry | ConflictedIndexEntry: """Retrieve entry by relative path and stage. Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha, flags) """ - if isinstance(key, tuple): - return self._bynamestage[key] - if (key, Stage.NORMAL) in self._bynamestage: - return self._bynamestage[(key, Stage.NORMAL)] - # there is a conflict return 'this' entry - return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)] + return self._byname[key] def __iter__(self) -> Iterator[bytes]: """Iterate over the paths and stages in this index.""" - for (name, stage) in self._bynamestage: - if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER: - continue - yield name + return iter(self._byname) def __contains__(self, key): - if isinstance(key, tuple): - return key in self._bynamestage - if (key, Stage.NORMAL) in self._bynamestage: - return True - if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage: - return True - return False + return key in self._byname - def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes: + def get_sha1(self, path: bytes) -> bytes: """Return the (git object) SHA1 for the object at a path.""" - return self[(path, stage)].sha + return self[path].sha - def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int: + def get_mode(self, path: bytes) -> int: """Return the POSIX file mode for the object at a path.""" - return self[(path, stage)].mode + return self[path].mode def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]: """Iterate over path, sha, mode tuples for use with commit_tree.""" for path in self: entry = self[path] + if isinstance(entry, ConflictedIndexEntry): + raise UnmergedEntries() yield path, entry.sha, cleanup_mode(entry.mode) - def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]: - """Iterate over path, sha, mode tuples for use with commit_tree.""" - for (name, stage), entry in self._bynamestage.items(): - if stage != Stage.NORMAL: - yield cleanup_mode(entry.mode), entry.sha, stage, name - - def has_conflicts(self): - for (name, stage) in self._bynamestage.keys(): - if stage != Stage.NORMAL: + def has_conflicts(self) -> bool: + for value in self._byname.values(): + if isinstance(value, ConflictedIndexEntry): return True return False @@ -456,65 +459,36 @@ class Index: sha, stage << FLAG_STAGESHIFT, 0) - if (apath, Stage.NORMAL) in self._bynamestage: - del self._bynamestage[(apath, Stage.NORMAL)] - self._bynamestage[(apath, stage)] = entry + if (apath, Stage.NORMAL) in self._byname: + del self._byname[(apath, Stage.NORMAL)] + self._byname[(apath, stage)] = entry def clear(self): """Remove all contents from this index.""" - self._bynamestage = {} + self._byname = {} - def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None: - assert len(x) == len(IndexEntry._fields) - if isinstance(key, tuple): - name, stage = key - else: - name = key - stage = Stage.NORMAL # default when stage not explicitly specified + def __setitem__(self, name: bytes, value: IndexEntry | ConflictedIndexEntry) -> None: assert isinstance(name, bytes) - # Remove merge conflict entries if new entry is stage 0 - # Remove normal stage entry if new entry has conflicts (stage > 0) - if stage == Stage.NORMAL: - if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage: - del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)] - if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage: - del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)] - if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage: - del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)] - if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage: - del self._bynamestage[(name, Stage.NORMAL)] - self._bynamestage[(name, stage)] = IndexEntry(*x) + if not isinstance(value, (IndexEntry, ConflictedIndexEntry)): + value = IndexEntry(*value) + self._byname[name] = value - def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None: - if isinstance(key, tuple): - del self._bynamestage[key] - return - name = key - assert isinstance(name, bytes) - if (name, Stage.NORMAL) in self._bynamestage: - del self._bynamestage[(name, Stage.NORMAL)] - if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage: - del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)] - if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage: - del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)] - if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage: - del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)] + def __delitem__(self, name: bytes) -> None: + del self._byname[name] - def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]: - for (name, stage), entry in self._bynamestage.items(): - yield name, entry + def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry | ConflictedIndexEntry]]: + return iter(self._byname.items()) - def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]: - return iter(self._bynamestage.items()) + def items(self) -> Iterator[Tuple[bytes, IndexEntry | ConflictedIndexEntry]]: + return iter(self._byname.items()) - def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]): + def update(self, entries: Dict[bytes, IndexEntry]): for key, value in entries.items(): self[key] = value def paths(self): - for (name, stage) in self._bynamestage.keys(): - if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS: - yield name + for name in self._byname.keys(): + yield name def changes_from_tree( self, object_store, tree: ObjectID, want_unchanged: bool = False): @@ -548,9 +522,6 @@ class Index: Returns: Root tree SHA """ - # as with git check for unmerged entries in the index and fail if found - if self.has_conflicts(): - raise UnmergedEntriesInIndexEx('Unmerged entries exist in index these need to be handled first') return commit_tree(object_store, self.iterobjects()) @@ -858,7 +829,7 @@ def build_index_from_tree( st = st.__class__(st_tuple) # default to a stage 0 index entry (normal) # when reading from the filesystem - index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0) + index[entry.path] = index_entry_from_stat(st, entry.sha, 0) index.write() @@ -962,9 +933,11 @@ def get_unstaged_changes( for tree_path, entry in index.iteritems(): full_path = _tree_to_fs_path(root_path, tree_path) - stage = read_stage(entry) - if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER: + if isinstance(entry, ConflictedIndexEntry): + # Conflicted files are always unstaged + yield tree_path continue + try: st = os.lstat(full_path) if stat.S_ISDIR(st.st_mode): @@ -1143,7 +1116,7 @@ class locked_index: return try: f = SHA1Writer(self._file) - write_index_dict(f, self._index._bynamestage) + write_index_dict(f, self._index._byname) except BaseException: self._file.abort() else: blob - a174addc22fd3dd804f2ac9ce10f8bfee56aef37 blob + 411fb6485f956dec8b810720f0a1172b9bf80eed --- dulwich/tests/test_index.py +++ dulwich/tests/test_index.py @@ -43,7 +43,6 @@ from ..index import ( get_unstaged_changes, index_entry_from_stat, read_index, - read_index_dict, validate_path_element_default, validate_path_element_ntfs, write_cache_time, @@ -156,41 +155,8 @@ class SimpleIndexWriterTestCase(IndexTestCase): with open(filename, "rb") as x: self.assertEqual(entries, list(read_index(x))) - - -class ReadIndexDictTests(IndexTestCase): - def setUp(self): - IndexTestCase.setUp(self) - self.tempdir = tempfile.mkdtemp() - - def tearDown(self): - IndexTestCase.tearDown(self) - shutil.rmtree(self.tempdir) - - def test_simple_write(self): - entries = { - (b"barbla", Stage.NORMAL): IndexEntry( - (1230680220, 0), - (1230680220, 0), - 2050, - 3761020, - 33188, - 1000, - 1000, - 0, - b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", - 0, - 0, - ) - } - filename = os.path.join(self.tempdir, "test-simple-write-index") - with open(filename, "wb+") as x: - write_index_dict(x, entries) - with open(filename, "rb") as x: - self.assertEqual(entries, read_index_dict(x)) - class CommitTreeTests(TestCase): def setUp(self): super().setUp()