Skip to content

Start adding types to Submodule, add py.typed to manifest #1282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 1, 2021
Merged
Prev Previous commit
Next Next commit
Type Tree.traverse() better
  • Loading branch information
Yobmod committed Jun 30, 2021
commit 237966a20a61237a475135ed8a13b90f65dcb2ca
6 changes: 5 additions & 1 deletion git/objects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
if TYPE_CHECKING:
from git.repo import Repo
from gitdb.base import OStream
# from .tree import Tree, Blob, Commit, TagObject
from .tree import Tree
from .blob import Blob
from .submodule.base import Submodule

IndexObjUnion = Union['Tree', 'Blob', 'Submodule']

# --------------------------------------------------------------------------

Expand Down
Empty file removed git/objects/output.txt
Empty file.
37 changes: 18 additions & 19 deletions git/objects/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from git.util import to_bin_sha

from . import util
from .base import IndexObject
from .base import IndexObject, IndexObjUnion
from .blob import Blob
from .submodule.base import Submodule

Expand All @@ -28,10 +28,11 @@

if TYPE_CHECKING:
from git.repo import Repo
from git.objects.util import TraversedTup
from io import BytesIO

T_Tree_cache = TypeVar('T_Tree_cache', bound=Union[Tuple[bytes, int, str]])
T_Tree_cache = TypeVar('T_Tree_cache', bound=Tuple[bytes, int, str])
TraversedTreeTup = Union[Tuple[Union['Tree', None], IndexObjUnion,
Tuple['Submodule', 'Submodule']]]

#--------------------------------------------------------

Expand Down Expand Up @@ -201,7 +202,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
symlink_id = 0o12
tree_id = 0o04

_map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = {
_map_id_to_type: Dict[int, Type[IndexObjUnion]] = {
commit_id: Submodule,
blob_id: Blob,
symlink_id: Blob
Expand Down Expand Up @@ -229,7 +230,7 @@ def _set_cache_(self, attr: str) -> None:
# END handle attribute

def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
) -> Iterator[Union[Blob, 'Tree', Submodule]]:
) -> Iterator[IndexObjUnion]:
"""Iterable yields tuples of (binsha, mode, name), which will be converted
to the respective object representation"""
for binsha, mode, name in iterable:
Expand All @@ -240,7 +241,7 @@ def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
# END for each item

def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
def join(self, file: str) -> IndexObjUnion:
"""Find the named object in this tree's contents
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``

Expand Down Expand Up @@ -273,7 +274,7 @@ def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
raise KeyError(msg % file)
# END handle long paths

def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]:
def __truediv__(self, file: str) -> IndexObjUnion:
"""For PY3 only"""
return self.join(file)

Expand All @@ -296,17 +297,16 @@ def cache(self) -> TreeModifier:
See the ``TreeModifier`` for more information on how to alter the cache"""
return TreeModifier(self._cache)

def traverse(self,
predicate: Callable[[Union['Tree', 'Submodule', 'Blob',
'TraversedTup'], int], bool] = lambda i, d: True,
prune: Callable[[Union['Tree', 'Submodule', 'Blob', 'TraversedTup'], int], bool] = lambda i, d: False,
def traverse(self, # type: ignore # overrides super()
predicate: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: True,
prune: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: False,
depth: int = -1,
branch_first: bool = True,
visit_once: bool = False,
ignore_self: int = 1,
as_edge: bool = False
) -> Union[Iterator[Union['Tree', 'Blob', 'Submodule']],
Iterator[Tuple[Union['Tree', 'Submodule', None], Union['Tree', 'Blob', 'Submodule']]]]:
) -> Union[Iterator[IndexObjUnion],
Iterator[TraversedTreeTup]]:
"""For documentation, see util.Traversable._traverse()
Trees are set to visit_once = False to gain more performance in the traversal"""

Expand All @@ -320,23 +320,22 @@ def traverse(self,
# ret_tup = itertools.tee(ret, 2)
# assert is_tree_traversed(ret_tup), f"Type is {[type(x) for x in list(ret_tup[0])]}"
# return ret_tup[0]"""
return cast(Union[Iterator[Union['Tree', 'Blob', 'Submodule']],
Iterator[Tuple[Union['Tree', 'Submodule', None], Union['Tree', 'Blob', 'Submodule']]]],
return cast(Union[Iterator[IndexObjUnion], Iterator[TraversedTreeTup]],
super(Tree, self).traverse(predicate, prune, depth, # type: ignore
branch_first, visit_once, ignore_self))

# List protocol

def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]:
def __getslice__(self, i: int, j: int) -> List[IndexObjUnion]:
return list(self._iter_convert_to_object(self._cache[i:j]))

def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]:
def __iter__(self) -> Iterator[IndexObjUnion]:
return self._iter_convert_to_object(self._cache)

def __len__(self) -> int:
return len(self._cache)

def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]:
def __getitem__(self, item: Union[str, int, slice]) -> IndexObjUnion:
if isinstance(item, int):
info = self._cache[item]
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
Expand All @@ -348,7 +347,7 @@ def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submo

raise TypeError("Invalid index type: %r" % item)

def __contains__(self, item: Union[IndexObject, PathLike]) -> bool:
def __contains__(self, item: Union[IndexObjUnion, PathLike]) -> bool:
if isinstance(item, IndexObject):
for info in self._cache:
if item.binsha == info[0]:
Expand Down
6 changes: 4 additions & 2 deletions git/objects/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
from .commit import Commit
from .blob import Blob
from .tag import TagObject
from .tree import Tree
from .tree import Tree, TraversedTreeTup
from subprocess import Popen


T_TIobj = TypeVar('T_TIobj', bound='TraversableIterableObj') # for TraversableIterableObj.traverse()
TraversedTup = Tuple[Union['Traversable', None], Union['Traversable', 'Blob']] # for Traversable.traverse()

TraversedTup = Union[Tuple[Union['Traversable', None], 'Traversable'], # for commit, submodule
TraversedTreeTup] # for tree.traverse()

# --------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion git/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ class IterableObj():
Subclasses = [Submodule, Commit, Reference, PushInfo, FetchInfo, Remote]"""

__slots__ = ()
_id_attribute_ = "attribute that most suitably identifies your instance"
_id_attribute_: str

@classmethod
def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> IterableList[T_IterableObj]:
Expand Down