Print conflict details with the new infrastructure (bug #11181)
[stgit] / stgit / lib / transaction.py
index 9b45729..54de127 100644 (file)
@@ -74,21 +74,34 @@ class StackTransaction(object):
       method. This will either succeed in writing the updated state to
       your refs and index+worktree, or fail without having done
       anything."""
       method. This will either succeed in writing the updated state to
       your refs and index+worktree, or fail without having done
       anything."""
-    def __init__(self, stack, msg, allow_conflicts = False):
+    def __init__(self, stack, msg, discard_changes = False,
+                 allow_conflicts = False, allow_bad_head = False):
+        """Create a new L{StackTransaction}.
+
+        @param discard_changes: Discard any changes in index+worktree
+        @type discard_changes: bool
+        @param allow_conflicts: Whether to allow pre-existing conflicts
+        @type allow_conflicts: bool or function of L{StackTransaction}"""
         self.__stack = stack
         self.__msg = msg
         self.__patches = _TransPatchMap(stack)
         self.__applied = list(self.__stack.patchorder.applied)
         self.__unapplied = list(self.__stack.patchorder.unapplied)
         self.__hidden = list(self.__stack.patchorder.hidden)
         self.__stack = stack
         self.__msg = msg
         self.__patches = _TransPatchMap(stack)
         self.__applied = list(self.__stack.patchorder.applied)
         self.__unapplied = list(self.__stack.patchorder.unapplied)
         self.__hidden = list(self.__stack.patchorder.hidden)
+        self.__conflicting_push = None
         self.__error = None
         self.__current_tree = self.__stack.head.data.tree
         self.__base = self.__stack.base
         self.__error = None
         self.__current_tree = self.__stack.head.data.tree
         self.__base = self.__stack.base
+        self.__discard_changes = discard_changes
+        self.__bad_head = None
+        self.__conflicts = None
         if isinstance(allow_conflicts, bool):
             self.__allow_conflicts = lambda trans: allow_conflicts
         else:
             self.__allow_conflicts = allow_conflicts
         self.__temp_index = self.temp_index_tree = None
         if isinstance(allow_conflicts, bool):
             self.__allow_conflicts = lambda trans: allow_conflicts
         else:
             self.__allow_conflicts = allow_conflicts
         self.__temp_index = self.temp_index_tree = None
+        if not allow_bad_head:
+            self.__assert_head_top_equal()
     stack = property(lambda self: self.__stack)
     patches = property(lambda self: self.__patches)
     def __set_applied(self, val):
     stack = property(lambda self: self.__stack)
     patches = property(lambda self: self.__patches)
     def __set_applied(self, val):
@@ -100,6 +113,8 @@ class StackTransaction(object):
     def __set_hidden(self, val):
         self.__hidden = list(val)
     hidden = property(lambda self: self.__hidden, __set_hidden)
     def __set_hidden(self, val):
         self.__hidden = list(val)
     hidden = property(lambda self: self.__hidden, __set_hidden)
+    all_patches = property(lambda self: (self.__applied + self.__unapplied
+                                         + self.__hidden))
     def __set_base(self, val):
         assert (not self.__applied
                 or self.patches[self.applied[0]].data.parent == val)
     def __set_base(self, val):
         assert (not self.__applied
                 or self.patches[self.applied[0]].data.parent == val)
@@ -111,14 +126,31 @@ class StackTransaction(object):
             self.__temp_index = self.__stack.repository.temp_index()
             atexit.register(self.__temp_index.delete)
         return self.__temp_index
             self.__temp_index = self.__stack.repository.temp_index()
             atexit.register(self.__temp_index.delete)
         return self.__temp_index
-    def __checkout(self, tree, iw):
+    @property
+    def top(self):
+        if self.__applied:
+            return self.__patches[self.__applied[-1]]
+        else:
+            return self.__base
+    def __get_head(self):
+        if self.__bad_head:
+            return self.__bad_head
+        else:
+            return self.top
+    def __set_head(self, val):
+        self.__bad_head = val
+    head = property(__get_head, __set_head)
+    def __assert_head_top_equal(self):
         if not self.__stack.head_top_equal():
             out.error(
                 'HEAD and top are not the same.',
                 'This can happen if you modify a branch with git.',
                 '"stg repair --help" explains more about what to do next.')
             self.__abort()
         if not self.__stack.head_top_equal():
             out.error(
                 'HEAD and top are not the same.',
                 'This can happen if you modify a branch with git.',
                 '"stg repair --help" explains more about what to do next.')
             self.__abort()
-        if self.__current_tree == tree:
+    def __checkout(self, tree, iw, allow_bad_head):
+        if not allow_bad_head:
+            self.__assert_head_top_equal()
+        if self.__current_tree == tree and not self.__discard_changes:
             # No tree change, but we still want to make sure that
             # there are no unresolved conflicts. Conflicts
             # conceptually "belong" to the topmost patch, and just
             # No tree change, but we still want to make sure that
             # there are no unresolved conflicts. Conflicts
             # conceptually "belong" to the topmost patch, and just
@@ -129,40 +161,40 @@ class StackTransaction(object):
             out.error('Need to resolve conflicts first')
             self.__abort()
         assert iw != None
             out.error('Need to resolve conflicts first')
             self.__abort()
         assert iw != None
-        iw.checkout(self.__current_tree, tree)
+        if self.__discard_changes:
+            iw.checkout_hard(tree)
+        else:
+            iw.checkout(self.__current_tree, tree)
         self.__current_tree = tree
     @staticmethod
     def __abort():
         raise TransactionException(
             'Command aborted (all changes rolled back)')
     def __check_consistency(self):
         self.__current_tree = tree
     @staticmethod
     def __abort():
         raise TransactionException(
             'Command aborted (all changes rolled back)')
     def __check_consistency(self):
-        remaining = set(self.__applied + self.__unapplied + self.__hidden)
+        remaining = set(self.all_patches)
         for pn, commit in self.__patches.iteritems():
             if commit == None:
                 assert self.__stack.patches.exists(pn)
             else:
                 assert pn in remaining
         for pn, commit in self.__patches.iteritems():
             if commit == None:
                 assert self.__stack.patches.exists(pn)
             else:
                 assert pn in remaining
-    @property
-    def __head(self):
-        if self.__applied:
-            return self.__patches[self.__applied[-1]]
-        else:
-            return self.__base
     def abort(self, iw = None):
         # The only state we need to restore is index+worktree.
         if iw:
     def abort(self, iw = None):
         # The only state we need to restore is index+worktree.
         if iw:
-            self.__checkout(self.__stack.head.data.tree, iw)
-    def run(self, iw = None, set_head = True):
+            self.__checkout(self.__stack.head.data.tree, iw,
+                            allow_bad_head = True)
+    def run(self, iw = None, set_head = True, allow_bad_head = False,
+            print_current_patch = True):
         """Execute the transaction. Will either succeed, or fail (with an
         exception) and do nothing."""
         self.__check_consistency()
         """Execute the transaction. Will either succeed, or fail (with an
         exception) and do nothing."""
         self.__check_consistency()
-        new_head = self.__head
+        log.log_external_mods(self.__stack)
+        new_head = self.head
 
         # Set branch head.
         if set_head:
             if iw:
                 try:
 
         # Set branch head.
         if set_head:
             if iw:
                 try:
-                    self.__checkout(new_head.data.tree, iw)
+                    self.__checkout(new_head.data.tree, iw, allow_bad_head)
                 except git.CheckoutException:
                     # We have to abort the transaction.
                     self.abort(iw)
                 except git.CheckoutException:
                     # We have to abort the transaction.
                     self.abort(iw)
@@ -170,23 +202,34 @@ class StackTransaction(object):
             self.__stack.set_head(new_head, self.__msg)
 
         if self.__error:
             self.__stack.set_head(new_head, self.__msg)
 
         if self.__error:
-            out.error(self.__error)
+            if self.__conflicts:
+                out.error(*([self.__error] + self.__conflicts))
+            else:
+                out.error(self.__error)
 
         # Write patches.
 
         # Write patches.
-        for pn, commit in self.__patches.iteritems():
-            if self.__stack.patches.exists(pn):
-                p = self.__stack.patches.get(pn)
-                if commit == None:
-                    p.delete()
+        def write(msg):
+            for pn, commit in self.__patches.iteritems():
+                if self.__stack.patches.exists(pn):
+                    p = self.__stack.patches.get(pn)
+                    if commit == None:
+                        p.delete()
+                    else:
+                        p.set_commit(commit, msg)
                 else:
                 else:
-                    p.set_commit(commit, self.__msg)
-            else:
-                self.__stack.patches.new(pn, commit, self.__msg)
-        _print_current_patch(self.__stack.patchorder.applied, self.__applied)
-        self.__stack.patchorder.applied = self.__applied
-        self.__stack.patchorder.unapplied = self.__unapplied
-        self.__stack.patchorder.hidden = self.__hidden
-        log.log_entry(self.__stack, self.__msg)
+                    self.__stack.patches.new(pn, commit, msg)
+            self.__stack.patchorder.applied = self.__applied
+            self.__stack.patchorder.unapplied = self.__unapplied
+            self.__stack.patchorder.hidden = self.__hidden
+            log.log_entry(self.__stack, msg)
+        old_applied = self.__stack.patchorder.applied
+        write(self.__msg)
+        if self.__conflicting_push != None:
+            self.__patches = _TransPatchMap(self.__stack)
+            self.__conflicting_push()
+            write(self.__msg + ' (CONFLICT)')
+        if print_current_patch:
+            _print_current_patch(old_applied, self.__applied)
 
         if self.__error:
             return utils.STGIT_CONFLICT
 
         if self.__error:
             return utils.STGIT_CONFLICT
@@ -222,7 +265,7 @@ class StackTransaction(object):
         self.__print_popped(popped)
         return popped1
 
         self.__print_popped(popped)
         return popped1
 
-    def delete_patches(self, p):
+    def delete_patches(self, p, quiet = False):
         """Delete all patches pn for which p(pn) is true. Return the list of
         other patches that had to be popped to accomplish this. Always
         succeeds."""
         """Delete all patches pn for which p(pn) is true. Return the list of
         other patches that had to be popped to accomplish this. Always
         succeeds."""
@@ -241,7 +284,8 @@ class StackTransaction(object):
             if p(pn):
                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
                 self.patches[pn] = None
             if p(pn):
                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
                 self.patches[pn] = None
-                out.info('Deleted %s%s' % (pn, s))
+                if not quiet:
+                    out.info('Deleted %s%s' % (pn, s))
         return popped
 
     def push_patch(self, pn, iw = None):
         return popped
 
     def push_patch(self, pn, iw = None):
@@ -250,20 +294,20 @@ class StackTransaction(object):
         conflicts to them."""
         orig_cd = self.patches[pn].data
         cd = orig_cd.set_committer(None)
         conflicts to them."""
         orig_cd = self.patches[pn].data
         cd = orig_cd.set_committer(None)
-        s = ['', ' (empty)'][cd.is_nochange()]
         oldparent = cd.parent
         oldparent = cd.parent
-        cd = cd.set_parent(self.__head)
+        cd = cd.set_parent(self.top)
         base = oldparent.data.tree
         ours = cd.parent.data.tree
         theirs = cd.tree
         tree, self.temp_index_tree = self.temp_index.merge(
             base, ours, theirs, self.temp_index_tree)
         base = oldparent.data.tree
         ours = cd.parent.data.tree
         theirs = cd.tree
         tree, self.temp_index_tree = self.temp_index.merge(
             base, ours, theirs, self.temp_index_tree)
+        s = ''
         merge_conflict = False
         if not tree:
             if iw == None:
                 self.__halt('%s does not apply cleanly' % pn)
             try:
         merge_conflict = False
         if not tree:
             if iw == None:
                 self.__halt('%s does not apply cleanly' % pn)
             try:
-                self.__checkout(ours, iw)
+                self.__checkout(ours, iw, allow_bad_head = False)
             except git.CheckoutException:
                 self.__halt('Index/worktree dirty')
             try:
             except git.CheckoutException:
                 self.__halt('Index/worktree dirty')
             try:
@@ -271,31 +315,44 @@ class StackTransaction(object):
                 tree = iw.index.write_tree()
                 self.__current_tree = tree
                 s = ' (modified)'
                 tree = iw.index.write_tree()
                 self.__current_tree = tree
                 s = ' (modified)'
-            except git.MergeConflictException:
+            except git.MergeConflictException, e:
                 tree = ours
                 merge_conflict = True
                 tree = ours
                 merge_conflict = True
+                self.__conflicts = e.conflicts
                 s = ' (conflict)'
             except git.MergeException, e:
                 self.__halt(str(e))
         cd = cd.set_tree(tree)
         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
                ['parent', 'tree', 'author', 'message']):
                 s = ' (conflict)'
             except git.MergeException, e:
                 self.__halt(str(e))
         cd = cd.set_tree(tree)
         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
                ['parent', 'tree', 'author', 'message']):
-            self.patches[pn] = self.__stack.repository.commit(cd)
+            comm = self.__stack.repository.commit(cd)
+            self.head = comm
         else:
         else:
+            comm = None
             s = ' (unmodified)'
             s = ' (unmodified)'
-        if pn in self.hidden:
-            x = self.hidden
-        else:
-            x = self.unapplied
-        del x[x.index(pn)]
-        self.applied.append(pn)
+        if not merge_conflict and cd.is_nochange():
+            s = ' (empty)'
         out.info('Pushed %s%s' % (pn, s))
         out.info('Pushed %s%s' % (pn, s))
+        def update():
+            if comm:
+                self.patches[pn] = comm
+            if pn in self.hidden:
+                x = self.hidden
+            else:
+                x = self.unapplied
+            del x[x.index(pn)]
+            self.applied.append(pn)
         if merge_conflict:
             # We've just caused conflicts, so we must allow them in
             # the final checkout.
             self.__allow_conflicts = lambda trans: True
 
         if merge_conflict:
             # We've just caused conflicts, so we must allow them in
             # the final checkout.
             self.__allow_conflicts = lambda trans: True
 
-            self.__halt('Merge conflict')
+            # Save this update so that we can run it a little later.
+            self.__conflicting_push = update
+            self.__halt("%d merge conflict(s)" % len(self.__conflicts))
+        else:
+            # Update immediately.
+            update()
 
     def reorder_patches(self, applied, unapplied, hidden, iw = None):
         """Push and pop patches to attain the given ordering."""
 
     def reorder_patches(self, applied, unapplied, hidden, iw = None):
         """Push and pop patches to attain the given ordering."""