Move stack reset function to a shared location
[stgit] / stgit / lib / transaction.py
1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
3
4 import atexit
5 import itertools as it
6
7 from stgit import exception, utils
8 from stgit.utils import any, all
9 from stgit.out import *
10 from stgit.lib import git, log
11
12 class TransactionException(exception.StgException):
13 """Exception raised when something goes wrong with a
14 L{StackTransaction}."""
15
16 class TransactionHalted(TransactionException):
17 """Exception raised when a L{StackTransaction} stops part-way through.
18 Used to make a non-local jump from the transaction setup to the
19 part of the transaction code where the transaction is run."""
20
21 def _print_current_patch(old_applied, new_applied):
22 def now_at(pn):
23 out.info('Now at patch "%s"' % pn)
24 if not old_applied and not new_applied:
25 pass
26 elif not old_applied:
27 now_at(new_applied[-1])
28 elif not new_applied:
29 out.info('No patch applied')
30 elif old_applied[-1] == new_applied[-1]:
31 pass
32 else:
33 now_at(new_applied[-1])
34
35 class _TransPatchMap(dict):
36 """Maps patch names to sha1 strings."""
37 def __init__(self, stack):
38 dict.__init__(self)
39 self.__stack = stack
40 def __getitem__(self, pn):
41 try:
42 return dict.__getitem__(self, pn)
43 except KeyError:
44 return self.__stack.patches.get(pn).commit
45
46 class StackTransaction(object):
47 """A stack transaction, used for making complex updates to an StGit
48 stack in one single operation that will either succeed or fail
49 cleanly.
50
51 The basic theory of operation is the following:
52
53 1. Create a transaction object.
54
55 2. Inside a::
56
57 try
58 ...
59 except TransactionHalted:
60 pass
61
62 block, update the transaction with e.g. methods like
63 L{pop_patches} and L{push_patch}. This may create new git
64 objects such as commits, but will not write any refs; this means
65 that in case of a fatal error we can just walk away, no clean-up
66 required.
67
68 (Some operations may need to touch your index and working tree,
69 though. But they are cleaned up when needed.)
70
71 3. After the C{try} block -- wheher or not the setup ran to
72 completion or halted part-way through by raising a
73 L{TransactionHalted} exception -- call the transaction's L{run}
74 method. This will either succeed in writing the updated state to
75 your refs and index+worktree, or fail without having done
76 anything."""
77 def __init__(self, stack, msg, discard_changes = False,
78 allow_conflicts = False):
79 """Create a new L{StackTransaction}.
80
81 @param discard_changes: Discard any changes in index+worktree
82 @type discard_changes: bool
83 @param allow_conflicts: Whether to allow pre-existing conflicts
84 @type allow_conflicts: bool or function of L{StackTransaction}"""
85 self.__stack = stack
86 self.__msg = msg
87 self.__patches = _TransPatchMap(stack)
88 self.__applied = list(self.__stack.patchorder.applied)
89 self.__unapplied = list(self.__stack.patchorder.unapplied)
90 self.__hidden = list(self.__stack.patchorder.hidden)
91 self.__conflicting_push = None
92 self.__error = None
93 self.__current_tree = self.__stack.head.data.tree
94 self.__base = self.__stack.base
95 self.__discard_changes = discard_changes
96 if isinstance(allow_conflicts, bool):
97 self.__allow_conflicts = lambda trans: allow_conflicts
98 else:
99 self.__allow_conflicts = allow_conflicts
100 self.__temp_index = self.temp_index_tree = None
101 stack = property(lambda self: self.__stack)
102 patches = property(lambda self: self.__patches)
103 def __set_applied(self, val):
104 self.__applied = list(val)
105 applied = property(lambda self: self.__applied, __set_applied)
106 def __set_unapplied(self, val):
107 self.__unapplied = list(val)
108 unapplied = property(lambda self: self.__unapplied, __set_unapplied)
109 def __set_hidden(self, val):
110 self.__hidden = list(val)
111 hidden = property(lambda self: self.__hidden, __set_hidden)
112 all_patches = property(lambda self: (self.__applied + self.__unapplied
113 + self.__hidden))
114 def __set_base(self, val):
115 assert (not self.__applied
116 or self.patches[self.applied[0]].data.parent == val)
117 self.__base = val
118 base = property(lambda self: self.__base, __set_base)
119 @property
120 def temp_index(self):
121 if not self.__temp_index:
122 self.__temp_index = self.__stack.repository.temp_index()
123 atexit.register(self.__temp_index.delete)
124 return self.__temp_index
125 def __checkout(self, tree, iw):
126 if not self.__stack.head_top_equal():
127 out.error(
128 'HEAD and top are not the same.',
129 'This can happen if you modify a branch with git.',
130 '"stg repair --help" explains more about what to do next.')
131 self.__abort()
132 if self.__current_tree == tree and not self.__discard_changes:
133 # No tree change, but we still want to make sure that
134 # there are no unresolved conflicts. Conflicts
135 # conceptually "belong" to the topmost patch, and just
136 # carrying them along to another patch is confusing.
137 if (self.__allow_conflicts(self) or iw == None
138 or not iw.index.conflicts()):
139 return
140 out.error('Need to resolve conflicts first')
141 self.__abort()
142 assert iw != None
143 if self.__discard_changes:
144 iw.checkout_hard(tree)
145 else:
146 iw.checkout(self.__current_tree, tree)
147 self.__current_tree = tree
148 @staticmethod
149 def __abort():
150 raise TransactionException(
151 'Command aborted (all changes rolled back)')
152 def __check_consistency(self):
153 remaining = set(self.all_patches)
154 for pn, commit in self.__patches.iteritems():
155 if commit == None:
156 assert self.__stack.patches.exists(pn)
157 else:
158 assert pn in remaining
159 @property
160 def __head(self):
161 if self.__applied:
162 return self.__patches[self.__applied[-1]]
163 else:
164 return self.__base
165 def abort(self, iw = None):
166 # The only state we need to restore is index+worktree.
167 if iw:
168 self.__checkout(self.__stack.head.data.tree, iw)
169 def run(self, iw = None, set_head = True):
170 """Execute the transaction. Will either succeed, or fail (with an
171 exception) and do nothing."""
172 self.__check_consistency()
173 new_head = self.__head
174
175 # Set branch head.
176 if set_head:
177 if iw:
178 try:
179 self.__checkout(new_head.data.tree, iw)
180 except git.CheckoutException:
181 # We have to abort the transaction.
182 self.abort(iw)
183 self.__abort()
184 self.__stack.set_head(new_head, self.__msg)
185
186 if self.__error:
187 out.error(self.__error)
188
189 # Write patches.
190 def write(msg):
191 for pn, commit in self.__patches.iteritems():
192 if self.__stack.patches.exists(pn):
193 p = self.__stack.patches.get(pn)
194 if commit == None:
195 p.delete()
196 else:
197 p.set_commit(commit, msg)
198 else:
199 self.__stack.patches.new(pn, commit, msg)
200 self.__stack.patchorder.applied = self.__applied
201 self.__stack.patchorder.unapplied = self.__unapplied
202 self.__stack.patchorder.hidden = self.__hidden
203 log.log_entry(self.__stack, msg)
204 old_applied = self.__stack.patchorder.applied
205 write(self.__msg)
206 if self.__conflicting_push != None:
207 self.__patches = _TransPatchMap(self.__stack)
208 self.__conflicting_push()
209 write(self.__msg + ' (CONFLICT)')
210 _print_current_patch(old_applied, self.__applied)
211
212 if self.__error:
213 return utils.STGIT_CONFLICT
214 else:
215 return utils.STGIT_SUCCESS
216
217 def __halt(self, msg):
218 self.__error = msg
219 raise TransactionHalted(msg)
220
221 @staticmethod
222 def __print_popped(popped):
223 if len(popped) == 0:
224 pass
225 elif len(popped) == 1:
226 out.info('Popped %s' % popped[0])
227 else:
228 out.info('Popped %s -- %s' % (popped[-1], popped[0]))
229
230 def pop_patches(self, p):
231 """Pop all patches pn for which p(pn) is true. Return the list of
232 other patches that had to be popped to accomplish this. Always
233 succeeds."""
234 popped = []
235 for i in xrange(len(self.applied)):
236 if p(self.applied[i]):
237 popped = self.applied[i:]
238 del self.applied[i:]
239 break
240 popped1 = [pn for pn in popped if not p(pn)]
241 popped2 = [pn for pn in popped if p(pn)]
242 self.unapplied = popped1 + popped2 + self.unapplied
243 self.__print_popped(popped)
244 return popped1
245
246 def delete_patches(self, p):
247 """Delete all patches pn for which p(pn) is true. Return the list of
248 other patches that had to be popped to accomplish this. Always
249 succeeds."""
250 popped = []
251 all_patches = self.applied + self.unapplied + self.hidden
252 for i in xrange(len(self.applied)):
253 if p(self.applied[i]):
254 popped = self.applied[i:]
255 del self.applied[i:]
256 break
257 popped = [pn for pn in popped if not p(pn)]
258 self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
259 self.hidden = [pn for pn in self.hidden if not p(pn)]
260 self.__print_popped(popped)
261 for pn in all_patches:
262 if p(pn):
263 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
264 self.patches[pn] = None
265 out.info('Deleted %s%s' % (pn, s))
266 return popped
267
268 def push_patch(self, pn, iw = None):
269 """Attempt to push the named patch. If this results in conflicts,
270 halts the transaction. If index+worktree are given, spill any
271 conflicts to them."""
272 orig_cd = self.patches[pn].data
273 cd = orig_cd.set_committer(None)
274 s = ['', ' (empty)'][cd.is_nochange()]
275 oldparent = cd.parent
276 cd = cd.set_parent(self.__head)
277 base = oldparent.data.tree
278 ours = cd.parent.data.tree
279 theirs = cd.tree
280 tree, self.temp_index_tree = self.temp_index.merge(
281 base, ours, theirs, self.temp_index_tree)
282 merge_conflict = False
283 if not tree:
284 if iw == None:
285 self.__halt('%s does not apply cleanly' % pn)
286 try:
287 self.__checkout(ours, iw)
288 except git.CheckoutException:
289 self.__halt('Index/worktree dirty')
290 try:
291 iw.merge(base, ours, theirs)
292 tree = iw.index.write_tree()
293 self.__current_tree = tree
294 s = ' (modified)'
295 except git.MergeConflictException:
296 tree = ours
297 merge_conflict = True
298 s = ' (conflict)'
299 except git.MergeException, e:
300 self.__halt(str(e))
301 cd = cd.set_tree(tree)
302 if any(getattr(cd, a) != getattr(orig_cd, a) for a in
303 ['parent', 'tree', 'author', 'message']):
304 comm = self.__stack.repository.commit(cd)
305 else:
306 comm = None
307 s = ' (unmodified)'
308 out.info('Pushed %s%s' % (pn, s))
309 def update():
310 if comm:
311 self.patches[pn] = comm
312 if pn in self.hidden:
313 x = self.hidden
314 else:
315 x = self.unapplied
316 del x[x.index(pn)]
317 self.applied.append(pn)
318 if merge_conflict:
319 # We've just caused conflicts, so we must allow them in
320 # the final checkout.
321 self.__allow_conflicts = lambda trans: True
322
323 # Save this update so that we can run it a little later.
324 self.__conflicting_push = update
325 self.__halt('Merge conflict')
326 else:
327 # Update immediately.
328 update()
329
330 def reorder_patches(self, applied, unapplied, hidden, iw = None):
331 """Push and pop patches to attain the given ordering."""
332 common = len(list(it.takewhile(lambda (a, b): a == b,
333 zip(self.applied, applied))))
334 to_pop = set(self.applied[common:])
335 self.pop_patches(lambda pn: pn in to_pop)
336 for pn in applied[common:]:
337 self.push_patch(pn, iw)
338 assert self.applied == applied
339 assert set(self.unapplied + self.hidden) == set(unapplied + hidden)
340 self.unapplied = unapplied
341 self.hidden = hidden