[PATCH] Really fix import --edit invoking editor twice
[stgit] / stgit / stack.py
index f43f94b..8970c04 100644 (file)
@@ -29,10 +29,21 @@ from stgit.config import config
 class StackException(Exception):
     pass
 
+class FilterUntil:
+    def __init__(self):
+        self.should_print = True
+    def __call__(self, x, until_test, prefix):
+        if until_test(x):
+            self.should_print = False
+        if self.should_print:
+            return x[0:len(prefix)] != prefix
+        return False
+
 #
 # Functions
 #
 __comment_prefix = 'STG:'
+__patch_prefix = 'STG_PATCH:'
 
 def __clean_comments(f):
     """Removes lines marked for status in a commit file
@@ -40,8 +51,12 @@ def __clean_comments(f):
     f.seek(0)
 
     # remove status-prefixed lines
-    lines = filter(lambda x: x[0:len(__comment_prefix)] != __comment_prefix,
-                   f.readlines())
+    lines = f.readlines()
+
+    patch_filter = FilterUntil()
+    until_test = lambda t: t == (__patch_prefix + '\n')
+    lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
+
     # remove empty lines at the end
     while len(lines) != 0 and lines[-1] == '\n':
         del lines[-1]
@@ -49,7 +64,7 @@ def __clean_comments(f):
     f.seek(0); f.truncate()
     f.writelines(lines)
 
-def edit_file(string, comment):
+def edit_file(series, string, comment, show_patch = True):
     fname = '.stgit.msg'
     tmpl = os.path.join(git.base_dir, 'patchdescr.tmpl')
 
@@ -66,6 +81,14 @@ def edit_file(string, comment):
           % __comment_prefix
     print >> f, __comment_prefix, \
           'Trailing empty lines will be automatically removed.'
+
+    if show_patch:
+       print >> f, __patch_prefix
+       # series.get_patch(series.get_current()).get_top()
+       git.diff([], series.get_patch(series.get_current()).get_bottom(), None, f)
+
+    #Vim modeline must be near the end.
+    print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff:'
     f.close()
 
     # the editor
@@ -324,11 +347,14 @@ class Series:
 
         create_empty_file(self.__applied_file)
         create_empty_file(self.__unapplied_file)
+        self.__begin_stack_check()
 
-    def refresh_patch(self, message = None, edit = False,
+    def refresh_patch(self, message = None, edit = False, show_patch = False,
+                      cache_update = True,
                       author_name = None, author_email = None,
                       author_date = None,
-                      committer_name = None, committer_email = None):
+                      committer_name = None, committer_email = None,
+                      commit_only = False):
         """Generates a new commit for the given patch
         """
         name = self.get_current()
@@ -345,9 +371,9 @@ class Series:
             descr = message
 
         if not message and edit:
-            descr = edit_file(descr.rstrip(), \
+            descr = edit_file(self, descr.rstrip(), \
                               'Please edit the description for patch "%s" ' \
-                              'above.' % name)
+                              'above.' % name, show_patch)
 
         if not author_name:
             author_name = patch.get_authname()
@@ -361,6 +387,7 @@ class Series:
             committer_email = patch.get_commemail()
 
         commit_id = git.commit(message = descr, parents = [patch.get_bottom()],
+                               cache_update = cache_update,
                                allowempty = True,
                                author_name = author_name,
                                author_email = author_email,
@@ -368,15 +395,18 @@ class Series:
                                committer_name = committer_name,
                                committer_email = committer_email)
 
-        patch.set_top(commit_id)
-        patch.set_description(descr)
-        patch.set_authname(author_name)
-        patch.set_authemail(author_email)
-        patch.set_authdate(author_date)
-        patch.set_commname(committer_name)
-        patch.set_commemail(committer_email)
+        if not commit_only:
+            patch.set_top(commit_id)
+            patch.set_description(descr)
+            patch.set_authname(author_name)
+            patch.set_authemail(author_email)
+            patch.set_authdate(author_date)
+            patch.set_commname(committer_name)
+            patch.set_commemail(committer_email)
+
+        return commit_id
 
-    def new_patch(self, name, message = None, edit = False,
+    def new_patch(self, name, message = None, can_edit = True, show_patch = False,
                   author_name = None, author_email = None, author_date = None,
                   committer_name = None, committer_email = None):
         """Creates a new patch
@@ -384,10 +414,10 @@ class Series:
         if self.__patch_applied(name) or self.__patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
-        if not message:
-            descr = edit_file(None, \
+        if not message and can_edit:
+            descr = edit_file(self, None, \
                               'Please enter the description for patch "%s" ' \
-                              'above.' % name)
+                              'above.' % name, show_patch)
         else:
             descr = message
 
@@ -476,17 +506,21 @@ class Series:
 
         self.__set_current(name)
 
-        if not ex:
-            # if the merge was OK and no conflicts, just refresh the patch
-            self.refresh_patch()
-        else:
-            raise StackException, str(ex)
+        # head == bottom case doesn't need to refresh the patch
+        if head != bottom:
+            if not ex:
+                # if the merge was OK and no conflicts, just refresh the patch
+                # The GIT cache was already updated by the merge operation
+                self.refresh_patch(cache_update = False)
+            else:
+                raise StackException, str(ex)
 
     def undo_push(self):
         name = self.get_current()
         assert(name)
 
         patch = Patch(name, self.__patch_dir)
+        git.reset()
         self.pop_patch(name)
         patch.restore_old_boundaries()
 
@@ -535,7 +569,8 @@ class Series:
 
         if bottom == top:
             return True
-        elif git.Commit(top).get_tree() == git.Commit(bottom).get_tree():
+        elif git.get_commit(top).get_tree() \
+                 == git.get_commit(bottom).get_tree():
             return True
 
         return False