Expose transaction abort function
[stgit] / stgit / lib / transaction.py
1 from stgit import exception
2 from stgit.out import *
3 from stgit.lib import git
4
5 class TransactionException(exception.StgException):
6 pass
7
8 class TransactionHalted(TransactionException):
9 pass
10
11 def _print_current_patch(old_applied, new_applied):
12 def now_at(pn):
13 out.info('Now at patch "%s"' % pn)
14 if not old_applied and not new_applied:
15 pass
16 elif not old_applied:
17 now_at(new_applied[-1])
18 elif not new_applied:
19 out.info('No patch applied')
20 elif old_applied[-1] == new_applied[-1]:
21 pass
22 else:
23 now_at(new_applied[-1])
24
25 class _TransPatchMap(dict):
26 def __init__(self, stack):
27 dict.__init__(self)
28 self.__stack = stack
29 def __getitem__(self, pn):
30 try:
31 return dict.__getitem__(self, pn)
32 except KeyError:
33 return self.__stack.patches.get(pn).commit
34
35 class StackTransaction(object):
36 def __init__(self, stack, msg):
37 self.__stack = stack
38 self.__msg = msg
39 self.__patches = _TransPatchMap(stack)
40 self.__applied = list(self.__stack.patchorder.applied)
41 self.__unapplied = list(self.__stack.patchorder.unapplied)
42 self.__error = None
43 self.__current_tree = self.__stack.head.data.tree
44 stack = property(lambda self: self.__stack)
45 patches = property(lambda self: self.__patches)
46 def __set_applied(self, val):
47 self.__applied = list(val)
48 applied = property(lambda self: self.__applied, __set_applied)
49 def __set_unapplied(self, val):
50 self.__unapplied = list(val)
51 unapplied = property(lambda self: self.__unapplied, __set_unapplied)
52 def __checkout(self, tree, iw):
53 if not self.__stack.head_top_equal():
54 out.error(
55 'HEAD and top are not the same.',
56 'This can happen if you modify a branch with git.',
57 '"stg repair --help" explains more about what to do next.')
58 self.__abort()
59 if self.__current_tree != tree:
60 assert iw != None
61 iw.checkout(self.__current_tree, tree)
62 self.__current_tree = tree
63 @staticmethod
64 def __abort():
65 raise TransactionException(
66 'Command aborted (all changes rolled back)')
67 def __check_consistency(self):
68 remaining = set(self.__applied + self.__unapplied)
69 for pn, commit in self.__patches.iteritems():
70 if commit == None:
71 assert self.__stack.patches.exists(pn)
72 else:
73 assert pn in remaining
74 @property
75 def __head(self):
76 if self.__applied:
77 return self.__patches[self.__applied[-1]]
78 else:
79 return self.__stack.base
80 def abort(self, iw = None):
81 # The only state we need to restore is index+worktree.
82 if iw:
83 self.__checkout(self.__stack.head.data.tree, iw)
84 def run(self, iw = None):
85 self.__check_consistency()
86 new_head = self.__head
87
88 # Set branch head.
89 try:
90 self.__checkout(new_head.data.tree, iw)
91 except git.CheckoutException:
92 # We have to abort the transaction.
93 self.abort(iw)
94 self.__abort()
95 self.__stack.set_head(new_head, self.__msg)
96
97 if self.__error:
98 out.error(self.__error)
99
100 # Write patches.
101 for pn, commit in self.__patches.iteritems():
102 if self.__stack.patches.exists(pn):
103 p = self.__stack.patches.get(pn)
104 if commit == None:
105 p.delete()
106 else:
107 p.set_commit(commit, self.__msg)
108 else:
109 self.__stack.patches.new(pn, commit, self.__msg)
110 _print_current_patch(self.__stack.patchorder.applied, self.__applied)
111 self.__stack.patchorder.applied = self.__applied
112 self.__stack.patchorder.unapplied = self.__unapplied
113
114 def __halt(self, msg):
115 self.__error = msg
116 raise TransactionHalted(msg)
117
118 @staticmethod
119 def __print_popped(popped):
120 if len(popped) == 0:
121 pass
122 elif len(popped) == 1:
123 out.info('Popped %s' % popped[0])
124 else:
125 out.info('Popped %s -- %s' % (popped[-1], popped[0]))
126
127 def pop_patches(self, p):
128 """Pop all patches pn for which p(pn) is true. Return the list of
129 other patches that had to be popped to accomplish this."""
130 popped = []
131 for i in xrange(len(self.applied)):
132 if p(self.applied[i]):
133 popped = self.applied[i:]
134 del self.applied[i:]
135 break
136 popped1 = [pn for pn in popped if not p(pn)]
137 popped2 = [pn for pn in popped if p(pn)]
138 self.unapplied = popped1 + popped2 + self.unapplied
139 self.__print_popped(popped)
140 return popped1
141
142 def delete_patches(self, p):
143 """Delete all patches pn for which p(pn) is true. Return the list of
144 other patches that had to be popped to accomplish this."""
145 popped = []
146 all_patches = self.applied + self.unapplied
147 for i in xrange(len(self.applied)):
148 if p(self.applied[i]):
149 popped = self.applied[i:]
150 del self.applied[i:]
151 break
152 popped = [pn for pn in popped if not p(pn)]
153 self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
154 self.__print_popped(popped)
155 for pn in all_patches:
156 if p(pn):
157 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
158 self.patches[pn] = None
159 out.info('Deleted %s%s' % (pn, s))
160 return popped
161
162 def push_patch(self, pn, iw = None):
163 """Attempt to push the named patch. If this results in conflicts,
164 halts the transaction. If index+worktree are given, spill any
165 conflicts to them."""
166 i = self.unapplied.index(pn)
167 cd = self.patches[pn].data
168 s = ['', ' (empty)'][cd.is_nochange()]
169 oldparent = cd.parent
170 cd = cd.set_parent(self.__head)
171 base = oldparent.data.tree
172 ours = cd.parent.data.tree
173 theirs = cd.tree
174 tree = self.__stack.repository.simple_merge(base, ours, theirs)
175 merge_conflict = False
176 if not tree:
177 if iw == None:
178 self.__halt('%s does not apply cleanly' % pn)
179 try:
180 self.__checkout(ours, iw)
181 except git.CheckoutException:
182 self.__halt('Index/worktree dirty')
183 try:
184 iw.merge(base, ours, theirs)
185 tree = iw.index.write_tree()
186 self.__current_tree = tree
187 s = ' (modified)'
188 except git.MergeException:
189 tree = ours
190 merge_conflict = True
191 s = ' (conflict)'
192 cd = cd.set_tree(tree)
193 self.patches[pn] = self.__stack.repository.commit(cd)
194 del self.unapplied[i]
195 self.applied.append(pn)
196 out.info('Pushed %s%s' % (pn, s))
197 if merge_conflict:
198 self.__halt('Merge conflict')