Commit Diff


commit - dab3fc56b073c166084672ac0834eb2c35445470
commit + f25a86e97ba199716969d136237a5e1814fb0b79
blob - 3a33d3f4cc7d716df3bc38350fad8a8a3c7bc882
blob + ede53e16ec89bbfc6320db16fe4d83f721f57e8d
--- dulwich/index.py
+++ dulwich/index.py
@@ -36,6 +36,7 @@ from typing import (
     Iterable,
     Iterator,
     Tuple,
+    Union,
 )
 
 if TYPE_CHECKING:
@@ -49,6 +50,7 @@ from dulwich.objects import (
     Tree,
     hex_to_sha,
     sha_to_hex,
+    ObjectID,
 )
 from dulwich.pack import (
     SHA1Reader,
@@ -95,7 +97,7 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 DEFAULT_VERSION = 2
 
 
-def pathsplit(path):
+def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
     """Split a /-delimited path into a directory part and a basename.
 
     Args:
@@ -194,7 +196,7 @@ def read_cache_entry(f, version: int) -> Tuple[str, In
         ))
 
 
-def write_cache_entry(f, name, entry, version):
+def write_cache_entry(f, name: bytes, entry: IndexEntry, version: int) -> None:
     """Write an index entry to a file.
 
     Args:
@@ -249,7 +251,7 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
 
 
-def read_index_dict(f):
+def read_index_dict(f) -> Dict[bytes, IndexEntry]:
     """Read an index file and return it as a dictionary.
 
     Args:
@@ -314,7 +316,7 @@ def cleanup_mode(mode: int) -> int:
 class Index(object):
     """A Git Index file."""
 
-    def __init__(self, filename):
+    def __init__(self, filename: Union[bytes, str]):
         """Open an index file.
 
         Args:
@@ -391,27 +393,28 @@ class Index(object):
         """Remove all contents from this index."""
         self._byname = {}
 
-    def __setitem__(self, name, x):
+    def __setitem__(self, name: bytes, x: IndexEntry):
         assert isinstance(name, bytes)
         assert len(x) == len(IndexEntry._fields)
         # Remove the old entry if any
         self._byname[name] = IndexEntry(*x)
 
-    def __delitem__(self, name):
+    def __delitem__(self, name: bytes):
         assert isinstance(name, bytes)
         del self._byname[name]
 
-    def iteritems(self):
+    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
         return self._byname.items()
 
-    def items(self):
+    def items(self) -> Iterator[Tuple[bytes, IndexEntry]]:
         return self._byname.items()
 
-    def update(self, entries):
+    def update(self, entries: Dict[bytes, IndexEntry]):
         for name, value in entries.items():
             self[name] = value
 
-    def changes_from_tree(self, object_store, tree, want_unchanged=False):
+    def changes_from_tree(
+            self, object_store, tree: ObjectID, want_unchanged: bool = False):
         """Find the differences between the contents of this index and a tree.
 
         Args:
@@ -609,8 +612,8 @@ else:
 
 
 def build_file_from_blob(
-    blob, mode, target_path, *, honor_filemode=True, tree_encoding="utf-8",
-    symlink_fn=None
+        blob: Blob, mode: int, target_path: bytes, *, honor_filemode=True,
+        tree_encoding="utf-8", symlink_fn=None
 ):
     """Build a file or symlink on disk based on a Git object.
 
@@ -629,13 +632,12 @@ def build_file_from_blob(
         oldstat = None
     contents = blob.as_raw_string()
     if stat.S_ISLNK(mode):
-        # FIXME: This will fail on Windows. What should we do instead?
         if oldstat:
             os.unlink(target_path)
         if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
-            contents = contents.decode(tree_encoding)
-            target_path = target_path.decode(tree_encoding)
+            contents = contents.decode(tree_encoding)  # type: ignore
+            target_path = target_path.decode(tree_encoding)  # type: ignore
         (symlink_fn or symlink)(contents, target_path)
     else:
         if oldstat is not None and oldstat.st_size == len(contents):
@@ -656,11 +658,11 @@ def build_file_from_blob(
 INVALID_DOTNAMES = (b".git", b".", b"..", b"")
 
 
-def validate_path_element_default(element):
+def validate_path_element_default(element: bytes) -> bool:
     return element.lower() not in INVALID_DOTNAMES
 
 
-def validate_path_element_ntfs(element):
+def validate_path_element_ntfs(element: bytes) -> bool:
     stripped = element.rstrip(b". ").lower()
     if stripped in INVALID_DOTNAMES:
         return False
@@ -669,7 +671,8 @@ def validate_path_element_ntfs(element):
     return True
 
 
-def validate_path(path, element_validator=validate_path_element_default):
+def validate_path(path: bytes,
+                  element_validator=validate_path_element_default) -> bool:
     """Default path validator that just checks for .git/."""
     parts = path.split(b"/")
     for p in parts:
@@ -680,11 +683,11 @@ def validate_path(path, element_validator=validate_pat
 
 
 def build_index_from_tree(
-    root_path,
-    index_path,
-    object_store,
-    tree_id,
-    honor_filemode=True,
+    root_path: Union[str, bytes],
+    index_path: Union[str, bytes],
+    object_store: "BaseObjectStore",
+    tree_id: bytes,
+    honor_filemode: bool = True,
     validate_path_element=validate_path_element_default,
     symlink_fn=None
 ):
@@ -753,7 +756,8 @@ def build_index_from_tree(
     index.write()
 
 
-def blob_from_path_and_mode(fs_path, mode, tree_encoding="utf-8"):
+def blob_from_path_and_mode(fs_path: bytes, mode: int,
+                            tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
 
     Args:
@@ -766,8 +770,7 @@ def blob_from_path_and_mode(fs_path, mode, tree_encodi
     if stat.S_ISLNK(mode):
         if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
-            fs_path = os.fsdecode(fs_path)
-            blob.data = os.readlink(fs_path).encode(tree_encoding)
+            blob.data = os.readlink(os.fsdecode(fs_path)).encode(tree_encoding)
         else:
             blob.data = os.readlink(fs_path)
     else:
@@ -776,7 +779,7 @@ def blob_from_path_and_mode(fs_path, mode, tree_encodi
     return blob
 
 
-def blob_from_path_and_stat(fs_path, st, tree_encoding="utf-8"):
+def blob_from_path_and_stat(fs_path: bytes, st, tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
 
     Args:
@@ -787,7 +790,7 @@ def blob_from_path_and_stat(fs_path, st, tree_encoding
     return blob_from_path_and_mode(fs_path, st.st_mode, tree_encoding)
 
 
-def read_submodule_head(path):
+def read_submodule_head(path: Union[str, bytes]) -> Optional[bytes]:
     """Read the head commit of a submodule.
 
     Args:
@@ -811,7 +814,7 @@ def read_submodule_head(path):
         return None
 
 
-def _has_directory_changed(tree_path, entry):
+def _has_directory_changed(tree_path: bytes, entry):
     """Check if a directory has changed after getting an error.
 
     When handling an error trying to create a blob from a path, call this
@@ -836,7 +839,9 @@ def _has_directory_changed(tree_path, entry):
     return False
 
 
-def get_unstaged_changes(index: Index, root_path, filter_blob_callback=None):
+def get_unstaged_changes(
+        index: Index, root_path: Union[str, bytes],
+        filter_blob_callback=None):
     """Walk through an index and check for differences against working tree.
 
     Args:
@@ -876,7 +881,7 @@ def get_unstaged_changes(index: Index, root_path, filt
 os_sep_bytes = os.sep.encode("ascii")
 
 
-def _tree_to_fs_path(root_path, tree_path: bytes):
+def _tree_to_fs_path(root_path: bytes, tree_path: bytes):
     """Convert a git tree path to a file system path.
 
     Args:
@@ -893,7 +898,7 @@ def _tree_to_fs_path(root_path, tree_path: bytes):
     return os.path.join(root_path, sep_corrected_path)
 
 
-def _fs_to_tree_path(fs_path):
+def _fs_to_tree_path(fs_path: Union[str, bytes]) -> bytes:
     """Convert a file system path to a git tree path.
 
     Args:
@@ -912,7 +917,7 @@ def _fs_to_tree_path(fs_path):
     return tree_path
 
 
-def index_entry_from_directory(st, path):
+def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
     if os.path.exists(os.path.join(path, b".git")):
         head = read_submodule_head(path)
         if head is None:
@@ -921,7 +926,9 @@ def index_entry_from_directory(st, path):
     return None
 
 
-def index_entry_from_path(path, object_store=None):
+def index_entry_from_path(
+        path: bytes, object_store: Optional["BaseObjectStore"] = None
+) -> Optional[IndexEntry]:
     """Create an index from a filesystem path.
 
     This returns an index value for files, symlinks
@@ -949,8 +956,9 @@ def index_entry_from_path(path, object_store=None):
 
 
 def iter_fresh_entries(
-    paths, root_path, object_store: Optional["BaseObjectStore"] = None
-):
+    paths: Iterable[bytes], root_path: bytes,
+    object_store: Optional["BaseObjectStore"] = None
+) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
     """Iterate over current versions of index entries on disk.
 
     Args:
@@ -968,7 +976,10 @@ def iter_fresh_entries(
         yield path, entry
 
 
-def iter_fresh_objects(paths, root_path, include_deleted=False, object_store=None):
+def iter_fresh_objects(
+        paths: Iterable[bytes], root_path: bytes, include_deleted=False,
+        object_store=None) -> Iterator[
+            Tuple[bytes, Optional[bytes], Optional[int]]]:
     """Iterate over versions of objects on disk referenced by index.
 
     Args:
@@ -978,7 +989,8 @@ def iter_fresh_objects(paths, root_path, include_delet
       object_store: Optional object store to report new items to
     Returns: Iterator over path, sha, mode
     """
-    for path, entry in iter_fresh_entries(paths, root_path, object_store=object_store):
+    for path, entry in iter_fresh_entries(
+            paths, root_path, object_store=object_store):
         if entry is None:
             if include_deleted:
                 yield path, None, None
@@ -987,7 +999,7 @@ def iter_fresh_objects(paths, root_path, include_delet
             yield path, entry.sha, cleanup_mode(entry.mode)
 
 
-def refresh_index(index, root_path):
+def refresh_index(index: Index, root_path: bytes):
     """Refresh the contents of an index.
 
     This is the equivalent to running 'git commit -a'.
@@ -997,7 +1009,8 @@ def refresh_index(index, root_path):
       root_path: Root filesystem path
     """
     for path, entry in iter_fresh_entries(index, root_path):
-        index[path] = path
+        if entry:
+            index[path] = entry
 
 
 class locked_index(object):
@@ -1005,7 +1018,7 @@ class locked_index(object):
 
     Works as a context manager.
     """
-    def __init__(self, path):
+    def __init__(self, path: Union[bytes, str]):
         self._path = path
 
     def __enter__(self):
blob - bcea5a7f4760273a4ff5ba0e09f39cc7cde65dc5
blob + b55d35b31e4034d4e152f4cc27adb383e41b9908
--- dulwich/object_store.py
+++ dulwich/object_store.py
@@ -38,6 +38,7 @@ from dulwich.errors import (
 )
 from dulwich.file import GitFile
 from dulwich.objects import (
+    ObjectID,
     Commit,
     ShaFile,
     Tag,
@@ -67,7 +68,7 @@ from dulwich.pack import (
     PackStreamCopier,
 )
 from dulwich.protocol import DEPTH_INFINITE
-from dulwich.refs import ANNOTATED_TAG_SUFFIX
+from dulwich.refs import ANNOTATED_TAG_SUFFIX, Ref
 
 INFODIR = "info"
 PACKDIR = "pack"
@@ -83,9 +84,9 @@ class BaseObjectStore(object):
 
     def determine_wants_all(
         self,
-        refs: Dict[bytes, bytes],
+        refs: Dict[Ref, ObjectID],
         depth: Optional[int] = None
-    ) -> List[bytes]:
+    ) -> List[ObjectID]:
         def _want_deepen(sha):
             if not depth:
                 return False
@@ -139,7 +140,7 @@ class BaseObjectStore(object):
         """
         raise NotImplementedError(self.get_raw)
 
-    def __getitem__(self, sha):
+    def __getitem__(self, sha: ObjectID):
         """Obtain an object by SHA1."""
         type_num, uncomp = self.get_raw(sha)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha)
@@ -984,7 +985,7 @@ class MemoryObjectStore(BaseObjectStore):
         """List with pack objects."""
         return []
 
-    def get_raw(self, name):
+    def get_raw(self, name: ObjectID):
         """Obtain the raw text for an object.
 
         Args:
@@ -994,10 +995,10 @@ class MemoryObjectStore(BaseObjectStore):
         obj = self[self._to_hexsha(name)]
         return obj.type_num, obj.as_raw_string()
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: ObjectID):
         return self._data[self._to_hexsha(name)].copy()
 
-    def __delitem__(self, name):
+    def __delitem__(self, name: ObjectID):
         """Delete an object from this store, for testing only."""
         del self._data[self._to_hexsha(name)]
 
blob - 479c2d1e381ac7a33f2c2ef2b156665d3aa17ac4
blob + 8164a6d88390fe56b0eb31808212a26d36c300a7
--- dulwich/porcelain.py
+++ dulwich/porcelain.py
@@ -1036,6 +1036,7 @@ def tag_create(
     tag_time=None,
     tag_timezone=None,
     sign=False,
+    encoding=DEFAULT_ENCODING
 ):
     """Creates a tag in git via dulwich calls:
 
@@ -1063,7 +1064,7 @@ def tag_create(
                 # TODO(jelmer): Don't use repo private method.
                 author = r._get_user_identity(r.get_config_stack())
             tag_obj.tagger = author
-            tag_obj.message = message + "\n".encode()
+            tag_obj.message = message + "\n".encode(encoding)
             tag_obj.name = tag
             tag_obj.object = (type(object), object.id)
             if tag_time is None:
@@ -1889,10 +1890,10 @@ def reset_file(repo, file_path: str, target: bytes = b
       target: branch or commit or b'HEAD' to reset
     """
     tree = parse_tree(repo, treeish=target)
-    file_path = _fs_to_tree_path(file_path)
+    tree_path = _fs_to_tree_path(file_path)
 
-    file_entry = tree.lookup_path(repo.object_store.__getitem__, file_path)
-    full_path = os.path.join(repo.path.encode(), file_path)
+    file_entry = tree.lookup_path(repo.object_store.__getitem__, tree_path)
+    full_path = os.path.join(os.fsencode(repo.path), tree_path)
     blob = repo.object_store[file_entry[1]]
     mode = file_entry[0]
     build_file_from_blob(blob, mode, full_path, symlink_fn=symlink_fn)
blob - 89439b4b22e7fee40a30893f00779eca2eb6a6db
blob + b0a2e020130eca6f5a89b94ad9a2526b9d312222
--- dulwich/repo.py
+++ dulwich/repo.py
@@ -33,7 +33,18 @@ import os
 import sys
 import stat
 import time
-from typing import Optional, Tuple, TYPE_CHECKING, List, Dict, Union, Iterable
+from typing import (
+    Optional,
+    BinaryIO,
+    Callable,
+    Tuple,
+    TYPE_CHECKING,
+    List,
+    Dict,
+    Union,
+    Iterable,
+    Set
+)
 
 if TYPE_CHECKING:
     # There are no circular imports here, but we try to defer imports as long
@@ -70,6 +81,7 @@ from dulwich.objects import (
     ShaFile,
     Tag,
     Tree,
+    ObjectID,
 )
 from dulwich.pack import (
     pack_objects_to_data,
@@ -86,6 +98,7 @@ from dulwich.hooks import (
 from dulwich.line_ending import BlobNormalizer, TreeBlobNormalizer
 
 from dulwich.refs import (  # noqa: F401
+    Ref,
     ANNOTATED_TAG_SUFFIX,
     LOCAL_BRANCH_PREFIX,
     LOCAL_TAG_PREFIX,
@@ -389,7 +402,7 @@ class BaseRepo(object):
         self._put_named_file("config", f.getvalue())
         self._put_named_file(os.path.join("info", "exclude"), b"")
 
-    def get_named_file(self, path):
+    def get_named_file(self, path: str) -> Optional[BinaryIO]:
         """Get a file from the control dir with a specific name.
 
         Although the filename should be interpreted as a filename relative to
@@ -402,7 +415,7 @@ class BaseRepo(object):
         """
         raise NotImplementedError(self.get_named_file)
 
-    def _put_named_file(self, path, contents):
+    def _put_named_file(self, path: str, contents: bytes):
         """Write a file to the control dir with the given name and contents.
 
         Args:
@@ -411,11 +424,11 @@ class BaseRepo(object):
         """
         raise NotImplementedError(self._put_named_file)
 
-    def _del_named_file(self, path):
+    def _del_named_file(self, path: str):
         """Delete a file in the control directory with the given name."""
         raise NotImplementedError(self._del_named_file)
 
-    def open_index(self):
+    def open_index(self) -> "Index":
         """Open the index for this repository.
 
         Raises:
@@ -562,7 +575,9 @@ class BaseRepo(object):
             )
         )
 
-    def generate_pack_data(self, have, want, progress=None, ofs_delta=None):
+    def generate_pack_data(self, have: List[ObjectID], want: List[ObjectID],
+                           progress: Optional[Callable[[str], None]] = None,
+                           ofs_delta: Optional[bool] = None):
         """Generate pack data objects for a set of wants/haves.
 
         Args:
@@ -579,7 +594,8 @@ class BaseRepo(object):
             ofs_delta=ofs_delta,
         )
 
-    def get_graph_walker(self, heads=None):
+    def get_graph_walker(
+            self, heads: List[ObjectID] = None) -> ObjectStoreGraphWalker:
         """Retrieve a graph walker.
 
         A graph walker is used by a remote repository (or proxy)
@@ -640,7 +656,7 @@ class BaseRepo(object):
         """
         return self.object_store[sha]
 
-    def parents_provider(self):
+    def parents_provider(self) -> ParentsProvider:
         return ParentsProvider(
             self.object_store,
             grafts=self._graftpoints,
@@ -660,7 +676,7 @@ class BaseRepo(object):
         """
         return self.parents_provider().get_parents(sha, commit)
 
-    def get_config(self):
+    def get_config(self) -> "ConfigFile":
         """Retrieve the config object.
 
         Returns: `ConfigFile` object for the ``.git/config`` file.
@@ -697,7 +713,7 @@ class BaseRepo(object):
         backends = [self.get_config()] + StackedConfig.default_backends()
         return StackedConfig(backends, writable=backends[0])
 
-    def get_shallow(self):
+    def get_shallow(self) -> Set[ObjectID]:
         """Get the set of shallow commits.
 
         Returns: Set of shallow commits.
@@ -727,7 +743,7 @@ class BaseRepo(object):
         else:
             self._del_named_file("shallow")
 
-    def get_peeled(self, ref):
+    def get_peeled(self, ref: Ref) -> ObjectID:
         """Get the peeled value of a ref.
 
         Args:
@@ -777,7 +793,7 @@ class BaseRepo(object):
 
         return Walker(self.object_store, include, *args, **kwargs)
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: Union[ObjectID, Ref]):
         """Retrieve a Git object by SHA1 or ref.
 
         Args:
@@ -876,19 +892,19 @@ class BaseRepo(object):
 
     def do_commit(  # noqa: C901
         self,
-        message=None,
-        committer=None,
-        author=None,
+        message: Optional[bytes] = None,
+        committer: Optional[bytes] = None,
+        author: Optional[bytes] = None,
         commit_timestamp=None,
         commit_timezone=None,
         author_timestamp=None,
         author_timezone=None,
-        tree=None,
-        encoding=None,
-        ref=b"HEAD",
-        merge_heads=None,
-        no_verify=False,
-        sign=False,
+        tree: Optional[ObjectID] = None,
+        encoding: Optional[bytes] = None,
+        ref: Ref = b"HEAD",
+        merge_heads: Optional[List[ObjectID]] = None,
+        no_verify: bool = False,
+        sign: bool = False,
     ):
         """Create a new commit.
 
@@ -1026,7 +1042,7 @@ class BaseRepo(object):
             if not ok:
                 # Fail if the atomic compare-and-swap failed, leaving the
                 # commit and all its objects as garbage.
-                raise CommitError("%s changed during commit" % (ref,))
+                raise CommitError(f"{ref!r} changed during commit")
 
         self._del_named_file("MERGE_HEAD")
 
@@ -1540,7 +1556,7 @@ class Repo(BaseRepo):
             raise
         return target
 
-    def reset_index(self, tree: Optional[Tree] = None):
+    def reset_index(self, tree: Optional[bytes] = None):
         """Reset the index back to a specific tree.
 
         Args: