Teach the new infrastructure about the index and worktree
authorKarl Hasselström <kha@treskal.com>
Sun, 9 Dec 2007 07:56:12 +0000 (08:56 +0100)
committerKarl Hasselström <kha@treskal.com>
Wed, 9 Jan 2008 23:37:12 +0000 (00:37 +0100)
And use the new powers to make "stg coalesce" able to handle arbitrary
patches, not just consecutive applied patches.

Signed-off-by: Karl Hasselström <kha@treskal.com>
stgit/commands/coalesce.py
stgit/lib/git.py
stgit/lib/stack.py
stgit/lib/transaction.py

index c4c1cf8..e3e1629 100644 (file)
@@ -27,58 +27,86 @@ help = 'coalesce two or more patches into one'
 usage = """%prog [options] <patches>
 
 Coalesce two or more patches, creating one big patch that contains all
-their changes. The patches must all be applied, and must be
-consecutive."""
+their changes.
+
+If there are conflicts when reordering the patches to match the order
+you specify, you will have to resolve them manually just as if you had
+done a sequence of pushes and pops yourself."""
 
 directory = common.DirectoryHasRepositoryLib()
 options = [make_option('-n', '--name', help = 'name of coalesced patch'),
            make_option('-m', '--message',
                        help = 'commit message of coalesced patch')]
 
-def _coalesce(stack, name, msg, patches):
-    applied = stack.patchorder.applied
+def _coalesce_patches(trans, patches, msg):
+    cd = trans.patches[patches[0]].data
+    cd = git.Commitdata(tree = cd.tree, parents = cd.parents)
+    for pn in patches[1:]:
+        c = trans.patches[pn]
+        tree = trans.stack.repository.simple_merge(
+            base = c.data.parent.data.tree,
+            ours = cd.tree, theirs = c.data.tree)
+        if not tree:
+            return None
+        cd = cd.set_tree(tree)
+    if msg == None:
+        msg = '\n\n'.join('%s\n\n%s' % (pn.ljust(70, '-'),
+                                        trans.patches[pn].data.message)
+                          for pn in patches)
+        msg = utils.edit_string(msg, '.stgit-coalesce.txt').strip()
+    cd = cd.set_message(msg)
+
+    return cd
 
-    # Make sure the patches are consecutive.
-    applied_ix = dict((applied[i], i) for i in xrange(len(applied)))
-    ixes = list(sorted(applied_ix[p] for p in patches))
-    i0, i1 = ixes[0], ixes[-1]
-    if i1 - i0 + 1 != len(patches):
-        raise common.CmdException('The patches must be consecutive')
+def _coalesce(stack, iw, name, msg, patches):
 
-    # Make a commit for the coalesced patch.
+    # If a name was supplied on the command line, make sure it's OK.
     def bad_name(pn):
         return pn not in patches and stack.patches.exists(pn)
+    def get_name(cd):
+        return name or utils.make_patch_name(cd.message, bad_name)
     if name and bad_name(name):
         raise common.CmdException('Patch name "%s" already taken')
-    ps = [stack.patches.get(pn) for pn in applied[i0:i1+1]]
-    if msg == None:
-        msg = '\n\n'.join('%s\n\n%s' % (p.name.ljust(70, '-'),
-                                        p.commit.data.message)
-                          for p in ps)
-        msg = utils.edit_string(msg, '.stgit-coalesce.txt').strip()
-    if not name:
-        name = utils.make_patch_name(msg, bad_name)
-    cd = git.Commitdata(tree = ps[-1].commit.data.tree,
-                        parents = ps[0].commit.data.parents, message = msg)
 
-    # Rewrite refs.
+    def make_coalesced_patch(trans, new_commit_data):
+        name = get_name(new_commit_data)
+        trans.patches[name] = stack.repository.commit(new_commit_data)
+        trans.unapplied.insert(0, name)
+
     trans = transaction.StackTransaction(stack, 'stg coalesce')
-    for pn in applied[i0:i1+1]:
-        trans.patches[pn] = None
-    parent = trans.patches[name] = stack.repository.commit(cd)
-    trans.applied = applied[:i0]
-    trans.applied.append(name)
-    for pn in applied[i1+1:]:
-        p = stack.patches.get(pn)
-        parent = trans.patches[pn] = stack.repository.commit(
-            p.commit.data.set_parent(parent))
-        trans.applied.append(pn)
-    trans.run()
+    push_new_patch = bool(set(patches) & set(trans.applied))
+    new_commit_data = _coalesce_patches(trans, patches, msg)
+    try:
+        if new_commit_data:
+            # We were able to construct the coalesced commit
+            # automatically. So just delete its constituent patches.
+            to_push = trans.delete_patches(lambda pn: pn in patches)
+        else:
+            # Automatic construction failed. So push the patches
+            # consecutively, so that a second construction attempt is
+            # guaranteed to work.
+            to_push = trans.pop_patches(lambda pn: pn in patches)
+            for pn in patches:
+                trans.push_patch(pn, iw)
+            new_commit_data = _coalesce_patches(trans, patches, msg)
+            assert not trans.delete_patches(lambda pn: pn in patches)
+        make_coalesced_patch(trans, new_commit_data)
+
+        # Push the new patch if necessary, and any unrelated patches we've
+        # had to pop out of the way.
+        if push_new_patch:
+            trans.push_patch(get_name(new_commit_data), iw)
+        for pn in to_push:
+            trans.push_patch(pn, iw)
+    except transaction.TransactionHalted:
+        pass
+    trans.run(iw)
 
 def func(parser, options, args):
     stack = directory.repository.current_stack
-    applied = set(stack.patchorder.applied)
-    patches = set(common.parse_patches(args, list(stack.patchorder.applied)))
+    patches = common.parse_patches(args, (list(stack.patchorder.applied)
+                                          + list(stack.patchorder.unapplied)))
     if len(patches) < 2:
         raise common.CmdException('Need at least two patches')
-    _coalesce(stack, options.name, options.message, patches)
+    _coalesce(stack, stack.repository.default_iw(),
+              options.name, options.message, patches)
index c4011f9..6aba966 100644 (file)
@@ -95,6 +95,8 @@ class Commitdata(Repr):
         return type(self)(committer = committer, defaults = self)
     def set_message(self, message):
         return type(self)(message = message, defaults = self)
+    def is_nochange(self):
+        return len(self.parents) == 1 and self.tree == self.parent.data.tree
     def __str__(self):
         if self.tree == None:
             tree = None
@@ -218,6 +220,21 @@ class Repository(RunWithEnv):
                                ).output_one_line())
         except run.RunException:
             raise RepositoryException('Cannot find git repository')
+    def default_index(self):
+        return Index(self, (os.environ.get('GIT_INDEX_FILE', None)
+                            or os.path.join(self.__git_dir, 'index')))
+    def temp_index(self):
+        return Index(self, self.__git_dir)
+    def default_worktree(self):
+        path = os.environ.get('GIT_WORK_TREE', None)
+        if not path:
+            o = run.Run('git', 'rev-parse', '--show-cdup').output_lines()
+            o = o or ['.']
+            assert len(o) == 1
+            path = o[0]
+        return Worktree(path)
+    def default_iw(self):
+        return IndexAndWorktree(self.default_index(), self.default_worktree())
     directory = property(lambda self: self.__git_dir)
     refs = property(lambda self: self.__refs)
     def cat_object(self, sha1):
@@ -258,3 +275,113 @@ class Repository(RunWithEnv):
             raise DetachedHeadException()
     def set_head(self, ref, msg):
         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
+    def simple_merge(self, base, ours, theirs):
+        """Given three trees, tries to do an in-index merge in a temporary
+        index with a temporary index. Returns the result tree, or None if
+        the merge failed (due to conflicts)."""
+        assert isinstance(base, Tree)
+        assert isinstance(ours, Tree)
+        assert isinstance(theirs, Tree)
+
+        # Take care of the really trivial cases.
+        if base == ours:
+            return theirs
+        if base == theirs:
+            return ours
+        if ours == theirs:
+            return ours
+
+        index = self.temp_index()
+        try:
+            index.merge(base, ours, theirs)
+            try:
+                return index.write_tree()
+            except MergeException:
+                return None
+        finally:
+            index.delete()
+
+class MergeException(exception.StgException):
+    pass
+
+class Index(RunWithEnv):
+    def __init__(self, repository, filename):
+        self.__repository = repository
+        if os.path.isdir(filename):
+            # Create a temp index in the given directory.
+            self.__filename = os.path.join(
+                filename, 'index.temp-%d-%x' % (os.getpid(), id(self)))
+            self.delete()
+        else:
+            self.__filename = filename
+    env = property(lambda self: utils.add_dict(
+            self.__repository.env, { 'GIT_INDEX_FILE': self.__filename }))
+    def read_tree(self, tree):
+        self.run(['git', 'read-tree', tree.sha1]).no_output()
+    def write_tree(self):
+        try:
+            return self.__repository.get_tree(
+                self.run(['git', 'write-tree']).discard_stderr(
+                    ).output_one_line())
+        except run.RunException:
+            raise MergeException('Conflicting merge')
+    def is_clean(self):
+        try:
+            self.run(['git', 'update-index', '--refresh']).discard_output()
+        except run.RunException:
+            return False
+        else:
+            return True
+    def merge(self, base, ours, theirs):
+        """In-index merge, no worktree involved."""
+        self.run(['git', 'read-tree', '-m', '-i', '--aggressive',
+                  base.sha1, ours.sha1, theirs.sha1]).no_output()
+    def delete(self):
+        if os.path.isfile(self.__filename):
+            os.remove(self.__filename)
+
+class Worktree(object):
+    def __init__(self, directory):
+        self.__directory = directory
+    env = property(lambda self: { 'GIT_WORK_TREE': self.__directory })
+
+class CheckoutException(exception.StgException):
+    pass
+
+class IndexAndWorktree(RunWithEnv):
+    def __init__(self, index, worktree):
+        self.__index = index
+        self.__worktree = worktree
+    index = property(lambda self: self.__index)
+    env = property(lambda self: utils.add_dict(self.__index.env,
+                                               self.__worktree.env))
+    def checkout(self, old_tree, new_tree):
+        # TODO: Optionally do a 3-way instead of doing nothing when we
+        # have a problem. Or maybe we should stash changes in a patch?
+        assert isinstance(old_tree, Tree)
+        assert isinstance(new_tree, Tree)
+        try:
+            self.run(['git', 'read-tree', '-u', '-m',
+                      '--exclude-per-directory=.gitignore',
+                      old_tree.sha1, new_tree.sha1]
+                     ).discard_output()
+        except run.RunException:
+            raise CheckoutException('Index/workdir dirty')
+    def merge(self, base, ours, theirs):
+        assert isinstance(base, Tree)
+        assert isinstance(ours, Tree)
+        assert isinstance(theirs, Tree)
+        try:
+            self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
+                      theirs.sha1],
+                     env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
+                             'GITHEAD_%s' % ours.sha1: 'current',
+                             'GITHEAD_%s' % theirs.sha1: 'patched'}
+                     ).discard_output()
+        except run.RunException, e:
+            raise MergeException('Index/worktree dirty')
+    def changed_files(self):
+        return self.run(['git', 'diff-files', '--name-only']).output_lines()
+    def update_index(self, files):
+        self.run(['git', 'update-index', '--remove', '-z', '--stdin']
+                 ).input_nulterm(files).discard_output()
index 0e821d9..9d47d68 100644 (file)
@@ -70,8 +70,7 @@ class Patch(object):
     def is_applied(self):
         return self.name in self.__stack.patchorder.applied
     def is_empty(self):
-        c = self.commit
-        return c.data.tree == c.data.parent.data.tree
+        return self.commit.data.is_nochange()
 
 class PatchOrder(object):
     """Keeps track of patch order, and which patches are applied.
@@ -155,6 +154,10 @@ class Stack(object):
                                     ).commit.data.parent
         else:
             return self.head
+    def head_top_equal(self):
+        if not self.patchorder.applied:
+            return True
+        return self.head == self.patches.get(self.patchorder.applied[-1]).commit
 
 class Repository(git.Repository):
     def __init__(self, *args, **kwargs):
index 991e64e..77333b3 100644 (file)
@@ -1,10 +1,14 @@
 from stgit import exception
 from stgit.out import *
+from stgit.lib import git
 
 class TransactionException(exception.StgException):
     pass
 
-def print_current_patch(old_applied, new_applied):
+class TransactionHalted(TransactionException):
+    pass
+
+def _print_current_patch(old_applied, new_applied):
     def now_at(pn):
         out.info('Now at patch "%s"' % pn)
     if not old_applied and not new_applied:
@@ -18,22 +22,48 @@ def print_current_patch(old_applied, new_applied):
     else:
         now_at(new_applied[-1])
 
+class _TransPatchMap(dict):
+    def __init__(self, stack):
+        dict.__init__(self)
+        self.__stack = stack
+    def __getitem__(self, pn):
+        try:
+            return dict.__getitem__(self, pn)
+        except KeyError:
+            return self.__stack.patches.get(pn).commit
+
 class StackTransaction(object):
     def __init__(self, stack, msg):
         self.__stack = stack
         self.__msg = msg
-        self.__patches = {}
+        self.__patches = _TransPatchMap(stack)
         self.__applied = list(self.__stack.patchorder.applied)
         self.__unapplied = list(self.__stack.patchorder.unapplied)
-    def __set_patches(self, val):
-        self.__patches = dict(val)
-    patches = property(lambda self: self.__patches, __set_patches)
+        self.__error = None
+        self.__current_tree = self.__stack.head.data.tree
+    stack = property(lambda self: self.__stack)
+    patches = property(lambda self: self.__patches)
     def __set_applied(self, val):
         self.__applied = list(val)
     applied = property(lambda self: self.__applied, __set_applied)
     def __set_unapplied(self, val):
         self.__unapplied = list(val)
     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
+    def __checkout(self, tree, iw):
+        if not self.__stack.head_top_equal():
+            out.error(
+                'HEAD and top are not the same.',
+                'This can happen if you modify a branch with git.',
+                '"stg repair --help" explains more about what to do next.')
+            self.__abort()
+        if self.__current_tree != tree:
+            assert iw != None
+            iw.checkout(self.__current_tree, tree)
+            self.__current_tree = tree
+    @staticmethod
+    def __abort():
+        raise TransactionException(
+            'Command aborted (all changes rolled back)')
     def __check_consistency(self):
         remaining = set(self.__applied + self.__unapplied)
         for pn, commit in self.__patches.iteritems():
@@ -41,29 +71,29 @@ class StackTransaction(object):
                 assert self.__stack.patches.exists(pn)
             else:
                 assert pn in remaining
-    def run(self):
-        self.__check_consistency()
-
-        # Get new head commit.
+    @property
+    def __head(self):
         if self.__applied:
-            top_patch = self.__applied[-1]
-            try:
-                new_head = self.__patches[top_patch]
-            except KeyError:
-                new_head = self.__stack.patches.get(top_patch).commit
+            return self.__patches[self.__applied[-1]]
         else:
-            new_head = self.__stack.base
+            return self.__stack.base
+    def run(self, iw = None):
+        self.__check_consistency()
+        new_head = self.__head
 
         # Set branch head.
-        if new_head == self.__stack.head:
-            pass # same commit: OK
-        elif new_head.data.tree == self.__stack.head.data.tree:
-            pass # same tree: OK
-        else:
-            # We can't handle this case yet.
-            raise TransactionException('Error: HEAD tree changed')
+        try:
+            self.__checkout(new_head.data.tree, iw)
+        except git.CheckoutException:
+            # We have to abort the transaction. The only state we need
+            # to restore is index+worktree.
+            self.__checkout(self.__stack.head.data.tree, iw)
+            self.__abort()
         self.__stack.set_head(new_head, self.__msg)
 
+        if self.__error:
+            out.error(self.__error)
+
         # Write patches.
         for pn, commit in self.__patches.iteritems():
             if self.__stack.patches.exists(pn):
@@ -74,6 +104,92 @@ class StackTransaction(object):
                     p.set_commit(commit, self.__msg)
             else:
                 self.__stack.patches.new(pn, commit, self.__msg)
-        print_current_patch(self.__stack.patchorder.applied, self.__applied)
+        _print_current_patch(self.__stack.patchorder.applied, self.__applied)
         self.__stack.patchorder.applied = self.__applied
         self.__stack.patchorder.unapplied = self.__unapplied
+
+    def __halt(self, msg):
+        self.__error = msg
+        raise TransactionHalted(msg)
+
+    @staticmethod
+    def __print_popped(popped):
+        if len(popped) == 0:
+            pass
+        elif len(popped) == 1:
+            out.info('Popped %s' % popped[0])
+        else:
+            out.info('Popped %s -- %s' % (popped[-1], popped[0]))
+
+    def pop_patches(self, p):
+        """Pop all patches pn for which p(pn) is true. Return the list of
+        other patches that had to be popped to accomplish this."""
+        popped = []
+        for i in xrange(len(self.applied)):
+            if p(self.applied[i]):
+                popped = self.applied[i:]
+                del self.applied[i:]
+                break
+        popped1 = [pn for pn in popped if not p(pn)]
+        popped2 = [pn for pn in popped if p(pn)]
+        self.unapplied = popped1 + popped2 + self.unapplied
+        self.__print_popped(popped)
+        return popped1
+
+    def delete_patches(self, p):
+        """Delete all patches pn for which p(pn) is true. Return the list of
+        other patches that had to be popped to accomplish this."""
+        popped = []
+        all_patches = self.applied + self.unapplied
+        for i in xrange(len(self.applied)):
+            if p(self.applied[i]):
+                popped = self.applied[i:]
+                del self.applied[i:]
+                break
+        popped = [pn for pn in popped if not p(pn)]
+        self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
+        self.__print_popped(popped)
+        for pn in all_patches:
+            if p(pn):
+                s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
+                self.patches[pn] = None
+                out.info('Deleted %s%s' % (pn, s))
+        return popped
+
+    def push_patch(self, pn, iw = None):
+        """Attempt to push the named patch. If this results in conflicts,
+        halts the transaction. If index+worktree are given, spill any
+        conflicts to them."""
+        i = self.unapplied.index(pn)
+        cd = self.patches[pn].data
+        s = ['', ' (empty)'][cd.is_nochange()]
+        oldparent = cd.parent
+        cd = cd.set_parent(self.__head)
+        base = oldparent.data.tree
+        ours = cd.parent.data.tree
+        theirs = cd.tree
+        tree = self.__stack.repository.simple_merge(base, ours, theirs)
+        merge_conflict = False
+        if not tree:
+            if iw == None:
+                self.__halt('%s does not apply cleanly' % pn)
+            try:
+                self.__checkout(ours, iw)
+            except git.CheckoutException:
+                self.__halt('Index/worktree dirty')
+            try:
+                iw.merge(base, ours, theirs)
+                tree = iw.index.write_tree()
+                self.__current_tree = tree
+                s = ' (modified)'
+            except git.MergeException:
+                tree = ours
+                merge_conflict = True
+                s = ' (conflict)'
+        cd = cd.set_tree(tree)
+        self.patches[pn] = self.__stack.repository.commit(cd)
+        del self.unapplied[i]
+        self.applied.append(pn)
+        out.info('Pushed %s%s' % (pn, s))
+        if merge_conflict:
+            self.__halt('Merge conflict')