Add a --hard flag to stg reset
[stgit] / stgit / lib / transaction.py
index 6880662..84d72d5 100644 (file)
@@ -2,11 +2,12 @@
 updates to an StGit stack in a safe and convenient way."""
 
 import atexit
+import itertools as it
 
 from stgit import exception, utils
 from stgit.utils import any, all
 from stgit.out import *
-from stgit.lib import git
+from stgit.lib import git, log
 
 class TransactionException(exception.StgException):
     """Exception raised when something goes wrong with a
@@ -73,16 +74,25 @@ class StackTransaction(object):
       method. This will either succeed in writing the updated state to
       your refs and index+worktree, or fail without having done
       anything."""
-    def __init__(self, stack, msg, allow_conflicts = False):
+    def __init__(self, stack, msg, discard_changes = False,
+                 allow_conflicts = False):
+        """Create a new L{StackTransaction}.
+
+        @param discard_changes: Discard any changes in index+worktree
+        @type discard_changes: bool
+        @param allow_conflicts: Whether to allow pre-existing conflicts
+        @type allow_conflicts: bool or function of L{StackTransaction}"""
         self.__stack = stack
         self.__msg = msg
         self.__patches = _TransPatchMap(stack)
         self.__applied = list(self.__stack.patchorder.applied)
         self.__unapplied = list(self.__stack.patchorder.unapplied)
         self.__hidden = list(self.__stack.patchorder.hidden)
+        self.__conflicting_push = None
         self.__error = None
         self.__current_tree = self.__stack.head.data.tree
         self.__base = self.__stack.base
+        self.__discard_changes = discard_changes
         if isinstance(allow_conflicts, bool):
             self.__allow_conflicts = lambda trans: allow_conflicts
         else:
@@ -117,7 +127,7 @@ class StackTransaction(object):
                 'This can happen if you modify a branch with git.',
                 '"stg repair --help" explains more about what to do next.')
             self.__abort()
-        if self.__current_tree == tree:
+        if self.__current_tree == tree and not self.__discard_changes:
             # No tree change, but we still want to make sure that
             # there are no unresolved conflicts. Conflicts
             # conceptually "belong" to the topmost patch, and just
@@ -128,7 +138,10 @@ class StackTransaction(object):
             out.error('Need to resolve conflicts first')
             self.__abort()
         assert iw != None
-        iw.checkout(self.__current_tree, tree)
+        if self.__discard_changes:
+            iw.checkout_hard(tree)
+        else:
+            iw.checkout(self.__current_tree, tree)
         self.__current_tree = tree
     @staticmethod
     def __abort():
@@ -172,19 +185,27 @@ class StackTransaction(object):
             out.error(self.__error)
 
         # Write patches.
-        for pn, commit in self.__patches.iteritems():
-            if self.__stack.patches.exists(pn):
-                p = self.__stack.patches.get(pn)
-                if commit == None:
-                    p.delete()
+        def write(msg):
+            for pn, commit in self.__patches.iteritems():
+                if self.__stack.patches.exists(pn):
+                    p = self.__stack.patches.get(pn)
+                    if commit == None:
+                        p.delete()
+                    else:
+                        p.set_commit(commit, msg)
                 else:
-                    p.set_commit(commit, self.__msg)
-            else:
-                self.__stack.patches.new(pn, commit, self.__msg)
-        _print_current_patch(self.__stack.patchorder.applied, self.__applied)
-        self.__stack.patchorder.applied = self.__applied
-        self.__stack.patchorder.unapplied = self.__unapplied
-        self.__stack.patchorder.hidden = self.__hidden
+                    self.__stack.patches.new(pn, commit, msg)
+            self.__stack.patchorder.applied = self.__applied
+            self.__stack.patchorder.unapplied = self.__unapplied
+            self.__stack.patchorder.hidden = self.__hidden
+            log.log_entry(self.__stack, msg)
+        old_applied = self.__stack.patchorder.applied
+        write(self.__msg)
+        if self.__conflicting_push != None:
+            self.__patches = _TransPatchMap(self.__stack)
+            self.__conflicting_push()
+            write(self.__msg + ' (CONFLICT)')
+        _print_current_patch(old_applied, self.__applied)
 
         if self.__error:
             return utils.STGIT_CONFLICT
@@ -278,19 +299,41 @@ class StackTransaction(object):
         cd = cd.set_tree(tree)
         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
                ['parent', 'tree', 'author', 'message']):
-            self.patches[pn] = self.__stack.repository.commit(cd)
+            comm = self.__stack.repository.commit(cd)
         else:
+            comm = None
             s = ' (unmodified)'
-        if pn in self.hidden:
-            x = self.hidden
-        else:
-            x = self.unapplied
-        del x[x.index(pn)]
-        self.applied.append(pn)
         out.info('Pushed %s%s' % (pn, s))
+        def update():
+            if comm:
+                self.patches[pn] = comm
+            if pn in self.hidden:
+                x = self.hidden
+            else:
+                x = self.unapplied
+            del x[x.index(pn)]
+            self.applied.append(pn)
         if merge_conflict:
             # We've just caused conflicts, so we must allow them in
             # the final checkout.
             self.__allow_conflicts = lambda trans: True
 
+            # Save this update so that we can run it a little later.
+            self.__conflicting_push = update
             self.__halt('Merge conflict')
+        else:
+            # Update immediately.
+            update()
+
+    def reorder_patches(self, applied, unapplied, hidden, iw = None):
+        """Push and pop patches to attain the given ordering."""
+        common = len(list(it.takewhile(lambda (a, b): a == b,
+                                       zip(self.applied, applied))))
+        to_pop = set(self.applied[common:])
+        self.pop_patches(lambda pn: pn in to_pop)
+        for pn in applied[common:]:
+            self.push_patch(pn, iw)
+        assert self.applied == applied
+        assert set(self.unapplied + self.hidden) == set(unapplied + hidden)
+        self.unapplied = unapplied
+        self.hidden = hidden