X-Git-Url: https://git.distorted.org.uk/~mdw/stgit/blobdiff_plain/d851da81409f07105c55d5088157e276cbc4fe13..a3d7efccc515eeb12001a46ea0b781133768b23c:/stgit/lib/transaction.py diff --git a/stgit/lib/transaction.py b/stgit/lib/transaction.py index 1ece01e..20d87e1 100644 --- a/stgit/lib/transaction.py +++ b/stgit/lib/transaction.py @@ -1,13 +1,22 @@ +"""The L{StackTransaction} class makes it possible to make complex +updates to an StGit stack in a safe and convenient way.""" + +import atexit +import itertools as it + from stgit import exception, utils from stgit.utils import any, all from stgit.out import * -from stgit.lib import git +from stgit.lib import git, log class TransactionException(exception.StgException): - pass + """Exception raised when something goes wrong with a + L{StackTransaction}.""" class TransactionHalted(TransactionException): - pass + """Exception raised when a L{StackTransaction} stops part-way through. + Used to make a non-local jump from the transaction setup to the + part of the transaction code where the transaction is run.""" def _print_current_patch(old_applied, new_applied): def now_at(pn): @@ -24,6 +33,7 @@ def _print_current_patch(old_applied, new_applied): now_at(new_applied[-1]) class _TransPatchMap(dict): + """Maps patch names to sha1 strings.""" def __init__(self, stack): dict.__init__(self) self.__stack = stack @@ -34,15 +44,52 @@ class _TransPatchMap(dict): return self.__stack.patches.get(pn).commit class StackTransaction(object): - def __init__(self, stack, msg): + """A stack transaction, used for making complex updates to an StGit + stack in one single operation that will either succeed or fail + cleanly. + + The basic theory of operation is the following: + + 1. Create a transaction object. + + 2. Inside a:: + + try + ... + except TransactionHalted: + pass + + block, update the transaction with e.g. methods like + L{pop_patches} and L{push_patch}. This may create new git + objects such as commits, but will not write any refs; this means + that in case of a fatal error we can just walk away, no clean-up + required. + + (Some operations may need to touch your index and working tree, + though. But they are cleaned up when needed.) + + 3. After the C{try} block -- wheher or not the setup ran to + completion or halted part-way through by raising a + L{TransactionHalted} exception -- call the transaction's L{run} + method. This will either succeed in writing the updated state to + your refs and index+worktree, or fail without having done + anything.""" + def __init__(self, stack, msg, allow_conflicts = False): self.__stack = stack self.__msg = msg self.__patches = _TransPatchMap(stack) self.__applied = list(self.__stack.patchorder.applied) self.__unapplied = list(self.__stack.patchorder.unapplied) + self.__hidden = list(self.__stack.patchorder.hidden) + self.__conflicting_push = None self.__error = None self.__current_tree = self.__stack.head.data.tree self.__base = self.__stack.base + if isinstance(allow_conflicts, bool): + self.__allow_conflicts = lambda trans: allow_conflicts + else: + self.__allow_conflicts = allow_conflicts + self.__temp_index = self.temp_index_tree = None stack = property(lambda self: self.__stack) patches = property(lambda self: self.__patches) def __set_applied(self, val): @@ -51,11 +98,20 @@ class StackTransaction(object): def __set_unapplied(self, val): self.__unapplied = list(val) unapplied = property(lambda self: self.__unapplied, __set_unapplied) + def __set_hidden(self, val): + self.__hidden = list(val) + hidden = property(lambda self: self.__hidden, __set_hidden) def __set_base(self, val): assert (not self.__applied or self.patches[self.applied[0]].data.parent == val) self.__base = val base = property(lambda self: self.__base, __set_base) + @property + def temp_index(self): + if not self.__temp_index: + self.__temp_index = self.__stack.repository.temp_index() + atexit.register(self.__temp_index.delete) + return self.__temp_index def __checkout(self, tree, iw): if not self.__stack.head_top_equal(): out.error( @@ -63,16 +119,25 @@ class StackTransaction(object): 'This can happen if you modify a branch with git.', '"stg repair --help" explains more about what to do next.') self.__abort() - if self.__current_tree != tree: - assert iw != None - iw.checkout(self.__current_tree, tree) - self.__current_tree = tree + if self.__current_tree == tree: + # No tree change, but we still want to make sure that + # there are no unresolved conflicts. Conflicts + # conceptually "belong" to the topmost patch, and just + # carrying them along to another patch is confusing. + if (self.__allow_conflicts(self) or iw == None + or not iw.index.conflicts()): + return + out.error('Need to resolve conflicts first') + self.__abort() + assert iw != None + iw.checkout(self.__current_tree, tree) + self.__current_tree = tree @staticmethod def __abort(): raise TransactionException( 'Command aborted (all changes rolled back)') def __check_consistency(self): - remaining = set(self.__applied + self.__unapplied) + remaining = set(self.__applied + self.__unapplied + self.__hidden) for pn, commit in self.__patches.iteritems(): if commit == None: assert self.__stack.patches.exists(pn) @@ -88,36 +153,48 @@ class StackTransaction(object): # The only state we need to restore is index+worktree. if iw: self.__checkout(self.__stack.head.data.tree, iw) - def run(self, iw = None): + def run(self, iw = None, set_head = True): + """Execute the transaction. Will either succeed, or fail (with an + exception) and do nothing.""" self.__check_consistency() new_head = self.__head # Set branch head. - if iw: - try: - self.__checkout(new_head.data.tree, iw) - except git.CheckoutException: - # We have to abort the transaction. - self.abort(iw) - self.__abort() - self.__stack.set_head(new_head, self.__msg) + if set_head: + if iw: + try: + self.__checkout(new_head.data.tree, iw) + except git.CheckoutException: + # We have to abort the transaction. + self.abort(iw) + self.__abort() + self.__stack.set_head(new_head, self.__msg) if self.__error: out.error(self.__error) # Write patches. - for pn, commit in self.__patches.iteritems(): - if self.__stack.patches.exists(pn): - p = self.__stack.patches.get(pn) - if commit == None: - p.delete() + def write(msg): + for pn, commit in self.__patches.iteritems(): + if self.__stack.patches.exists(pn): + p = self.__stack.patches.get(pn) + if commit == None: + p.delete() + else: + p.set_commit(commit, msg) else: - p.set_commit(commit, self.__msg) - else: - self.__stack.patches.new(pn, commit, self.__msg) - _print_current_patch(self.__stack.patchorder.applied, self.__applied) - self.__stack.patchorder.applied = self.__applied - self.__stack.patchorder.unapplied = self.__unapplied + self.__stack.patches.new(pn, commit, msg) + self.__stack.patchorder.applied = self.__applied + self.__stack.patchorder.unapplied = self.__unapplied + self.__stack.patchorder.hidden = self.__hidden + log.log_entry(self.__stack, msg) + old_applied = self.__stack.patchorder.applied + write(self.__msg) + if self.__conflicting_push != None: + self.__patches = _TransPatchMap(self.__stack) + self.__conflicting_push() + write(self.__msg + ' (CONFLICT)') + _print_current_patch(old_applied, self.__applied) if self.__error: return utils.STGIT_CONFLICT @@ -139,7 +216,8 @@ class StackTransaction(object): def pop_patches(self, p): """Pop all patches pn for which p(pn) is true. Return the list of - other patches that had to be popped to accomplish this.""" + other patches that had to be popped to accomplish this. Always + succeeds.""" popped = [] for i in xrange(len(self.applied)): if p(self.applied[i]): @@ -154,9 +232,10 @@ class StackTransaction(object): def delete_patches(self, p): """Delete all patches pn for which p(pn) is true. Return the list of - other patches that had to be popped to accomplish this.""" + other patches that had to be popped to accomplish this. Always + succeeds.""" popped = [] - all_patches = self.applied + self.unapplied + all_patches = self.applied + self.unapplied + self.hidden for i in xrange(len(self.applied)): if p(self.applied[i]): popped = self.applied[i:] @@ -164,6 +243,7 @@ class StackTransaction(object): break popped = [pn for pn in popped if not p(pn)] self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)] + self.hidden = [pn for pn in self.hidden if not p(pn)] self.__print_popped(popped) for pn in all_patches: if p(pn): @@ -184,7 +264,8 @@ class StackTransaction(object): base = oldparent.data.tree ours = cd.parent.data.tree theirs = cd.tree - tree = self.__stack.repository.simple_merge(base, ours, theirs) + tree, self.temp_index_tree = self.temp_index.merge( + base, ours, theirs, self.temp_index_tree) merge_conflict = False if not tree: if iw == None: @@ -207,11 +288,41 @@ class StackTransaction(object): cd = cd.set_tree(tree) if any(getattr(cd, a) != getattr(orig_cd, a) for a in ['parent', 'tree', 'author', 'message']): - self.patches[pn] = self.__stack.repository.commit(cd) + comm = self.__stack.repository.commit(cd) else: + comm = None s = ' (unmodified)' - del self.unapplied[self.unapplied.index(pn)] - self.applied.append(pn) out.info('Pushed %s%s' % (pn, s)) + def update(): + if comm: + self.patches[pn] = comm + if pn in self.hidden: + x = self.hidden + else: + x = self.unapplied + del x[x.index(pn)] + self.applied.append(pn) if merge_conflict: + # We've just caused conflicts, so we must allow them in + # the final checkout. + self.__allow_conflicts = lambda trans: True + + # Save this update so that we can run it a little later. + self.__conflicting_push = update self.__halt('Merge conflict') + else: + # Update immediately. + update() + + def reorder_patches(self, applied, unapplied, hidden, iw = None): + """Push and pop patches to attain the given ordering.""" + common = len(list(it.takewhile(lambda (a, b): a == b, + zip(self.applied, applied)))) + to_pop = set(self.applied[common:]) + self.pop_patches(lambda pn: pn in to_pop) + for pn in applied[common:]: + self.push_patch(pn, iw) + assert self.applied == applied + assert set(self.unapplied + self.hidden) == set(unapplied + hidden) + self.unapplied = unapplied + self.hidden = hidden