[PATCH] Allow fast-forward pushing.
[stgit] / stgit / stack.py
index cef4ae5..cfac219 100644 (file)
@@ -29,10 +29,21 @@ from stgit.config import config
 class StackException(Exception):
     pass
 
+class FilterUntil:
+    def __init__(self):
+        self.should_print = True
+    def __call__(self, x, until_test, prefix):
+        if until_test(x):
+            self.should_print = False
+        if self.should_print:
+            return x[0:len(prefix)] != prefix
+        return False
+
 #
 # Functions
 #
 __comment_prefix = 'STG:'
+__patch_prefix = 'STG_PATCH:'
 
 def __clean_comments(f):
     """Removes lines marked for status in a commit file
@@ -40,8 +51,12 @@ def __clean_comments(f):
     f.seek(0)
 
     # remove status-prefixed lines
-    lines = filter(lambda x: x[0:len(__comment_prefix)] != __comment_prefix,
-                   f.readlines())
+    lines = f.readlines()
+
+    patch_filter = FilterUntil()
+    until_test = lambda t: t == (__patch_prefix + '\n')
+    lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
+
     # remove empty lines at the end
     while len(lines) != 0 and lines[-1] == '\n':
         del lines[-1]
@@ -49,7 +64,7 @@ def __clean_comments(f):
     f.seek(0); f.truncate()
     f.writelines(lines)
 
-def edit_file(string, comment):
+def edit_file(series, string, comment, show_patch = True):
     fname = '.stgit.msg'
     tmpl = os.path.join(git.base_dir, 'patchdescr.tmpl')
 
@@ -66,6 +81,14 @@ def edit_file(string, comment):
           % __comment_prefix
     print >> f, __comment_prefix, \
           'Trailing empty lines will be automatically removed.'
+
+    if show_patch:
+       print >> f, __patch_prefix
+       # series.get_patch(series.get_current()).get_top()
+       git.diff([], series.get_patch(series.get_current()).get_bottom(), None, f)
+
+    #Vim modeline must be near the end.
+    print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff:'
     f.close()
 
     # the editor
@@ -115,6 +138,13 @@ class Patch:
     def get_name(self):
         return self.__name
 
+    def rename(self, newname):
+        olddir = self.__dir
+        self.__name = newname
+        self.__dir = os.path.join(self.__patch_dir, self.__name)
+
+        os.rename(olddir, self.__dir)
+
     def __get_field(self, name, multiline = False):
         id_file = os.path.join(self.__dir, name)
         if os.path.isfile(id_file):
@@ -317,11 +347,14 @@ class Series:
 
         create_empty_file(self.__applied_file)
         create_empty_file(self.__unapplied_file)
+        self.__begin_stack_check()
 
-    def refresh_patch(self, message = None, edit = False,
+    def refresh_patch(self, message = None, edit = False, show_patch = False,
+                      cache_update = True,
                       author_name = None, author_email = None,
                       author_date = None,
-                      committer_name = None, committer_email = None):
+                      committer_name = None, committer_email = None,
+                      commit_only = False):
         """Generates a new commit for the given patch
         """
         name = self.get_current()
@@ -338,9 +371,9 @@ class Series:
             descr = message
 
         if not message and edit:
-            descr = edit_file(descr.rstrip(), \
+            descr = edit_file(self, descr.rstrip(), \
                               'Please edit the description for patch "%s" ' \
-                              'above.' % name)
+                              'above.' % name, show_patch)
 
         if not author_name:
             author_name = patch.get_authname()
@@ -354,6 +387,7 @@ class Series:
             committer_email = patch.get_commemail()
 
         commit_id = git.commit(message = descr, parents = [patch.get_bottom()],
+                               cache_update = cache_update,
                                allowempty = True,
                                author_name = author_name,
                                author_email = author_email,
@@ -361,15 +395,18 @@ class Series:
                                committer_name = committer_name,
                                committer_email = committer_email)
 
-        patch.set_top(commit_id)
-        patch.set_description(descr)
-        patch.set_authname(author_name)
-        patch.set_authemail(author_email)
-        patch.set_authdate(author_date)
-        patch.set_commname(committer_name)
-        patch.set_commemail(committer_email)
+        if not commit_only:
+            patch.set_top(commit_id)
+            patch.set_description(descr)
+            patch.set_authname(author_name)
+            patch.set_authemail(author_email)
+            patch.set_authdate(author_date)
+            patch.set_commname(committer_name)
+            patch.set_commemail(committer_email)
 
-    def new_patch(self, name, message = None, edit = False,
+        return commit_id
+
+    def new_patch(self, name, message = None, can_edit = True, show_patch = False,
                   author_name = None, author_email = None, author_date = None,
                   committer_name = None, committer_email = None):
         """Creates a new patch
@@ -377,10 +414,10 @@ class Series:
         if self.__patch_applied(name) or self.__patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
-        if not message:
-            descr = edit_file(None, \
+        if not message and can_edit:
+            descr = edit_file(self, None, \
                               'Please enter the description for patch "%s" ' \
-                              'above.' % name)
+                              'above.' % name, show_patch)
         else:
             descr = message
 
@@ -423,6 +460,53 @@ class Series:
         f.writelines([line + '\n' for line in unapplied])
         f.close()
 
+    def forward_patches(self, names):
+        """Try to fast-forward an array of patches.
+
+        On return, patches in names[0:returned_value] have been pushed on the
+        stack. Apply the rest with push_patch
+        """
+        unapplied = self.get_unapplied()
+        self.__begin_stack_check()
+
+        forwarded = 0
+        top = git.get_head()
+
+        for name in names:
+            assert(name in unapplied)
+
+            patch = Patch(name, self.__patch_dir)
+
+            head = top
+            bottom = patch.get_bottom()
+            top = patch.get_top()
+
+            # top != bottom always since we have a commit for each patch
+            if head == bottom:
+                # reset the backup information
+                patch.set_bottom(bottom, backup = True)
+                patch.set_top(top, backup = True)
+
+            else:
+                top = head
+                # stop the fast-forwarding, must do a real merge
+                break
+
+            forwarded+=1
+            unapplied.remove(name)
+
+        git.switch(top)
+
+        append_strings(self.__applied_file, names[0:forwarded])
+
+        f = file(self.__unapplied_file, 'w+')
+        f.writelines([line + '\n' for line in unapplied])
+        f.close()
+
+        self.__set_current(name)
+
+        return forwarded
+
     def push_patch(self, name):
         """Pushes a patch on the stack
         """
@@ -469,17 +553,21 @@ class Series:
 
         self.__set_current(name)
 
-        if not ex:
-            # if the merge was OK and no conflicts, just refresh the patch
-            self.refresh_patch()
-        else:
-            raise StackException, str(ex)
+        # head == bottom case doesn't need to refresh the patch
+        if head != bottom:
+            if not ex:
+                # if the merge was OK and no conflicts, just refresh the patch
+                # The GIT cache was already updated by the merge operation
+                self.refresh_patch(cache_update = False)
+            else:
+                raise StackException, str(ex)
 
     def undo_push(self):
         name = self.get_current()
         assert(name)
 
         patch = Patch(name, self.__patch_dir)
+        git.reset()
         self.pop_patch(name)
         patch.restore_old_boundaries()
 
@@ -528,7 +616,35 @@ class Series:
 
         if bottom == top:
             return True
-        elif git.Commit(top).get_tree() == git.Commit(bottom).get_tree():
+        elif git.get_commit(top).get_tree() \
+                 == git.get_commit(bottom).get_tree():
             return True
 
         return False
+
+    def rename_patch(self, oldname, newname):
+        applied = self.get_applied()
+        unapplied = self.get_unapplied()
+
+        if newname in applied or newname in unapplied:
+            raise StackException, 'Patch "%s" already exists' % newname
+
+        if oldname in unapplied:
+            Patch(oldname, self.__patch_dir).rename(newname)
+            unapplied[unapplied.index(oldname)] = newname
+
+            f = file(self.__unapplied_file, 'w+')
+            f.writelines([line + '\n' for line in unapplied])
+            f.close()
+        elif oldname in applied:
+            Patch(oldname, self.__patch_dir).rename(newname)
+            if oldname == self.get_current():
+                self.__set_current(newname)
+
+            applied[applied.index(oldname)] = newname
+
+            f = file(self.__applied_file, 'w+')
+            f.writelines([line + '\n' for line in applied])
+            f.close()
+        else:
+            raise StackException, 'Unknown patch "%s"' % oldname