Merge branch 'stable'
[stgit] / stgit / lib / transaction.py
1 from stgit import exception, utils
2 from stgit.utils import any, all
3 from stgit.out import *
4 from stgit.lib import git
5
6 class TransactionException(exception.StgException):
7 pass
8
9 class TransactionHalted(TransactionException):
10 pass
11
12 def _print_current_patch(old_applied, new_applied):
13 def now_at(pn):
14 out.info('Now at patch "%s"' % pn)
15 if not old_applied and not new_applied:
16 pass
17 elif not old_applied:
18 now_at(new_applied[-1])
19 elif not new_applied:
20 out.info('No patch applied')
21 elif old_applied[-1] == new_applied[-1]:
22 pass
23 else:
24 now_at(new_applied[-1])
25
26 class _TransPatchMap(dict):
27 def __init__(self, stack):
28 dict.__init__(self)
29 self.__stack = stack
30 def __getitem__(self, pn):
31 try:
32 return dict.__getitem__(self, pn)
33 except KeyError:
34 return self.__stack.patches.get(pn).commit
35
36 class StackTransaction(object):
37 def __init__(self, stack, msg, allow_conflicts = False):
38 self.__stack = stack
39 self.__msg = msg
40 self.__patches = _TransPatchMap(stack)
41 self.__applied = list(self.__stack.patchorder.applied)
42 self.__unapplied = list(self.__stack.patchorder.unapplied)
43 self.__error = None
44 self.__current_tree = self.__stack.head.data.tree
45 self.__base = self.__stack.base
46 if isinstance(allow_conflicts, bool):
47 self.__allow_conflicts = lambda trans: allow_conflicts
48 else:
49 self.__allow_conflicts = allow_conflicts
50 stack = property(lambda self: self.__stack)
51 patches = property(lambda self: self.__patches)
52 def __set_applied(self, val):
53 self.__applied = list(val)
54 applied = property(lambda self: self.__applied, __set_applied)
55 def __set_unapplied(self, val):
56 self.__unapplied = list(val)
57 unapplied = property(lambda self: self.__unapplied, __set_unapplied)
58 def __set_base(self, val):
59 assert (not self.__applied
60 or self.patches[self.applied[0]].data.parent == val)
61 self.__base = val
62 base = property(lambda self: self.__base, __set_base)
63 def __checkout(self, tree, iw):
64 if not self.__stack.head_top_equal():
65 out.error(
66 'HEAD and top are not the same.',
67 'This can happen if you modify a branch with git.',
68 '"stg repair --help" explains more about what to do next.')
69 self.__abort()
70 if self.__current_tree == tree:
71 # No tree change, but we still want to make sure that
72 # there are no unresolved conflicts. Conflicts
73 # conceptually "belong" to the topmost patch, and just
74 # carrying them along to another patch is confusing.
75 if (self.__allow_conflicts(self) or iw == None
76 or not iw.index.conflicts()):
77 return
78 out.error('Need to resolve conflicts first')
79 self.__abort()
80 assert iw != None
81 iw.checkout(self.__current_tree, tree)
82 self.__current_tree = tree
83 @staticmethod
84 def __abort():
85 raise TransactionException(
86 'Command aborted (all changes rolled back)')
87 def __check_consistency(self):
88 remaining = set(self.__applied + self.__unapplied)
89 for pn, commit in self.__patches.iteritems():
90 if commit == None:
91 assert self.__stack.patches.exists(pn)
92 else:
93 assert pn in remaining
94 @property
95 def __head(self):
96 if self.__applied:
97 return self.__patches[self.__applied[-1]]
98 else:
99 return self.__base
100 def abort(self, iw = None):
101 # The only state we need to restore is index+worktree.
102 if iw:
103 self.__checkout(self.__stack.head.data.tree, iw)
104 def run(self, iw = None):
105 self.__check_consistency()
106 new_head = self.__head
107
108 # Set branch head.
109 if iw:
110 try:
111 self.__checkout(new_head.data.tree, iw)
112 except git.CheckoutException:
113 # We have to abort the transaction.
114 self.abort(iw)
115 self.__abort()
116 self.__stack.set_head(new_head, self.__msg)
117
118 if self.__error:
119 out.error(self.__error)
120
121 # Write patches.
122 for pn, commit in self.__patches.iteritems():
123 if self.__stack.patches.exists(pn):
124 p = self.__stack.patches.get(pn)
125 if commit == None:
126 p.delete()
127 else:
128 p.set_commit(commit, self.__msg)
129 else:
130 self.__stack.patches.new(pn, commit, self.__msg)
131 _print_current_patch(self.__stack.patchorder.applied, self.__applied)
132 self.__stack.patchorder.applied = self.__applied
133 self.__stack.patchorder.unapplied = self.__unapplied
134
135 if self.__error:
136 return utils.STGIT_CONFLICT
137 else:
138 return utils.STGIT_SUCCESS
139
140 def __halt(self, msg):
141 self.__error = msg
142 raise TransactionHalted(msg)
143
144 @staticmethod
145 def __print_popped(popped):
146 if len(popped) == 0:
147 pass
148 elif len(popped) == 1:
149 out.info('Popped %s' % popped[0])
150 else:
151 out.info('Popped %s -- %s' % (popped[-1], popped[0]))
152
153 def pop_patches(self, p):
154 """Pop all patches pn for which p(pn) is true. Return the list of
155 other patches that had to be popped to accomplish this."""
156 popped = []
157 for i in xrange(len(self.applied)):
158 if p(self.applied[i]):
159 popped = self.applied[i:]
160 del self.applied[i:]
161 break
162 popped1 = [pn for pn in popped if not p(pn)]
163 popped2 = [pn for pn in popped if p(pn)]
164 self.unapplied = popped1 + popped2 + self.unapplied
165 self.__print_popped(popped)
166 return popped1
167
168 def delete_patches(self, p):
169 """Delete all patches pn for which p(pn) is true. Return the list of
170 other patches that had to be popped to accomplish this."""
171 popped = []
172 all_patches = self.applied + self.unapplied
173 for i in xrange(len(self.applied)):
174 if p(self.applied[i]):
175 popped = self.applied[i:]
176 del self.applied[i:]
177 break
178 popped = [pn for pn in popped if not p(pn)]
179 self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
180 self.__print_popped(popped)
181 for pn in all_patches:
182 if p(pn):
183 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
184 self.patches[pn] = None
185 out.info('Deleted %s%s' % (pn, s))
186 return popped
187
188 def push_patch(self, pn, iw = None):
189 """Attempt to push the named patch. If this results in conflicts,
190 halts the transaction. If index+worktree are given, spill any
191 conflicts to them."""
192 orig_cd = self.patches[pn].data
193 cd = orig_cd.set_committer(None)
194 s = ['', ' (empty)'][cd.is_nochange()]
195 oldparent = cd.parent
196 cd = cd.set_parent(self.__head)
197 base = oldparent.data.tree
198 ours = cd.parent.data.tree
199 theirs = cd.tree
200 tree = self.__stack.repository.simple_merge(base, ours, theirs)
201 merge_conflict = False
202 if not tree:
203 if iw == None:
204 self.__halt('%s does not apply cleanly' % pn)
205 try:
206 self.__checkout(ours, iw)
207 except git.CheckoutException:
208 self.__halt('Index/worktree dirty')
209 try:
210 iw.merge(base, ours, theirs)
211 tree = iw.index.write_tree()
212 self.__current_tree = tree
213 s = ' (modified)'
214 except git.MergeConflictException:
215 tree = ours
216 merge_conflict = True
217 s = ' (conflict)'
218 except git.MergeException, e:
219 self.__halt(str(e))
220 cd = cd.set_tree(tree)
221 if any(getattr(cd, a) != getattr(orig_cd, a) for a in
222 ['parent', 'tree', 'author', 'message']):
223 self.patches[pn] = self.__stack.repository.commit(cd)
224 else:
225 s = ' (unmodified)'
226 del self.unapplied[self.unapplied.index(pn)]
227 self.applied.append(pn)
228 out.info('Pushed %s%s' % (pn, s))
229 if merge_conflict:
230 # We've just caused conflicts, so we must allow them in
231 # the final checkout.
232 self.__allow_conflicts = lambda trans: True
233
234 self.__halt('Merge conflict')