Don't use refs/bases/<branchname>
[stgit] / stgit / stack.py
index 33010d9..2477ac6 100644 (file)
@@ -18,7 +18,7 @@ along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 """
 
-import sys, os
+import sys, os, re
 
 from stgit.utils import *
 from stgit import git, basedir, templates
@@ -91,18 +91,7 @@ def edit_file(series, line, comment, show_patch = True):
     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
     f.close()
 
-    # the editor
-    if config.has_option('stgit', 'editor'):
-        editor = config.get('stgit', 'editor')
-    elif 'EDITOR' in os.environ:
-        editor = os.environ['EDITOR']
-    else:
-        editor = 'vi'
-    editor += ' %s' % fname
-
-    print 'Invoking the editor: "%s"...' % editor,
-    sys.stdout.flush()
-    print 'done (exit code: %d)' % os.system(editor)
+    call_editor(fname)
 
     f = file(fname, 'r+')
 
@@ -302,11 +291,10 @@ class Series(StgitObject):
         self._set_dir(os.path.join(self.__base_dir, 'patches', self.__name))
         self.__refs_dir = os.path.join(self.__base_dir, 'refs', 'patches',
                                        self.__name)
-        self.__base_file = os.path.join(self.__base_dir, 'refs', 'bases',
-                                        self.__name)
 
         self.__applied_file = os.path.join(self._dir(), 'applied')
         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
+        self.__hidden_file = os.path.join(self._dir(), 'hidden')
         self.__current_file = os.path.join(self._dir(), 'current')
         self.__descr_file = os.path.join(self._dir(), 'description')
 
@@ -326,6 +314,12 @@ class Series(StgitObject):
         if self.is_initialised() and not os.path.isdir(self.__trash_dir):
             os.makedirs(self.__trash_dir)
 
+    def __patch_name_valid(self, name):
+        """Raise an exception if the patch name is not valid.
+        """
+        if not name or re.search('[^\w.-]', name):
+            raise StackException, 'Invalid patch name: "%s"' % name
+
     def get_branch(self):
         """Return the branch name for the Series object
         """
@@ -374,9 +368,31 @@ class Series(StgitObject):
         f.close()
         return names
 
-    def get_base_file(self):
-        self.__begin_stack_check()
-        return self.__base_file
+    def get_hidden(self):
+        if not os.path.isfile(self.__hidden_file):
+            return []
+        f = file(self.__hidden_file)
+        names = [line.strip() for line in f.readlines()]
+        f.close()
+        return names
+
+    def get_base(self):
+        # Return the parent of the bottommost patch, if there is one.
+        if os.path.isfile(self.__applied_file):
+            bottommost = file(self.__applied_file).readline().strip()
+            if bottommost:
+                return self.get_patch(bottommost).get_bottom()
+        # No bottommost patch, so just return HEAD
+        return git.get_head()
+
+    def get_head(self):
+        """Return the head of the branch
+        """
+        crt = self.get_current_patch()
+        if crt:
+            return crt.get_top()
+        else:
+            return self.get_base()
 
     def get_protected(self):
         return os.path.isfile(os.path.join(self._dir(), 'protected'))
@@ -397,6 +413,52 @@ class Series(StgitObject):
     def set_description(self, line):
         self._set_field('description', line)
 
+    def get_parent_remote(self):
+        value = config.get('branch.%s.remote' % self.__name)
+        if value:
+            return value
+        elif 'origin' in git.remotes_list():
+            print 'Notice: no parent remote declared for stack "%s", ' \
+                  'defaulting to "origin". Consider setting "branch.%s.remote" ' \
+                  'and "branch.%s.merge" with "git repo-config".' \
+                  % (self.__name, self.__name, self.__name)
+            return 'origin'
+        else:
+            raise StackException, 'Cannot find a parent remote for "%s"' % self.__name
+
+    def __set_parent_remote(self, remote):
+        value = config.set('branch.%s.remote' % self.__name, remote)
+
+    def get_parent_branch(self):
+        value = config.get('branch.%s.stgit.parentbranch' % self.__name)
+        if value:
+            return value
+        elif git.rev_parse('heads/origin'):
+            print 'Notice: no parent branch declared for stack "%s", ' \
+                  'defaulting to "heads/origin". Consider setting ' \
+                  '"branch.%s.stgit.parentbranch" with "git repo-config".' \
+                  % (self.__name, self.__name)
+            return 'heads/origin'
+        else:
+            raise StackException, 'Cannot find a parent branch for "%s"' % self.__name
+
+    def __set_parent_branch(self, name):
+        if config.get('branch.%s.remote' % self.__name):
+            # Never set merge if remote is not set to avoid
+            # possibly-erroneous lookups into 'origin'
+            config.set('branch.%s.merge' % self.__name, name)
+        config.set('branch.%s.stgit.parentbranch' % self.__name, name)
+
+    def set_parent(self, remote, localbranch):
+        # policy: record local branches as remote='.'
+        recordremote = remote or '.'
+        if localbranch:
+            self.__set_parent_remote(recordremote)
+            self.__set_parent_branch(localbranch)
+        # We'll enforce this later
+#         else:
+#             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.__name
+
     def __patch_is_current(self, patch):
         return patch.get_name() == self.get_current()
 
@@ -410,27 +472,16 @@ class Series(StgitObject):
         """
         return name in self.get_unapplied()
 
+    def patch_hidden(self, name):
+        """Return true if the patch is hidden.
+        """
+        return name in self.get_hidden()
+
     def patch_exists(self, name):
         """Return true if there is a patch with the given name, false
         otherwise."""
         return self.patch_applied(name) or self.patch_unapplied(name)
 
-    def __begin_stack_check(self):
-        """Save the current HEAD into .git/refs/heads/base if the stack
-        is empty
-        """
-        if len(self.get_applied()) == 0:
-            head = git.get_head()
-            write_string(self.__base_file, head)
-
-    def __end_stack_check(self):
-        """Remove .git/refs/heads/base if the stack is empty.
-        This warning should never happen
-        """
-        if len(self.get_applied()) == 0 \
-           and read_string(self.__base_file) != git.get_head():
-            print 'Warning: stack empty but the HEAD and base are different'
-
     def head_top_equal(self):
         """Return true if the head and the top are the same
         """
@@ -445,31 +496,27 @@ class Series(StgitObject):
         """
         return os.path.isdir(self.__patch_dir)
 
-    def init(self, create_at=False):
+    def init(self, create_at=False, parent_remote=None, parent_branch=None):
         """Initialises the stgit series
         """
-        bases_dir = os.path.join(self.__base_dir, 'refs', 'bases')
-
         if os.path.exists(self.__patch_dir):
             raise StackException, self.__patch_dir + ' already exists'
         if os.path.exists(self.__refs_dir):
             raise StackException, self.__refs_dir + ' already exists'
-        if os.path.exists(self.__base_file):
-            raise StackException, self.__base_file + ' already exists'
 
         if (create_at!=False):
             git.create_branch(self.__name, create_at)
 
         os.makedirs(self.__patch_dir)
 
-        create_dirs(bases_dir)
+        self.set_parent(parent_remote, parent_branch)
 
         self.create_empty_field('applied')
         self.create_empty_field('unapplied')
         self.create_empty_field('description')
         os.makedirs(os.path.join(self._dir(), 'patches'))
         os.makedirs(self.__refs_dir)
-        self.__begin_stack_check()
+        self._set_field('orig-base', git.get_head())
 
     def convert(self):
         """Either convert to use a separate patch directory, or
@@ -503,7 +550,7 @@ class Series(StgitObject):
                 os.rmdir(self.__patch_dir)
                 print 'done'
             else:
-                print 'Patch directory %s is not empty.' % self.__name
+                print 'Patch directory %s is not empty.' % self.__patch_dir
 
             self.__patch_dir = self._dir()
 
@@ -514,21 +561,20 @@ class Series(StgitObject):
 
         if to_stack.is_initialised():
             raise StackException, '"%s" already exists' % to_stack.get_branch()
-        if os.path.exists(to_stack.__base_file):
-            os.remove(to_stack.__base_file)
 
         git.rename_branch(self.__name, to_name)
 
         if os.path.isdir(self._dir()):
             rename(os.path.join(self.__base_dir, 'patches'),
                    self.__name, to_stack.__name)
-        if os.path.exists(self.__base_file):
-            rename(os.path.join(self.__base_dir, 'refs', 'bases'),
-                   self.__name, to_stack.__name)
         if os.path.exists(self.__refs_dir):
             rename(os.path.join(self.__base_dir, 'refs', 'patches'),
                    self.__name, to_stack.__name)
 
+        # Rename the config section
+        config.rename_section("branch.%s" % self.__name,
+                              "branch.%s" % to_name)
+
         self.__init__(to_name)
 
     def clone(self, target_series):
@@ -536,7 +582,7 @@ class Series(StgitObject):
         """
         try:
             # allow cloning of branches not under StGIT control
-            base = read_string(self.get_base_file())
+            base = self.get_base()
         except:
             base = git.get_head()
         Series(target_series).init(create_at = base)
@@ -567,6 +613,19 @@ class Series(StgitObject):
         # fast forward the cloned series to self's top
         new_series.forward_patches(applied)
 
+        # Clone parent informations
+        value = config.get('branch.%s.remote' % self.__name)
+        if value:
+            config.set('branch.%s.remote' % target_series, value)
+
+        value = config.get('branch.%s.merge' % self.__name)
+        if value:
+            config.set('branch.%s.merge' % target_series, value)
+
+        value = config.get('branch.%s.stgit.parentbranch' % self.__name)
+        if value:
+            config.set('branch.%s.stgit.parentbranch' % target_series, value)
+
     def delete(self, force = False):
         """Deletes an stgit series
         """
@@ -584,33 +643,41 @@ class Series(StgitObject):
             os.rmdir(self.__trash_dir)
 
             # FIXME: find a way to get rid of those manual removals
-            # (move functionnality to StgitObject ?)
+            # (move functionality to StgitObject ?)
             if os.path.exists(self.__applied_file):
                 os.remove(self.__applied_file)
             if os.path.exists(self.__unapplied_file):
                 os.remove(self.__unapplied_file)
+            if os.path.exists(self.__hidden_file):
+                os.remove(self.__hidden_file)
             if os.path.exists(self.__current_file):
                 os.remove(self.__current_file)
             if os.path.exists(self.__descr_file):
                 os.remove(self.__descr_file)
+            if os.path.exists(self._dir()+'/orig-base'):
+                os.remove(self._dir()+'/orig-base')
+
             if not os.listdir(self.__patch_dir):
                 os.rmdir(self.__patch_dir)
             else:
-                print 'Patch directory %s is not empty.' % self.__name
-            if not os.listdir(self._dir()):
-                remove_dirs(os.path.join(self.__base_dir, 'patches'),
-                            self.__name)
-            else:
-                print 'Series directory %s is not empty.' % self.__name
-            if not os.listdir(self.__refs_dir):
-                remove_dirs(os.path.join(self.__base_dir, 'refs', 'patches'),
-                            self.__name)
-            else:
+                print 'Patch directory %s is not empty.' % self.__patch_dir
+
+            try:
+                os.removedirs(self._dir())
+            except OSError:
+                raise StackException, 'Series directory %s is not empty.' % self._dir()
+
+            try:
+                os.removedirs(self.__refs_dir)
+            except OSError:
                 print 'Refs directory %s is not empty.' % self.__refs_dir
 
-        if os.path.exists(self.__base_file):
-            remove_file_and_dirs(
-                os.path.join(self.__base_dir, 'refs', 'bases'), self.__name)
+        # Cleanup parent informations
+        # FIXME: should one day make use of git-config --section-remove,
+        # scheduled for 1.5.1
+        config.unset('branch.%s.remote' % self.__name)
+        config.unset('branch.%s.merge' % self.__name)
+        config.unset('branch.%s.stgit.parentbranch' % self.__name)
 
     def refresh_patch(self, files = None, message = None, edit = False,
                       show_patch = False,
@@ -651,7 +718,12 @@ class Series(StgitObject):
             committer_email = patch.get_commemail()
 
         if sign_str:
-            descr = '%s\n%s: %s <%s>\n' % (descr.rstrip(), sign_str,
+            descr = descr.rstrip()
+            if descr.find("\nSigned-off-by:") < 0 \
+               and descr.find("\nAcked-by:") < 0:
+                descr = descr + "\n"
+
+            descr = '%s\n%s: %s <%s>\n' % (descr, sign_str,
                                            committer_name, committer_email)
 
         bottom = patch.get_bottom()
@@ -708,6 +780,8 @@ class Series(StgitObject):
                   before_existing = False, refresh = True):
         """Creates a new patch
         """
+        self.__patch_name_valid(name)
+
         if self.patch_applied(name) or self.patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
@@ -720,8 +794,6 @@ class Series(StgitObject):
 
         head = git.get_head()
 
-        self.__begin_stack_check()
-
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
         patch.create()
 
@@ -764,6 +836,7 @@ class Series(StgitObject):
     def delete_patch(self, name):
         """Deletes a patch
         """
+        self.__patch_name_valid(name)
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
         if self.__patch_is_current(patch):
@@ -784,7 +857,9 @@ class Series(StgitObject):
         f = file(self.__unapplied_file, 'w+')
         f.writelines([line + '\n' for line in unapplied])
         f.close()
-        self.__begin_stack_check()
+
+        if self.patch_hidden(name):
+            self.unhide_patch(name)
 
     def forward_patches(self, names):
         """Try to fast-forward an array of patches.
@@ -793,7 +868,6 @@ class Series(StgitObject):
         stack. Apply the rest with push_patch
         """
         unapplied = self.get_unapplied()
-        self.__begin_stack_check()
 
         forwarded = 0
         top = git.get_head()
@@ -892,8 +966,6 @@ class Series(StgitObject):
         unapplied = self.get_unapplied()
         assert(name in unapplied)
 
-        self.__begin_stack_check()
-
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
         head = git.get_head()
@@ -937,7 +1009,8 @@ class Series(StgitObject):
                 except git.GitException, ex:
                     print >> sys.stderr, \
                           'The merge failed during "push". ' \
-                          'Use "refresh" after fixing the conflicts'
+                          'Use "refresh" after fixing the conflicts or ' \
+                          'revert the operation with "push --undo".'
 
         append_string(self.__applied_file, name)
 
@@ -959,6 +1032,11 @@ class Series(StgitObject):
                     log = 'push'
                 self.refresh_patch(cache_update = False, log = log)
             else:
+                # we store the correctly merged files only for
+                # tracking the conflict history. Note that the
+                # git.merge() operations should always leave the index
+                # in a valid state (i.e. only stage 0 files)
+                self.refresh_patch(cache_update = False, log = 'push(c)')
                 raise StackException, str(ex)
 
         return modified
@@ -1025,11 +1103,10 @@ class Series(StgitObject):
         else:
             self.__set_current(applied[-1])
 
-        self.__end_stack_check()
-
     def empty_patch(self, name):
         """Returns True if the patch is empty
         """
+        self.__patch_name_valid(name)
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
         bottom = patch.get_bottom()
         top = patch.get_top()
@@ -1043,6 +1120,8 @@ class Series(StgitObject):
         return False
 
     def rename_patch(self, oldname, newname):
+        self.__patch_name_valid(newname)
+
         applied = self.get_applied()
         unapplied = self.get_unapplied()
 
@@ -1052,6 +1131,10 @@ class Series(StgitObject):
         if newname in applied or newname in unapplied:
             raise StackException, 'Patch "%s" already exists' % newname
 
+        if self.patch_hidden(oldname):
+            self.unhide_patch(oldname)
+            self.hide_patch(newname)
+
         if oldname in unapplied:
             Patch(oldname, self.__patch_dir, self.__refs_dir).rename(newname)
             unapplied[unapplied.index(oldname)] = newname
@@ -1088,3 +1171,28 @@ class Series(StgitObject):
                          cache_update = False, tree_id = top.get_tree(),
                          allowempty = True)
         patch.set_log(log)
+
+    def hide_patch(self, name):
+        """Add the patch to the hidden list.
+        """
+        if not self.patch_exists(name):
+            raise StackException, 'Unknown patch "%s"' % name
+        elif self.patch_hidden(name):
+            raise StackException, 'Patch "%s" already hidden' % name
+
+        append_string(self.__hidden_file, name)
+
+    def unhide_patch(self, name):
+        """Add the patch to the hidden list.
+        """
+        if not self.patch_exists(name):
+            raise StackException, 'Unknown patch "%s"' % name
+        hidden = self.get_hidden()
+        if not name in hidden:
+            raise StackException, 'Patch "%s" not hidden' % name
+
+        hidden.remove(name)
+
+        f = file(self.__hidden_file, 'w+')
+        f.writelines([line + '\n' for line in hidden])
+        f.close()