Have only a single command in each test_expect_failure
[stgit] / stgit / stack.py
index f57e4f0..d9c4b99 100644 (file)
@@ -23,6 +23,7 @@ import sys, os, re
 from stgit.utils import *
 from stgit import git, basedir, templates
 from stgit.config import config
+from shutil import copyfile
 
 
 # stack exception class
@@ -91,19 +92,7 @@ def edit_file(series, line, comment, show_patch = True):
     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
     f.close()
 
-    # the editor
-    editor = config.get('stgit.editor')
-    if editor:
-        pass
-    elif 'EDITOR' in os.environ:
-        editor = os.environ['EDITOR']
-    else:
-        editor = 'vi'
-    editor += ' %s' % fname
-
-    print 'Invoking the editor: "%s"...' % editor,
-    sys.stdout.flush()
-    print 'done (exit code: %d)' % os.system(editor)
+    call_editor(fname)
 
     f = file(fname, 'r+')
 
@@ -303,14 +292,10 @@ class Series(StgitObject):
         self._set_dir(os.path.join(self.__base_dir, 'patches', self.__name))
         self.__refs_dir = os.path.join(self.__base_dir, 'refs', 'patches',
                                        self.__name)
-        self.__base_file = os.path.join(self.__base_dir, 'refs', 'bases',
-                                        self.__name)
 
         self.__applied_file = os.path.join(self._dir(), 'applied')
         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
         self.__hidden_file = os.path.join(self._dir(), 'hidden')
-        self.__current_file = os.path.join(self._dir(), 'current')
-        self.__descr_file = os.path.join(self._dir(), 'description')
 
         # where this series keeps its patches
         self.__patch_dir = os.path.join(self._dir(), 'patches')
@@ -339,11 +324,6 @@ class Series(StgitObject):
         """
         return self.__name
 
-    def __set_current(self, name):
-        """Sets the topmost patch
-        """
-        self._set_field('current', name)
-
     def get_patch(self, name):
         """Return a Patch object for the given name
         """
@@ -360,11 +340,16 @@ class Series(StgitObject):
     def get_current(self):
         """Return the name of the topmost patch, or None if there is
         no such patch."""
-        name = self._get_field('current')
-        if name == '':
+        try:
+            applied = self.get_applied()
+        except StackException:
+            # No "applied" file: branch is not initialized.
+            return None
+        try:
+            return applied[-1]
+        except IndexError:
+            # No patches applied.
             return None
-        else:
-            return name
 
     def get_applied(self):
         if not os.path.isfile(self.__applied_file):
@@ -390,9 +375,23 @@ class Series(StgitObject):
         f.close()
         return names
 
-    def get_base_file(self):
-        self.__begin_stack_check()
-        return self.__base_file
+    def get_base(self):
+        # Return the parent of the bottommost patch, if there is one.
+        if os.path.isfile(self.__applied_file):
+            bottommost = file(self.__applied_file).readline().strip()
+            if bottommost:
+                return self.get_patch(bottommost).get_bottom()
+        # No bottommost patch, so just return HEAD
+        return git.get_head()
+
+    def get_head(self):
+        """Return the head of the branch
+        """
+        crt = self.get_current_patch()
+        if crt:
+            return crt.get_top()
+        else:
+            return self.get_base()
 
     def get_protected(self):
         return os.path.isfile(os.path.join(self._dir(), 'protected'))
@@ -407,37 +406,67 @@ class Series(StgitObject):
         if os.path.isfile(protect_file):
             os.remove(protect_file)
 
+    def __branch_descr(self):
+        return 'branch.%s.description' % self.get_branch()
+
     def get_description(self):
-        return self._get_field('description') or ''
+        # Fall back to the .git/patches/<branch>/description file if
+        # the config variable is unset.
+        return (config.get(self.__branch_descr())
+                or self._get_field('description') or '')
 
     def set_description(self, line):
-        self._set_field('description', line)
+        if line:
+            config.set(self.__branch_descr(), line)
+        else:
+            config.unset(self.__branch_descr())
+        # Delete the old .git/patches/<branch>/description file if it
+        # exists.
+        self._set_field('description', None)
 
     def get_parent_remote(self):
-        return config.get('branch.%s.remote' % self.__name) or 'origin'
+        value = config.get('branch.%s.remote' % self.__name)
+        if value:
+            return value
+        elif 'origin' in git.remotes_list():
+            print 'Notice: no parent remote declared for stack "%s", ' \
+                  'defaulting to "origin". Consider setting "branch.%s.remote" ' \
+                  'and "branch.%s.merge" with "git repo-config".' \
+                  % (self.__name, self.__name, self.__name)
+            return 'origin'
+        else:
+            raise StackException, 'Cannot find a parent remote for "%s"' % self.__name
 
     def __set_parent_remote(self, remote):
         value = config.set('branch.%s.remote' % self.__name, remote)
 
     def get_parent_branch(self):
-        value = config.get('branch.%s.merge' % self.__name)
+        value = config.get('branch.%s.stgit.parentbranch' % self.__name)
         if value:
             return value
         elif git.rev_parse('heads/origin'):
+            print 'Notice: no parent branch declared for stack "%s", ' \
+                  'defaulting to "heads/origin". Consider setting ' \
+                  '"branch.%s.stgit.parentbranch" with "git repo-config".' \
+                  % (self.__name, self.__name)
             return 'heads/origin'
         else:
             raise StackException, 'Cannot find a parent branch for "%s"' % self.__name
 
     def __set_parent_branch(self, name):
-        config.set('branch.%s.merge' % self.__name, name)
+        if config.get('branch.%s.remote' % self.__name):
+            # Never set merge if remote is not set to avoid
+            # possibly-erroneous lookups into 'origin'
+            config.set('branch.%s.merge' % self.__name, name)
+        config.set('branch.%s.stgit.parentbranch' % self.__name, name)
 
     def set_parent(self, remote, localbranch):
         if localbranch:
+            self.__set_parent_remote(remote)
             self.__set_parent_branch(localbranch)
-            if remote:
-                self.__set_parent_remote(remote)
-        elif remote:
-            raise StackException, 'Remote "%s" without a branch cannot be used as parent' % remote
+        # We'll enforce this later
+#         else:
+#             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.__name
 
     def __patch_is_current(self, patch):
         return patch.get_name() == self.get_current()
@@ -462,22 +491,6 @@ class Series(StgitObject):
         otherwise."""
         return self.patch_applied(name) or self.patch_unapplied(name)
 
-    def __begin_stack_check(self):
-        """Save the current HEAD into .git/refs/heads/base if the stack
-        is empty
-        """
-        if len(self.get_applied()) == 0:
-            head = git.get_head()
-            write_string(self.__base_file, head)
-
-    def __end_stack_check(self):
-        """Remove .git/refs/heads/base if the stack is empty.
-        This warning should never happen
-        """
-        if len(self.get_applied()) == 0 \
-           and read_string(self.__base_file) != git.get_head():
-            print 'Warning: stack empty but the HEAD and base are different'
-
     def head_top_equal(self):
         """Return true if the head and the top are the same
         """
@@ -495,14 +508,10 @@ class Series(StgitObject):
     def init(self, create_at=False, parent_remote=None, parent_branch=None):
         """Initialises the stgit series
         """
-        bases_dir = os.path.join(self.__base_dir, 'refs', 'bases')
-
         if os.path.exists(self.__patch_dir):
             raise StackException, self.__patch_dir + ' already exists'
         if os.path.exists(self.__refs_dir):
             raise StackException, self.__refs_dir + ' already exists'
-        if os.path.exists(self.__base_file):
-            raise StackException, self.__base_file + ' already exists'
 
         if (create_at!=False):
             git.create_branch(self.__name, create_at)
@@ -510,15 +519,12 @@ class Series(StgitObject):
         os.makedirs(self.__patch_dir)
 
         self.set_parent(parent_remote, parent_branch)
-        
-        create_dirs(bases_dir)
 
         self.create_empty_field('applied')
         self.create_empty_field('unapplied')
-        self.create_empty_field('description')
         os.makedirs(os.path.join(self._dir(), 'patches'))
         os.makedirs(self.__refs_dir)
-        self.__begin_stack_check()
+        self._set_field('orig-base', git.get_head())
 
     def convert(self):
         """Either convert to use a separate patch directory, or
@@ -552,7 +558,7 @@ class Series(StgitObject):
                 os.rmdir(self.__patch_dir)
                 print 'done'
             else:
-                print 'Patch directory %s is not empty.' % self.__name
+                print 'Patch directory %s is not empty.' % self.__patch_dir
 
             self.__patch_dir = self._dir()
 
@@ -563,21 +569,20 @@ class Series(StgitObject):
 
         if to_stack.is_initialised():
             raise StackException, '"%s" already exists' % to_stack.get_branch()
-        if os.path.exists(to_stack.__base_file):
-            os.remove(to_stack.__base_file)
 
         git.rename_branch(self.__name, to_name)
 
         if os.path.isdir(self._dir()):
             rename(os.path.join(self.__base_dir, 'patches'),
                    self.__name, to_stack.__name)
-        if os.path.exists(self.__base_file):
-            rename(os.path.join(self.__base_dir, 'refs', 'bases'),
-                   self.__name, to_stack.__name)
         if os.path.exists(self.__refs_dir):
             rename(os.path.join(self.__base_dir, 'refs', 'patches'),
                    self.__name, to_stack.__name)
 
+        # Rename the config section
+        config.rename_section("branch.%s" % self.__name,
+                              "branch.%s" % to_name)
+
         self.__init__(to_name)
 
     def clone(self, target_series):
@@ -585,7 +590,7 @@ class Series(StgitObject):
         """
         try:
             # allow cloning of branches not under StGIT control
-            base = read_string(self.get_base_file())
+            base = self.get_base()
         except:
             base = git.get_head()
         Series(target_series).init(create_at = base)
@@ -605,17 +610,35 @@ class Series(StgitObject):
             patches = applied = unapplied = []
         for p in patches:
             patch = self.get_patch(p)
-            new_series.new_patch(p, message = patch.get_description(),
-                                 can_edit = False, unapplied = True,
-                                 bottom = patch.get_bottom(),
-                                 top = patch.get_top(),
-                                 author_name = patch.get_authname(),
-                                 author_email = patch.get_authemail(),
-                                 author_date = patch.get_authdate())
+            newpatch = new_series.new_patch(p, message = patch.get_description(),
+                                            can_edit = False, unapplied = True,
+                                            bottom = patch.get_bottom(),
+                                            top = patch.get_top(),
+                                            author_name = patch.get_authname(),
+                                            author_email = patch.get_authemail(),
+                                            author_date = patch.get_authdate())
+            if patch.get_log():
+                print "setting log to %s" %  patch.get_log()
+                newpatch.set_log(patch.get_log())
+            else:
+                print "no log for %s" % p
 
         # fast forward the cloned series to self's top
         new_series.forward_patches(applied)
 
+        # Clone parent informations
+        value = config.get('branch.%s.remote' % self.__name)
+        if value:
+            config.set('branch.%s.remote' % target_series, value)
+
+        value = config.get('branch.%s.merge' % self.__name)
+        if value:
+            config.set('branch.%s.merge' % target_series, value)
+
+        value = config.get('branch.%s.stgit.parentbranch' % self.__name)
+        if value:
+            config.set('branch.%s.stgit.parentbranch' % target_series, value)
+
     def delete(self, force = False):
         """Deletes an stgit series
         """
@@ -629,39 +652,50 @@ class Series(StgitObject):
 
             # remove the trash directory
             for fname in os.listdir(self.__trash_dir):
-                os.remove(fname)
+                os.remove(os.path.join(self.__trash_dir, fname))
             os.rmdir(self.__trash_dir)
 
             # FIXME: find a way to get rid of those manual removals
-            # (move functionnality to StgitObject ?)
+            # (move functionality to StgitObject ?)
             if os.path.exists(self.__applied_file):
                 os.remove(self.__applied_file)
             if os.path.exists(self.__unapplied_file):
                 os.remove(self.__unapplied_file)
             if os.path.exists(self.__hidden_file):
                 os.remove(self.__hidden_file)
-            if os.path.exists(self.__current_file):
-                os.remove(self.__current_file)
-            if os.path.exists(self.__descr_file):
-                os.remove(self.__descr_file)
+            if os.path.exists(self._dir()+'/orig-base'):
+                os.remove(self._dir()+'/orig-base')
+
+            # Remove obsolete files that StGIT no longer uses, but
+            # that might still be around if this is an old repository.
+            for obsolete in ([os.path.join(self._dir(), fn)
+                              for fn in ['current', 'description']]
+                             + [os.path.join(self.__base_dir,
+                                             'refs', 'bases', self.__name)]):
+                if os.path.exists(obsolete):
+                    os.remove(obsolete)
+
             if not os.listdir(self.__patch_dir):
                 os.rmdir(self.__patch_dir)
             else:
-                print 'Patch directory %s is not empty.' % self.__name
-            if not os.listdir(self._dir()):
-                remove_dirs(os.path.join(self.__base_dir, 'patches'),
-                            self.__name)
-            else:
-                print 'Series directory %s is not empty.' % self.__name
-            if not os.listdir(self.__refs_dir):
-                remove_dirs(os.path.join(self.__base_dir, 'refs', 'patches'),
-                            self.__name)
-            else:
+                print 'Patch directory %s is not empty.' % self.__patch_dir
+
+            try:
+                os.removedirs(self._dir())
+            except OSError:
+                raise StackException, 'Series directory %s is not empty.' % self._dir()
+
+            try:
+                os.removedirs(self.__refs_dir)
+            except OSError:
                 print 'Refs directory %s is not empty.' % self.__refs_dir
 
-        if os.path.exists(self.__base_file):
-            remove_file_and_dirs(
-                os.path.join(self.__base_dir, 'refs', 'bases'), self.__name)
+        # Cleanup parent informations
+        # FIXME: should one day make use of git-config --section-remove,
+        # scheduled for 1.5.1
+        config.unset('branch.%s.remote' % self.__name)
+        config.unset('branch.%s.merge' % self.__name)
+        config.unset('branch.%s.stgit.parentbranch' % self.__name)
 
     def refresh_patch(self, files = None, message = None, edit = False,
                       show_patch = False,
@@ -702,7 +736,12 @@ class Series(StgitObject):
             committer_email = patch.get_commemail()
 
         if sign_str:
-            descr = '%s\n%s: %s <%s>\n' % (descr.rstrip(), sign_str,
+            descr = descr.rstrip()
+            if descr.find("\nSigned-off-by:") < 0 \
+               and descr.find("\nAcked-by:") < 0:
+                descr = descr + "\n"
+
+            descr = '%s\n%s: %s <%s>\n' % (descr, sign_str,
                                            committer_name, committer_email)
 
         bottom = patch.get_bottom()
@@ -759,21 +798,24 @@ class Series(StgitObject):
                   before_existing = False, refresh = True):
         """Creates a new patch
         """
-        self.__patch_name_valid(name)
 
-        if self.patch_applied(name) or self.patch_unapplied(name):
-            raise StackException, 'Patch "%s" already exists' % name
+        if name != None:
+            self.__patch_name_valid(name)
+            if self.patch_applied(name) or self.patch_unapplied(name):
+                raise StackException, 'Patch "%s" already exists' % name
 
         if not message and can_edit:
-            descr = edit_file(self, None, \
-                              'Please enter the description for patch "%s" ' \
-                              'above.' % name, show_patch)
+            descr = edit_file(
+                self, None,
+                'Please enter the description for the patch above.',
+                show_patch)
         else:
             descr = message
 
         head = git.get_head()
 
-        self.__begin_stack_check()
+        if name == None:
+            name = make_patch_name(descr, self.patch_exists)
 
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
         patch.create()
@@ -806,14 +848,13 @@ class Series(StgitObject):
             self.log_patch(patch, 'new')
 
             insert_string(self.__applied_file, patch.get_name())
-            if not self.get_current():
-                self.__set_current(name)
         else:
             append_string(self.__applied_file, patch.get_name())
-            self.__set_current(name)
             if refresh:
                 self.refresh_patch(cache_update = False, log = 'new')
 
+        return patch
+
     def delete_patch(self, name):
         """Deletes a patch
         """
@@ -842,8 +883,6 @@ class Series(StgitObject):
         if self.patch_hidden(name):
             self.unhide_patch(name)
 
-        self.__begin_stack_check()
-
     def forward_patches(self, names):
         """Try to fast-forward an array of patches.
 
@@ -851,7 +890,6 @@ class Series(StgitObject):
         stack. Apply the rest with push_patch
         """
         unapplied = self.get_unapplied()
-        self.__begin_stack_check()
 
         forwarded = 0
         top = git.get_head()
@@ -920,8 +958,6 @@ class Series(StgitObject):
         f.writelines([line + '\n' for line in unapplied])
         f.close()
 
-        self.__set_current(name)
-
         return forwarded
 
     def merged_patches(self, names):
@@ -950,8 +986,6 @@ class Series(StgitObject):
         unapplied = self.get_unapplied()
         assert(name in unapplied)
 
-        self.__begin_stack_check()
-
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
         head = git.get_head()
@@ -995,7 +1029,8 @@ class Series(StgitObject):
                 except git.GitException, ex:
                     print >> sys.stderr, \
                           'The merge failed during "push". ' \
-                          'Use "refresh" after fixing the conflicts'
+                          'Use "refresh" after fixing the conflicts or ' \
+                          'revert the operation with "push --undo".'
 
         append_string(self.__applied_file, name)
 
@@ -1004,8 +1039,6 @@ class Series(StgitObject):
         f.writelines([line + '\n' for line in unapplied])
         f.close()
 
-        self.__set_current(name)
-
         # head == bottom case doesn't need to refresh the patch
         if empty or head != bottom:
             if not ex:
@@ -1019,7 +1052,7 @@ class Series(StgitObject):
             else:
                 # we store the correctly merged files only for
                 # tracking the conflict history. Note that the
-                # git.merge() operations shouls always leave the index
+                # git.merge() operations should always leave the index
                 # in a valid state (i.e. only stage 0 files)
                 self.refresh_patch(cache_update = False, log = 'push(c)')
                 raise StackException, str(ex)
@@ -1058,12 +1091,13 @@ class Series(StgitObject):
 
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
-        # only keep the local changes
-        if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
-            raise StackException, \
-                  'Failed to pop patches while preserving the local changes'
-
-        git.switch(patch.get_bottom(), keep)
+        if git.get_head_file() == self.get_branch():
+            if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
+                raise StackException(
+                    'Failed to pop patches while preserving the local changes')
+            git.switch(patch.get_bottom(), keep)
+        else:
+            git.set_branch(self.get_branch(), patch.get_bottom())
 
         # save the new applied list
         idx = applied.index(name) + 1
@@ -1083,13 +1117,6 @@ class Series(StgitObject):
         f.writelines([line + '\n' for line in applied])
         f.close()
 
-        if applied == []:
-            self.__set_current(None)
-        else:
-            self.__set_current(applied[-1])
-
-        self.__end_stack_check()
-
     def empty_patch(self, name):
         """Returns True if the patch is empty
         """
@@ -1131,8 +1158,6 @@ class Series(StgitObject):
             f.close()
         elif oldname in applied:
             Patch(oldname, self.__patch_dir, self.__refs_dir).rename(newname)
-            if oldname == self.get_current():
-                self.__set_current(newname)
 
             applied[applied.index(oldname)] = newname