Implement fast-forward when only tree (but not
[stgit] / stgit / stack.py
index f43f94b..7109186 100644 (file)
@@ -29,10 +29,21 @@ from stgit.config import config
 class StackException(Exception):
     pass
 
+class FilterUntil:
+    def __init__(self):
+        self.should_print = True
+    def __call__(self, x, until_test, prefix):
+        if until_test(x):
+            self.should_print = False
+        if self.should_print:
+            return x[0:len(prefix)] != prefix
+        return False
+
 #
 # Functions
 #
 __comment_prefix = 'STG:'
+__patch_prefix = 'STG_PATCH:'
 
 def __clean_comments(f):
     """Removes lines marked for status in a commit file
@@ -40,8 +51,12 @@ def __clean_comments(f):
     f.seek(0)
 
     # remove status-prefixed lines
-    lines = filter(lambda x: x[0:len(__comment_prefix)] != __comment_prefix,
-                   f.readlines())
+    lines = f.readlines()
+
+    patch_filter = FilterUntil()
+    until_test = lambda t: t == (__patch_prefix + '\n')
+    lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
+
     # remove empty lines at the end
     while len(lines) != 0 and lines[-1] == '\n':
         del lines[-1]
@@ -49,7 +64,7 @@ def __clean_comments(f):
     f.seek(0); f.truncate()
     f.writelines(lines)
 
-def edit_file(string, comment):
+def edit_file(series, string, comment, show_patch = True):
     fname = '.stgit.msg'
     tmpl = os.path.join(git.base_dir, 'patchdescr.tmpl')
 
@@ -66,10 +81,20 @@ def edit_file(string, comment):
           % __comment_prefix
     print >> f, __comment_prefix, \
           'Trailing empty lines will be automatically removed.'
+
+    if show_patch:
+       print >> f, __patch_prefix
+       # series.get_patch(series.get_current()).get_top()
+       git.diff([], series.get_patch(series.get_current()).get_bottom(), None, f)
+
+    #Vim modeline must be near the end.
+    print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff:'
     f.close()
 
     # the editor
-    if 'EDITOR' in os.environ:
+    if config.has_option('stgit', 'editor'):
+        editor = config.get('stgit', 'editor')
+    elif 'EDITOR' in os.environ:
         editor = os.environ['EDITOR']
     else:
         editor = 'vi'
@@ -145,7 +170,11 @@ class Patch:
 
     def set_bottom(self, string, backup = False):
         if backup:
-            self.__set_field('bottom.old', self.__get_field('bottom'))
+            curr = self.__get_field('bottom')
+            if curr != string:
+                self.__set_field('bottom.old', curr)
+            else:
+                self.__set_field('bottom.old', None)
         self.__set_field('bottom', string)
 
     def get_top(self):
@@ -153,7 +182,11 @@ class Patch:
 
     def set_top(self, string, backup = False):
         if backup:
-            self.__set_field('top.old', self.__get_field('top'))
+            curr = self.__get_field('top')
+            if curr != string:
+                self.__set_field('top.old', curr)
+            else:
+                self.__set_field('top.old', None)
         self.__set_field('top', string)
 
     def restore_old_boundaries(self):
@@ -163,8 +196,9 @@ class Patch:
         if top and bottom:
             self.__set_field('bottom', bottom)
             self.__set_field('top', top)
+            return True
         else:
-            raise StackException, 'No patch undo information'
+            return False
 
     def get_description(self):
         return self.__get_field('description', True)
@@ -232,6 +266,11 @@ class Series:
             self.__unapplied_file = os.path.join(self.__patch_dir, 'unapplied')
             self.__current_file = os.path.join(self.__patch_dir, 'current')
 
+    def get_branch(self):
+        """Return the branch name for the Series object
+        """
+        return self.__name
+
     def __set_current(self, name):
         """Sets the topmost patch
         """
@@ -324,11 +363,14 @@ class Series:
 
         create_empty_file(self.__applied_file)
         create_empty_file(self.__unapplied_file)
+        self.__begin_stack_check()
 
-    def refresh_patch(self, message = None, edit = False,
+    def refresh_patch(self, message = None, edit = False, show_patch = False,
+                      cache_update = True,
                       author_name = None, author_email = None,
                       author_date = None,
-                      committer_name = None, committer_email = None):
+                      committer_name = None, committer_email = None,
+                      commit_only = False):
         """Generates a new commit for the given patch
         """
         name = self.get_current()
@@ -345,9 +387,9 @@ class Series:
             descr = message
 
         if not message and edit:
-            descr = edit_file(descr.rstrip(), \
+            descr = edit_file(self, descr.rstrip(), \
                               'Please edit the description for patch "%s" ' \
-                              'above.' % name)
+                              'above.' % name, show_patch)
 
         if not author_name:
             author_name = patch.get_authname()
@@ -361,6 +403,7 @@ class Series:
             committer_email = patch.get_commemail()
 
         commit_id = git.commit(message = descr, parents = [patch.get_bottom()],
+                               cache_update = cache_update,
                                allowempty = True,
                                author_name = author_name,
                                author_email = author_email,
@@ -368,15 +411,20 @@ class Series:
                                committer_name = committer_name,
                                committer_email = committer_email)
 
-        patch.set_top(commit_id)
-        patch.set_description(descr)
-        patch.set_authname(author_name)
-        patch.set_authemail(author_email)
-        patch.set_authdate(author_date)
-        patch.set_commname(committer_name)
-        patch.set_commemail(committer_email)
+        if not commit_only:
+            patch.set_top(commit_id)
+            patch.set_description(descr)
+            patch.set_authname(author_name)
+            patch.set_authemail(author_email)
+            patch.set_authdate(author_date)
+            patch.set_commname(committer_name)
+            patch.set_commemail(committer_email)
 
-    def new_patch(self, name, message = None, edit = False,
+        return commit_id
+
+    def new_patch(self, name, message = None, can_edit = True,
+                  unapplied = False, show_patch = False,
+                  top = None, bottom = None,
                   author_name = None, author_email = None, author_date = None,
                   committer_name = None, committer_email = None):
         """Creates a new patch
@@ -384,10 +432,10 @@ class Series:
         if self.__patch_applied(name) or self.__patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
-        if not message:
-            descr = edit_file(None, \
+        if not message and can_edit:
+            descr = edit_file(self, None, \
                               'Please enter the description for patch "%s" ' \
-                              'above.' % name)
+                              'above.' % name, show_patch)
         else:
             descr = message
 
@@ -397,8 +445,16 @@ class Series:
 
         patch = Patch(name, self.__patch_dir)
         patch.create()
-        patch.set_bottom(head)
-        patch.set_top(head)
+
+        if bottom:
+            patch.set_bottom(bottom)
+        else:
+            patch.set_bottom(head)
+        if top:
+            patch.set_top(top)
+        else:
+            patch.set_top(head)
+
         patch.set_description(descr)
         patch.set_authname(author_name)
         patch.set_authemail(author_email)
@@ -406,8 +462,15 @@ class Series:
         patch.set_commname(committer_name)
         patch.set_commemail(committer_email)
 
-        append_string(self.__applied_file, patch.get_name())
-        self.__set_current(name)
+        if unapplied:
+            patches = [patch.get_name()] + self.get_unapplied()
+
+            f = file(self.__unapplied_file, 'w+')
+            f.writelines([line + '\n' for line in patches])
+            f.close()
+        else:
+            append_string(self.__applied_file, patch.get_name())
+            self.__set_current(name)
 
     def delete_patch(self, name):
         """Deletes a patch
@@ -430,6 +493,80 @@ class Series:
         f.writelines([line + '\n' for line in unapplied])
         f.close()
 
+    def forward_patches(self, names):
+        """Try to fast-forward an array of patches.
+
+        On return, patches in names[0:returned_value] have been pushed on the
+        stack. Apply the rest with push_patch
+        """
+        unapplied = self.get_unapplied()
+        self.__begin_stack_check()
+
+        forwarded = 0
+        top = git.get_head()
+
+        for name in names:
+            assert(name in unapplied)
+
+            patch = Patch(name, self.__patch_dir)
+
+            head = top
+            bottom = patch.get_bottom()
+            top = patch.get_top()
+
+            # top != bottom always since we have a commit for each patch
+            if head == bottom:
+                # reset the backup information
+                patch.set_bottom(head, backup = True)
+                patch.set_top(top, backup = True)
+
+            else:
+                head_tree = git.get_commit(head).get_tree()
+                bottom_tree = git.get_commit(bottom).get_tree()
+                if head_tree == bottom_tree:
+                    # We must just reparent this patch and create a new commit
+                    # for it
+                    descr = patch.get_description()
+                    author_name = patch.get_authname()
+                    author_email = patch.get_authemail()
+                    author_date = patch.get_authdate()
+                    committer_name = patch.get_commname()
+                    committer_email = patch.get_commemail()
+
+                    top_tree = git.get_commit(top).get_tree()
+
+                    top = git.commit(message = descr, parents = [head],
+                                     cache_update = False,
+                                     tree_id = top_tree,
+                                     allowempty = True,
+                                     author_name = author_name,
+                                     author_email = author_email,
+                                     author_date = author_date,
+                                     committer_name = committer_name,
+                                     committer_email = committer_email)
+
+                    patch.set_bottom(head, backup = True)
+                    patch.set_top(top, backup = True)
+                else:
+                    top = head
+                    # stop the fast-forwarding, must do a real merge
+                    break
+
+            forwarded+=1
+            unapplied.remove(name)
+
+        git.switch(top)
+
+        append_strings(self.__applied_file, names[0:forwarded])
+
+        f = file(self.__unapplied_file, 'w+')
+        f.writelines([line + '\n' for line in unapplied])
+        f.close()
+
+        self.__set_current(name)
+
+        return forwarded
+
     def push_patch(self, name):
         """Pushes a patch on the stack
         """
@@ -476,19 +613,23 @@ class Series:
 
         self.__set_current(name)
 
-        if not ex:
-            # if the merge was OK and no conflicts, just refresh the patch
-            self.refresh_patch()
-        else:
-            raise StackException, str(ex)
+        # head == bottom case doesn't need to refresh the patch
+        if head != bottom:
+            if not ex:
+                # if the merge was OK and no conflicts, just refresh the patch
+                # The GIT cache was already updated by the merge operation
+                self.refresh_patch(cache_update = False)
+            else:
+                raise StackException, str(ex)
 
     def undo_push(self):
         name = self.get_current()
         assert(name)
 
         patch = Patch(name, self.__patch_dir)
+        git.reset()
         self.pop_patch(name)
-        patch.restore_old_boundaries()
+        return patch.restore_old_boundaries()
 
     def pop_patch(self, name):
         """Pops the top patch from the stack
@@ -535,7 +676,8 @@ class Series:
 
         if bottom == top:
             return True
-        elif git.Commit(top).get_tree() == git.Commit(bottom).get_tree():
+        elif git.get_commit(top).get_tree() \
+                 == git.get_commit(bottom).get_tree():
             return True
 
         return False