Discard stderr output from git apply if the caller wants
[stgit] / stgit / lib / git.py
index c5b048f..6929698 100644 (file)
@@ -1,39 +1,72 @@
+"""A Python class hierarchy wrapping a git repository and its
+contents."""
+
 import os, os.path, re
 from datetime import datetime, timedelta, tzinfo
 
 from stgit import exception, run, utils
 from stgit.config import config
 
+class Immutable(object):
+    """I{Immutable} objects cannot be modified once created. Any
+    modification methods will return a new object, leaving the
+    original object as it was.
+
+    The reason for this is that we want to be able to represent git
+    objects, which are immutable, and want to be able to create new
+    git objects that are just slight modifications of other git
+    objects. (Such as, for example, modifying the commit message of a
+    commit object while leaving the rest of it intact. This involves
+    creating a whole new commit object that's exactly like the old one
+    except for the commit message.)
+
+    The L{Immutable} class doesn't actually enforce immutability --
+    that is up to the individual immutable subclasses. It just serves
+    as documentation."""
+
 class RepositoryException(exception.StgException):
-    pass
+    """Base class for all exceptions due to failed L{Repository}
+    operations."""
+
+class BranchException(exception.StgException):
+    """Exception raised by failed L{Branch} operations."""
 
 class DateException(exception.StgException):
+    """Exception raised when a date+time string could not be parsed."""
     def __init__(self, string, type):
         exception.StgException.__init__(
             self, '"%s" is not a valid %s' % (string, type))
 
 class DetachedHeadException(RepositoryException):
+    """Exception raised when HEAD is detached (that is, there is no
+    current branch)."""
     def __init__(self):
         RepositoryException.__init__(self, 'Not on any branch')
 
 class Repr(object):
+    """Utility class that defines C{__reps__} in terms of C{__str__}."""
     def __repr__(self):
         return str(self)
 
 class NoValue(object):
+    """A handy default value that is guaranteed to be distinct from any
+    real argument value."""
     pass
 
 def make_defaults(defaults):
-    def d(val, attr):
+    def d(val, attr, default_fun = lambda: None):
         if val != NoValue:
             return val
         elif defaults != NoValue:
             return getattr(defaults, attr)
         else:
-            return None
+            return default_fun()
     return d
 
 class TimeZone(tzinfo, Repr):
+    """A simple time zone class for static offsets from UTC. (We have to
+    define our own since Python's standard library doesn't define any
+    time zone classes.)"""
     def __init__(self, tzstring):
         m = re.match(r'^([+-])(\d{2}):?(\d{2})$', tzstring)
         if not m:
@@ -54,8 +87,8 @@ class TimeZone(tzinfo, Repr):
     def __str__(self):
         return self.__name
 
-class Date(Repr):
-    """Immutable."""
+class Date(Immutable, Repr):
+    """Represents a timestamp used in git commits."""
     def __init__(self, datestring):
         # Try git-formatted date.
         m = re.match(r'^(\d+)\s+([+-]\d\d:?\d\d)$', datestring)
@@ -88,12 +121,15 @@ class Date(Repr):
                           self.__time.tzinfo)
     @classmethod
     def maybe(cls, datestring):
+        """Return a new object initialized with the argument if it contains a
+        value (otherwise, just return the argument)."""
         if datestring in [None, NoValue]:
             return datestring
         return cls(datestring)
 
-class Person(Repr):
-    """Immutable."""
+class Person(Immutable, Repr):
+    """Represents an author or committer in a git commit object. Contains
+    name, email and timestamp."""
     def __init__(self, name = NoValue, email = NoValue,
                  date = NoValue, defaults = NoValue):
         d = make_defaults(defaults)
@@ -146,23 +182,23 @@ class Person(Repr):
                 defaults = cls.user())
         return cls.__committer
 
-class Tree(Repr):
-    """Immutable."""
+class Tree(Immutable, Repr):
+    """Represents a git tree object."""
     def __init__(self, sha1):
         self.__sha1 = sha1
     sha1 = property(lambda self: self.__sha1)
     def __str__(self):
         return 'Tree<%s>' % self.sha1
 
-class Commitdata(Repr):
-    """Immutable."""
+class CommitData(Immutable, Repr):
+    """Represents the actual data contents of a git commit object."""
     def __init__(self, tree = NoValue, parents = NoValue, author = NoValue,
                  committer = NoValue, message = NoValue, defaults = NoValue):
         d = make_defaults(defaults)
         self.__tree = d(tree, 'tree')
         self.__parents = d(parents, 'parents')
-        self.__author = d(author, 'author')
-        self.__committer = d(committer, 'committer')
+        self.__author = d(author, 'author', Person.author)
+        self.__committer = d(committer, 'committer', Person.committer)
         self.__message = d(message, 'message')
     tree = property(lambda self: self.__tree)
     parents = property(lambda self: self.__parents)
@@ -199,7 +235,7 @@ class Commitdata(Repr):
             parents = None
         else:
             parents = [p.sha1 for p in self.parents]
-        return ('Commitdata<tree: %s, parents: %s, author: %s,'
+        return ('CommitData<tree: %s, parents: %s, author: %s,'
                 ' committer: %s, message: "%s">'
                 ) % (tree, parents, self.author, self.committer, self.message)
     @classmethod
@@ -223,8 +259,10 @@ class Commitdata(Repr):
                 assert False
         assert False
 
-class Commit(Repr):
-    """Immutable."""
+class Commit(Immutable, Repr):
+    """Represents a git commit object. All the actual data contents of the
+    commit object is stored in the L{data} member, which is a
+    L{CommitData} object."""
     def __init__(self, repository, sha1):
         self.__sha1 = sha1
         self.__repository = repository
@@ -233,7 +271,7 @@ class Commit(Repr):
     @property
     def data(self):
         if self.__data == None:
-            self.__data = Commitdata.parse(
+            self.__data = CommitData.parse(
                 self.__repository,
                 self.__repository.cat_object(self.sha1))
         return self.__data
@@ -241,21 +279,26 @@ class Commit(Repr):
         return 'Commit<sha1: %s, data: %s>' % (self.sha1, self.__data)
 
 class Refs(object):
+    """Accessor for the refs stored in a git repository. Will
+    transparently cache the values of all refs."""
     def __init__(self, repository):
         self.__repository = repository
         self.__refs = None
     def __cache_refs(self):
+        """(Re-)Build the cache of all refs in the repository."""
         self.__refs = {}
         for line in self.__repository.run(['git', 'show-ref']).output_lines():
             m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line)
             sha1, ref = m.groups()
             self.__refs[ref] = sha1
     def get(self, ref):
-        """Throws KeyError if ref doesn't exist."""
+        """Get the Commit the given ref points to. Throws KeyError if ref
+        doesn't exist."""
         if self.__refs == None:
             self.__cache_refs()
         return self.__repository.get_commit(self.__refs[ref])
     def exists(self, ref):
+        """Check if the given ref exists."""
         try:
             self.get(ref)
         except KeyError:
@@ -263,6 +306,8 @@ class Refs(object):
         else:
             return True
     def set(self, ref, commit, msg):
+        """Write the sha1 of the given Commit to the ref. The ref may or may
+        not already exist."""
         if self.__refs == None:
             self.__cache_refs()
         old_sha1 = self.__refs.get(ref, '0'*40)
@@ -272,6 +317,7 @@ class Refs(object):
                                    ref, new_sha1, old_sha1]).no_output()
             self.__refs[ref] = new_sha1
     def delete(self, ref):
+        """Delete the given ref. Throws KeyError if ref doesn't exist."""
         if self.__refs == None:
             self.__cache_refs()
         self.__repository.run(['git', 'update-ref',
@@ -280,7 +326,8 @@ class Refs(object):
 
 class ObjectCache(object):
     """Cache for Python objects, for making sure that we create only one
-    Python object per git object."""
+    Python object per git object. This reduces memory consumption and
+    makes object comparison very cheap."""
     def __init__(self, create):
         self.__objects = {}
         self.__create = create
@@ -296,13 +343,27 @@ class ObjectCache(object):
 
 class RunWithEnv(object):
     def run(self, args, env = {}):
+        """Run the given command with an environment given by self.env.
+
+        @type args: list of strings
+        @param args: Command and argument vector
+        @type env: dict
+        @param env: Extra environment"""
         return run.Run(*args).env(utils.add_dict(self.env, env))
 
 class RunWithEnvCwd(RunWithEnv):
     def run(self, args, env = {}):
+        """Run the given command with an environment given by self.env, and
+        current working directory given by self.cwd.
+
+        @type args: list of strings
+        @param args: Command and argument vector
+        @type env: dict
+        @param env: Extra environment"""
         return RunWithEnv.run(self, args, env).cwd(self.cwd)
 
 class Repository(RunWithEnv):
+    """Represents a git repository."""
     def __init__(self, directory):
         self.__git_dir = directory
         self.__refs = Refs(self)
@@ -321,16 +382,25 @@ class Repository(RunWithEnv):
         except run.RunException:
             raise RepositoryException('Cannot find git repository')
     @property
+    def current_branch_name(self):
+        """Return the name of the current branch."""
+        return utils.strip_leading('refs/heads/', self.head_ref)
+    @property
     def default_index(self):
+        """An L{Index} object representing the default index file for the
+        repository."""
         if self.__default_index == None:
             self.__default_index = Index(
                 self, (os.environ.get('GIT_INDEX_FILE', None)
                        or os.path.join(self.__git_dir, 'index')))
         return self.__default_index
     def temp_index(self):
+        """Return an L{Index} object representing a new temporary index file
+        for the repository."""
         return Index(self, self.__git_dir)
     @property
     def default_worktree(self):
+        """A L{Worktree} object representing the default work tree."""
         if self.__default_worktree == None:
             path = os.environ.get('GIT_WORK_TREE', None)
             if not path:
@@ -342,6 +412,8 @@ class Repository(RunWithEnv):
         return self.__default_worktree
     @property
     def default_iw(self):
+        """An L{IndexAndWorktree} object representing the default index and
+        work tree for this repository."""
         if self.__default_iw == None:
             self.__default_iw = IndexAndWorktree(self.default_index,
                                                  self.default_worktree)
@@ -378,18 +450,18 @@ class Repository(RunWithEnv):
                                                 ).output_one_line()
         return self.get_commit(sha1)
     @property
-    def head(self):
+    def head_ref(self):
         try:
             return self.run(['git', 'symbolic-ref', '-q', 'HEAD']
                             ).output_one_line()
         except run.RunException:
             raise DetachedHeadException()
-    def set_head(self, ref, msg):
+    def set_head_ref(self, ref, msg):
         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
     def simple_merge(self, base, ours, theirs):
-        """Given three trees, tries to do an in-index merge in a temporary
-        index with a temporary index. Returns the result tree, or None if
-        the merge failed (due to conflicts)."""
+        """Given three L{Tree}s, tries to do an in-index merge with a
+        temporary index. Returns the result L{Tree}, or None if the
+        merge failed (due to conflicts)."""
         assert isinstance(base, Tree)
         assert isinstance(ours, Tree)
         assert isinstance(theirs, Tree)
@@ -411,9 +483,9 @@ class Repository(RunWithEnv):
                 return None
         finally:
             index.delete()
-    def apply(self, tree, patch_text):
-        """Given a tree and a patch, will either return the new tree that
-        results when the patch is applied, or None if the patch
+    def apply(self, tree, patch_text, quiet):
+        """Given a L{Tree} and a patch, will either return the new L{Tree}
+        that results when the patch is applied, or None if the patch
         couldn't be applied."""
         assert isinstance(tree, Tree)
         if not patch_text:
@@ -422,25 +494,33 @@ class Repository(RunWithEnv):
         try:
             index.read_tree(tree)
             try:
-                index.apply(patch_text)
+                index.apply(patch_text, quiet)
                 return index.write_tree()
             except MergeException:
                 return None
         finally:
             index.delete()
     def diff_tree(self, t1, t2, diff_opts):
+        """Given two L{Tree}s C{t1} and C{t2}, return the patch that takes
+        C{t1} to C{t2}.
+
+        @type diff_opts: list of strings
+        @param diff_opts: Extra diff options
+        @rtype: String
+        @return: Patch text"""
         assert isinstance(t1, Tree)
         assert isinstance(t2, Tree)
         return self.run(['git', 'diff-tree', '-p'] + list(diff_opts)
                         + [t1.sha1, t2.sha1]).raw_output()
 
 class MergeException(exception.StgException):
-    pass
+    """Exception raised when a merge fails for some reason."""
 
 class MergeConflictException(MergeException):
-    pass
+    """Exception raised when a merge fails due to conflicts."""
 
 class Index(RunWithEnv):
+    """Represents a git index file."""
     def __init__(self, repository, filename):
         self.__repository = repository
         if os.path.isdir(filename):
@@ -472,11 +552,13 @@ class Index(RunWithEnv):
         """In-index merge, no worktree involved."""
         self.run(['git', 'read-tree', '-m', '-i', '--aggressive',
                   base.sha1, ours.sha1, theirs.sha1]).no_output()
-    def apply(self, patch_text):
+    def apply(self, patch_text, quiet):
         """In-index patch application, no worktree involved."""
         try:
-            self.run(['git', 'apply', '--cached']
-                     ).raw_input(patch_text).no_output()
+            r = self.run(['git', 'apply', '--cached']).raw_input(patch_text)
+            if quiet:
+                r = r.discard_stderr()
+            r.no_output()
         except run.RunException:
             raise MergeException('Patch does not apply cleanly')
     def delete(self):
@@ -492,15 +574,20 @@ class Index(RunWithEnv):
         return paths
 
 class Worktree(object):
+    """Represents a git worktree (that is, a checked-out file tree)."""
     def __init__(self, directory):
         self.__directory = directory
     env = property(lambda self: { 'GIT_WORK_TREE': '.' })
     directory = property(lambda self: self.__directory)
 
 class CheckoutException(exception.StgException):
-    pass
+    """Exception raised when a checkout fails."""
 
 class IndexAndWorktree(RunWithEnvCwd):
+    """Represents a git index and a worktree. Anything that an index or
+    worktree can do on their own are handled by the L{Index} and
+    L{Worktree} classes; this class concerns itself with the
+    operations that require both."""
     def __init__(self, index, worktree):
         self.__index = index
         self.__worktree = worktree
@@ -541,3 +628,50 @@ class IndexAndWorktree(RunWithEnvCwd):
     def update_index(self, files):
         self.run(['git', 'update-index', '--remove', '-z', '--stdin']
                  ).input_nulterm(files).discard_output()
+
+class Branch(object):
+    """Represents a Git branch."""
+    def __init__(self, repository, name):
+        self.__repository = repository
+        self.__name = name
+        try:
+            self.head
+        except KeyError:
+            raise BranchException('%s: no such branch' % name)
+
+    name = property(lambda self: self.__name)
+    repository = property(lambda self: self.__repository)
+
+    def __ref(self):
+        return 'refs/heads/%s' % self.__name
+    @property
+    def head(self):
+        return self.__repository.refs.get(self.__ref())
+    def set_head(self, commit, msg):
+        self.__repository.refs.set(self.__ref(), commit, msg)
+
+    def set_parent_remote(self, name):
+        value = config.set('branch.%s.remote' % self.__name, name)
+    def set_parent_branch(self, name):
+        if config.get('branch.%s.remote' % self.__name):
+            # Never set merge if remote is not set to avoid
+            # possibly-erroneous lookups into 'origin'
+            config.set('branch.%s.merge' % self.__name, name)
+
+    @classmethod
+    def create(cls, repository, name, create_at = None):
+        """Create a new Git branch and return the corresponding
+        L{Branch} object."""
+        try:
+            branch = cls(repository, name)
+        except BranchException:
+            branch = None
+        if branch:
+            raise BranchException('%s: branch already exists' % name)
+
+        cmd = ['git', 'branch']
+        if create_at:
+            cmd.append(create_at.sha1)
+        repository.run(['git', 'branch', create_at.sha1]).discard_output()
+
+        return cls(repository, name)