Add automatic git-mergetool invocation to the new infrastructure
[stgit] / stgit / lib / transaction.py
index 20d87e1..4b5398a 100644 (file)
@@ -8,6 +8,7 @@ from stgit import exception, utils
 from stgit.utils import any, all
 from stgit.out import *
 from stgit.lib import git, log
+from stgit.config import config
 
 class TransactionException(exception.StgException):
     """Exception raised when something goes wrong with a
@@ -74,7 +75,15 @@ class StackTransaction(object):
       method. This will either succeed in writing the updated state to
       your refs and index+worktree, or fail without having done
       anything."""
-    def __init__(self, stack, msg, allow_conflicts = False):
+    def __init__(self, stack, msg, discard_changes = False,
+                 allow_conflicts = False, allow_bad_head = False,
+                 check_clean_iw = None):
+        """Create a new L{StackTransaction}.
+
+        @param discard_changes: Discard any changes in index+worktree
+        @type discard_changes: bool
+        @param allow_conflicts: Whether to allow pre-existing conflicts
+        @type allow_conflicts: bool or function of L{StackTransaction}"""
         self.__stack = stack
         self.__msg = msg
         self.__patches = _TransPatchMap(stack)
@@ -85,11 +94,18 @@ class StackTransaction(object):
         self.__error = None
         self.__current_tree = self.__stack.head.data.tree
         self.__base = self.__stack.base
+        self.__discard_changes = discard_changes
+        self.__bad_head = None
+        self.__conflicts = None
         if isinstance(allow_conflicts, bool):
             self.__allow_conflicts = lambda trans: allow_conflicts
         else:
             self.__allow_conflicts = allow_conflicts
         self.__temp_index = self.temp_index_tree = None
+        if not allow_bad_head:
+            self.__assert_head_top_equal()
+        if check_clean_iw:
+            self.__assert_index_worktree_clean(check_clean_iw)
     stack = property(lambda self: self.__stack)
     patches = property(lambda self: self.__patches)
     def __set_applied(self, val):
@@ -101,6 +117,8 @@ class StackTransaction(object):
     def __set_hidden(self, val):
         self.__hidden = list(val)
     hidden = property(lambda self: self.__hidden, __set_hidden)
+    all_patches = property(lambda self: (self.__applied + self.__unapplied
+                                         + self.__hidden))
     def __set_base(self, val):
         assert (not self.__applied
                 or self.patches[self.applied[0]].data.parent == val)
@@ -112,14 +130,36 @@ class StackTransaction(object):
             self.__temp_index = self.__stack.repository.temp_index()
             atexit.register(self.__temp_index.delete)
         return self.__temp_index
-    def __checkout(self, tree, iw):
+    @property
+    def top(self):
+        if self.__applied:
+            return self.__patches[self.__applied[-1]]
+        else:
+            return self.__base
+    def __get_head(self):
+        if self.__bad_head:
+            return self.__bad_head
+        else:
+            return self.top
+    def __set_head(self, val):
+        self.__bad_head = val
+    head = property(__get_head, __set_head)
+    def __assert_head_top_equal(self):
         if not self.__stack.head_top_equal():
             out.error(
                 'HEAD and top are not the same.',
                 'This can happen if you modify a branch with git.',
                 '"stg repair --help" explains more about what to do next.')
             self.__abort()
-        if self.__current_tree == tree:
+    def __assert_index_worktree_clean(self, iw):
+        if not iw.worktree_clean():
+            self.__halt('Worktree not clean. Use "refresh" or "status --reset"')
+        if not iw.index.is_clean(self.stack.head):
+            self.__halt('Index not clean. Use "refresh" or "status --reset"')
+    def __checkout(self, tree, iw, allow_bad_head):
+        if not allow_bad_head:
+            self.__assert_head_top_equal()
+        if self.__current_tree == tree and not self.__discard_changes:
             # No tree change, but we still want to make sure that
             # there are no unresolved conflicts. Conflicts
             # conceptually "belong" to the topmost patch, and just
@@ -130,40 +170,40 @@ class StackTransaction(object):
             out.error('Need to resolve conflicts first')
             self.__abort()
         assert iw != None
-        iw.checkout(self.__current_tree, tree)
+        if self.__discard_changes:
+            iw.checkout_hard(tree)
+        else:
+            iw.checkout(self.__current_tree, tree)
         self.__current_tree = tree
     @staticmethod
     def __abort():
         raise TransactionException(
             'Command aborted (all changes rolled back)')
     def __check_consistency(self):
-        remaining = set(self.__applied + self.__unapplied + self.__hidden)
+        remaining = set(self.all_patches)
         for pn, commit in self.__patches.iteritems():
             if commit == None:
                 assert self.__stack.patches.exists(pn)
             else:
                 assert pn in remaining
-    @property
-    def __head(self):
-        if self.__applied:
-            return self.__patches[self.__applied[-1]]
-        else:
-            return self.__base
     def abort(self, iw = None):
         # The only state we need to restore is index+worktree.
         if iw:
-            self.__checkout(self.__stack.head.data.tree, iw)
-    def run(self, iw = None, set_head = True):
+            self.__checkout(self.__stack.head.data.tree, iw,
+                            allow_bad_head = True)
+    def run(self, iw = None, set_head = True, allow_bad_head = False,
+            print_current_patch = True):
         """Execute the transaction. Will either succeed, or fail (with an
         exception) and do nothing."""
         self.__check_consistency()
-        new_head = self.__head
+        log.log_external_mods(self.__stack)
+        new_head = self.head
 
         # Set branch head.
         if set_head:
             if iw:
                 try:
-                    self.__checkout(new_head.data.tree, iw)
+                    self.__checkout(new_head.data.tree, iw, allow_bad_head)
                 except git.CheckoutException:
                     # We have to abort the transaction.
                     self.abort(iw)
@@ -171,7 +211,10 @@ class StackTransaction(object):
             self.__stack.set_head(new_head, self.__msg)
 
         if self.__error:
-            out.error(self.__error)
+            if self.__conflicts:
+                out.error(*([self.__error] + self.__conflicts))
+            else:
+                out.error(self.__error)
 
         # Write patches.
         def write(msg):
@@ -194,7 +237,8 @@ class StackTransaction(object):
             self.__patches = _TransPatchMap(self.__stack)
             self.__conflicting_push()
             write(self.__msg + ' (CONFLICT)')
-        _print_current_patch(old_applied, self.__applied)
+        if print_current_patch:
+            _print_current_patch(old_applied, self.__applied)
 
         if self.__error:
             return utils.STGIT_CONFLICT
@@ -230,7 +274,7 @@ class StackTransaction(object):
         self.__print_popped(popped)
         return popped1
 
-    def delete_patches(self, p):
+    def delete_patches(self, p, quiet = False):
         """Delete all patches pn for which p(pn) is true. Return the list of
         other patches that had to be popped to accomplish this. Always
         succeeds."""
@@ -249,39 +293,43 @@ class StackTransaction(object):
             if p(pn):
                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
                 self.patches[pn] = None
-                out.info('Deleted %s%s' % (pn, s))
+                if not quiet:
+                    out.info('Deleted %s%s' % (pn, s))
         return popped
 
-    def push_patch(self, pn, iw = None):
+    def push_patch(self, pn, iw = None, allow_interactive = False):
         """Attempt to push the named patch. If this results in conflicts,
         halts the transaction. If index+worktree are given, spill any
         conflicts to them."""
         orig_cd = self.patches[pn].data
         cd = orig_cd.set_committer(None)
-        s = ['', ' (empty)'][cd.is_nochange()]
         oldparent = cd.parent
-        cd = cd.set_parent(self.__head)
+        cd = cd.set_parent(self.top)
         base = oldparent.data.tree
         ours = cd.parent.data.tree
         theirs = cd.tree
         tree, self.temp_index_tree = self.temp_index.merge(
             base, ours, theirs, self.temp_index_tree)
+        s = ''
         merge_conflict = False
         if not tree:
             if iw == None:
                 self.__halt('%s does not apply cleanly' % pn)
             try:
-                self.__checkout(ours, iw)
+                self.__checkout(ours, iw, allow_bad_head = False)
             except git.CheckoutException:
                 self.__halt('Index/worktree dirty')
             try:
-                iw.merge(base, ours, theirs)
+                interactive = (allow_interactive and
+                               config.get('stgit.autoimerge') == 'yes')
+                iw.merge(base, ours, theirs, interactive = interactive)
                 tree = iw.index.write_tree()
                 self.__current_tree = tree
                 s = ' (modified)'
-            except git.MergeConflictException:
+            except git.MergeConflictException, e:
                 tree = ours
                 merge_conflict = True
+                self.__conflicts = e.conflicts
                 s = ' (conflict)'
             except git.MergeException, e:
                 self.__halt(str(e))
@@ -289,9 +337,12 @@ class StackTransaction(object):
         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
                ['parent', 'tree', 'author', 'message']):
             comm = self.__stack.repository.commit(cd)
+            self.head = comm
         else:
             comm = None
             s = ' (unmodified)'
+        if not merge_conflict and cd.is_nochange():
+            s = ' (empty)'
         out.info('Pushed %s%s' % (pn, s))
         def update():
             if comm:
@@ -309,7 +360,7 @@ class StackTransaction(object):
 
             # Save this update so that we can run it a little later.
             self.__conflicting_push = update
-            self.__halt('Merge conflict')
+            self.__halt("%d merge conflict(s)" % len(self.__conflicts))
         else:
             # Update immediately.
             update()