Don't use refs/bases/<branchname>
[stgit] / stgit / stack.py
index dbdda01..2477ac6 100644 (file)
@@ -18,7 +18,7 @@ along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 """
 
-import sys, os
+import sys, os, re
 
 from stgit.utils import *
 from stgit import git, basedir, templates
@@ -91,19 +91,7 @@ def edit_file(series, line, comment, show_patch = True):
     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
     f.close()
 
-    # the editor
-    editor = config.get('stgit.editor')
-    if editor:
-        pass
-    elif 'EDITOR' in os.environ:
-        editor = os.environ['EDITOR']
-    else:
-        editor = 'vi'
-    editor += ' %s' % fname
-
-    print 'Invoking the editor: "%s"...' % editor,
-    sys.stdout.flush()
-    print 'done (exit code: %d)' % os.system(editor)
+    call_editor(fname)
 
     f = file(fname, 'r+')
 
@@ -303,8 +291,6 @@ class Series(StgitObject):
         self._set_dir(os.path.join(self.__base_dir, 'patches', self.__name))
         self.__refs_dir = os.path.join(self.__base_dir, 'refs', 'patches',
                                        self.__name)
-        self.__base_file = os.path.join(self.__base_dir, 'refs', 'bases',
-                                        self.__name)
 
         self.__applied_file = os.path.join(self._dir(), 'applied')
         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
@@ -328,6 +314,12 @@ class Series(StgitObject):
         if self.is_initialised() and not os.path.isdir(self.__trash_dir):
             os.makedirs(self.__trash_dir)
 
+    def __patch_name_valid(self, name):
+        """Raise an exception if the patch name is not valid.
+        """
+        if not name or re.search('[^\w.-]', name):
+            raise StackException, 'Invalid patch name: "%s"' % name
+
     def get_branch(self):
         """Return the branch name for the Series object
         """
@@ -384,9 +376,23 @@ class Series(StgitObject):
         f.close()
         return names
 
-    def get_base_file(self):
-        self.__begin_stack_check()
-        return self.__base_file
+    def get_base(self):
+        # Return the parent of the bottommost patch, if there is one.
+        if os.path.isfile(self.__applied_file):
+            bottommost = file(self.__applied_file).readline().strip()
+            if bottommost:
+                return self.get_patch(bottommost).get_bottom()
+        # No bottommost patch, so just return HEAD
+        return git.get_head()
+
+    def get_head(self):
+        """Return the head of the branch
+        """
+        crt = self.get_current_patch()
+        if crt:
+            return crt.get_top()
+        else:
+            return self.get_base()
 
     def get_protected(self):
         return os.path.isfile(os.path.join(self._dir(), 'protected'))
@@ -408,30 +414,50 @@ class Series(StgitObject):
         self._set_field('description', line)
 
     def get_parent_remote(self):
-        return config.get('branch.%s.remote' % self.__name) or 'origin'
+        value = config.get('branch.%s.remote' % self.__name)
+        if value:
+            return value
+        elif 'origin' in git.remotes_list():
+            print 'Notice: no parent remote declared for stack "%s", ' \
+                  'defaulting to "origin". Consider setting "branch.%s.remote" ' \
+                  'and "branch.%s.merge" with "git repo-config".' \
+                  % (self.__name, self.__name, self.__name)
+            return 'origin'
+        else:
+            raise StackException, 'Cannot find a parent remote for "%s"' % self.__name
 
     def __set_parent_remote(self, remote):
         value = config.set('branch.%s.remote' % self.__name, remote)
 
     def get_parent_branch(self):
-        value = config.get('branch.%s.merge' % self.__name)
+        value = config.get('branch.%s.stgit.parentbranch' % self.__name)
         if value:
             return value
         elif git.rev_parse('heads/origin'):
+            print 'Notice: no parent branch declared for stack "%s", ' \
+                  'defaulting to "heads/origin". Consider setting ' \
+                  '"branch.%s.stgit.parentbranch" with "git repo-config".' \
+                  % (self.__name, self.__name)
             return 'heads/origin'
         else:
             raise StackException, 'Cannot find a parent branch for "%s"' % self.__name
 
     def __set_parent_branch(self, name):
-        config.set('branch.%s.merge' % self.__name, name)
+        if config.get('branch.%s.remote' % self.__name):
+            # Never set merge if remote is not set to avoid
+            # possibly-erroneous lookups into 'origin'
+            config.set('branch.%s.merge' % self.__name, name)
+        config.set('branch.%s.stgit.parentbranch' % self.__name, name)
 
     def set_parent(self, remote, localbranch):
+        # policy: record local branches as remote='.'
+        recordremote = remote or '.'
         if localbranch:
+            self.__set_parent_remote(recordremote)
             self.__set_parent_branch(localbranch)
-            if remote:
-                self.__set_parent_remote(remote)
-        elif remote:
-            raise StackException, 'Remote "%s" without a branch cannot be used as parent' % remote
+        # We'll enforce this later
+#         else:
+#             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.__name
 
     def __patch_is_current(self, patch):
         return patch.get_name() == self.get_current()
@@ -456,22 +482,6 @@ class Series(StgitObject):
         otherwise."""
         return self.patch_applied(name) or self.patch_unapplied(name)
 
-    def __begin_stack_check(self):
-        """Save the current HEAD into .git/refs/heads/base if the stack
-        is empty
-        """
-        if len(self.get_applied()) == 0:
-            head = git.get_head()
-            write_string(self.__base_file, head)
-
-    def __end_stack_check(self):
-        """Remove .git/refs/heads/base if the stack is empty.
-        This warning should never happen
-        """
-        if len(self.get_applied()) == 0 \
-           and read_string(self.__base_file) != git.get_head():
-            print 'Warning: stack empty but the HEAD and base are different'
-
     def head_top_equal(self):
         """Return true if the head and the top are the same
         """
@@ -489,14 +499,10 @@ class Series(StgitObject):
     def init(self, create_at=False, parent_remote=None, parent_branch=None):
         """Initialises the stgit series
         """
-        bases_dir = os.path.join(self.__base_dir, 'refs', 'bases')
-
         if os.path.exists(self.__patch_dir):
             raise StackException, self.__patch_dir + ' already exists'
         if os.path.exists(self.__refs_dir):
             raise StackException, self.__refs_dir + ' already exists'
-        if os.path.exists(self.__base_file):
-            raise StackException, self.__base_file + ' already exists'
 
         if (create_at!=False):
             git.create_branch(self.__name, create_at)
@@ -504,15 +510,13 @@ class Series(StgitObject):
         os.makedirs(self.__patch_dir)
 
         self.set_parent(parent_remote, parent_branch)
-        
-        create_dirs(bases_dir)
 
         self.create_empty_field('applied')
         self.create_empty_field('unapplied')
         self.create_empty_field('description')
         os.makedirs(os.path.join(self._dir(), 'patches'))
         os.makedirs(self.__refs_dir)
-        self.__begin_stack_check()
+        self._set_field('orig-base', git.get_head())
 
     def convert(self):
         """Either convert to use a separate patch directory, or
@@ -546,7 +550,7 @@ class Series(StgitObject):
                 os.rmdir(self.__patch_dir)
                 print 'done'
             else:
-                print 'Patch directory %s is not empty.' % self.__name
+                print 'Patch directory %s is not empty.' % self.__patch_dir
 
             self.__patch_dir = self._dir()
 
@@ -557,21 +561,20 @@ class Series(StgitObject):
 
         if to_stack.is_initialised():
             raise StackException, '"%s" already exists' % to_stack.get_branch()
-        if os.path.exists(to_stack.__base_file):
-            os.remove(to_stack.__base_file)
 
         git.rename_branch(self.__name, to_name)
 
         if os.path.isdir(self._dir()):
             rename(os.path.join(self.__base_dir, 'patches'),
                    self.__name, to_stack.__name)
-        if os.path.exists(self.__base_file):
-            rename(os.path.join(self.__base_dir, 'refs', 'bases'),
-                   self.__name, to_stack.__name)
         if os.path.exists(self.__refs_dir):
             rename(os.path.join(self.__base_dir, 'refs', 'patches'),
                    self.__name, to_stack.__name)
 
+        # Rename the config section
+        config.rename_section("branch.%s" % self.__name,
+                              "branch.%s" % to_name)
+
         self.__init__(to_name)
 
     def clone(self, target_series):
@@ -579,7 +582,7 @@ class Series(StgitObject):
         """
         try:
             # allow cloning of branches not under StGIT control
-            base = read_string(self.get_base_file())
+            base = self.get_base()
         except:
             base = git.get_head()
         Series(target_series).init(create_at = base)
@@ -610,6 +613,19 @@ class Series(StgitObject):
         # fast forward the cloned series to self's top
         new_series.forward_patches(applied)
 
+        # Clone parent informations
+        value = config.get('branch.%s.remote' % self.__name)
+        if value:
+            config.set('branch.%s.remote' % target_series, value)
+
+        value = config.get('branch.%s.merge' % self.__name)
+        if value:
+            config.set('branch.%s.merge' % target_series, value)
+
+        value = config.get('branch.%s.stgit.parentbranch' % self.__name)
+        if value:
+            config.set('branch.%s.stgit.parentbranch' % target_series, value)
+
     def delete(self, force = False):
         """Deletes an stgit series
         """
@@ -627,7 +643,7 @@ class Series(StgitObject):
             os.rmdir(self.__trash_dir)
 
             # FIXME: find a way to get rid of those manual removals
-            # (move functionnality to StgitObject ?)
+            # (move functionality to StgitObject ?)
             if os.path.exists(self.__applied_file):
                 os.remove(self.__applied_file)
             if os.path.exists(self.__unapplied_file):
@@ -638,24 +654,30 @@ class Series(StgitObject):
                 os.remove(self.__current_file)
             if os.path.exists(self.__descr_file):
                 os.remove(self.__descr_file)
+            if os.path.exists(self._dir()+'/orig-base'):
+                os.remove(self._dir()+'/orig-base')
+
             if not os.listdir(self.__patch_dir):
                 os.rmdir(self.__patch_dir)
             else:
-                print 'Patch directory %s is not empty.' % self.__name
-            if not os.listdir(self._dir()):
-                remove_dirs(os.path.join(self.__base_dir, 'patches'),
-                            self.__name)
-            else:
-                print 'Series directory %s is not empty.' % self.__name
-            if not os.listdir(self.__refs_dir):
-                remove_dirs(os.path.join(self.__base_dir, 'refs', 'patches'),
-                            self.__name)
-            else:
+                print 'Patch directory %s is not empty.' % self.__patch_dir
+
+            try:
+                os.removedirs(self._dir())
+            except OSError:
+                raise StackException, 'Series directory %s is not empty.' % self._dir()
+
+            try:
+                os.removedirs(self.__refs_dir)
+            except OSError:
                 print 'Refs directory %s is not empty.' % self.__refs_dir
 
-        if os.path.exists(self.__base_file):
-            remove_file_and_dirs(
-                os.path.join(self.__base_dir, 'refs', 'bases'), self.__name)
+        # Cleanup parent informations
+        # FIXME: should one day make use of git-config --section-remove,
+        # scheduled for 1.5.1
+        config.unset('branch.%s.remote' % self.__name)
+        config.unset('branch.%s.merge' % self.__name)
+        config.unset('branch.%s.stgit.parentbranch' % self.__name)
 
     def refresh_patch(self, files = None, message = None, edit = False,
                       show_patch = False,
@@ -696,7 +718,12 @@ class Series(StgitObject):
             committer_email = patch.get_commemail()
 
         if sign_str:
-            descr = '%s\n%s: %s <%s>\n' % (descr.rstrip(), sign_str,
+            descr = descr.rstrip()
+            if descr.find("\nSigned-off-by:") < 0 \
+               and descr.find("\nAcked-by:") < 0:
+                descr = descr + "\n"
+
+            descr = '%s\n%s: %s <%s>\n' % (descr, sign_str,
                                            committer_name, committer_email)
 
         bottom = patch.get_bottom()
@@ -753,6 +780,8 @@ class Series(StgitObject):
                   before_existing = False, refresh = True):
         """Creates a new patch
         """
+        self.__patch_name_valid(name)
+
         if self.patch_applied(name) or self.patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
@@ -765,8 +794,6 @@ class Series(StgitObject):
 
         head = git.get_head()
 
-        self.__begin_stack_check()
-
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
         patch.create()
 
@@ -809,6 +836,7 @@ class Series(StgitObject):
     def delete_patch(self, name):
         """Deletes a patch
         """
+        self.__patch_name_valid(name)
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
         if self.__patch_is_current(patch):
@@ -833,8 +861,6 @@ class Series(StgitObject):
         if self.patch_hidden(name):
             self.unhide_patch(name)
 
-        self.__begin_stack_check()
-
     def forward_patches(self, names):
         """Try to fast-forward an array of patches.
 
@@ -842,7 +868,6 @@ class Series(StgitObject):
         stack. Apply the rest with push_patch
         """
         unapplied = self.get_unapplied()
-        self.__begin_stack_check()
 
         forwarded = 0
         top = git.get_head()
@@ -941,8 +966,6 @@ class Series(StgitObject):
         unapplied = self.get_unapplied()
         assert(name in unapplied)
 
-        self.__begin_stack_check()
-
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
         head = git.get_head()
@@ -986,7 +1009,8 @@ class Series(StgitObject):
                 except git.GitException, ex:
                     print >> sys.stderr, \
                           'The merge failed during "push". ' \
-                          'Use "refresh" after fixing the conflicts'
+                          'Use "refresh" after fixing the conflicts or ' \
+                          'revert the operation with "push --undo".'
 
         append_string(self.__applied_file, name)
 
@@ -1010,7 +1034,7 @@ class Series(StgitObject):
             else:
                 # we store the correctly merged files only for
                 # tracking the conflict history. Note that the
-                # git.merge() operations shouls always leave the index
+                # git.merge() operations should always leave the index
                 # in a valid state (i.e. only stage 0 files)
                 self.refresh_patch(cache_update = False, log = 'push(c)')
                 raise StackException, str(ex)
@@ -1079,11 +1103,10 @@ class Series(StgitObject):
         else:
             self.__set_current(applied[-1])
 
-        self.__end_stack_check()
-
     def empty_patch(self, name):
         """Returns True if the patch is empty
         """
+        self.__patch_name_valid(name)
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
         bottom = patch.get_bottom()
         top = patch.get_top()
@@ -1097,6 +1120,8 @@ class Series(StgitObject):
         return False
 
     def rename_patch(self, oldname, newname):
+        self.__patch_name_valid(newname)
+
         applied = self.get_applied()
         unapplied = self.get_unapplied()