Fix popping on non-active branches
[stgit] / stgit / stack.py
index 0cf7947..52f39a6 100644 (file)
@@ -23,6 +23,7 @@ import sys, os, re
 from stgit.utils import *
 from stgit import git, basedir, templates
 from stgit.config import config
+from shutil import copyfile
 
 
 # stack exception class
@@ -295,7 +296,6 @@ class Series(StgitObject):
         self.__applied_file = os.path.join(self._dir(), 'applied')
         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
         self.__hidden_file = os.path.join(self._dir(), 'hidden')
-        self.__current_file = os.path.join(self._dir(), 'current')
         self.__descr_file = os.path.join(self._dir(), 'description')
 
         # where this series keeps its patches
@@ -325,11 +325,6 @@ class Series(StgitObject):
         """
         return self.__name
 
-    def __set_current(self, name):
-        """Sets the topmost patch
-        """
-        self._set_field('current', name)
-
     def get_patch(self, name):
         """Return a Patch object for the given name
         """
@@ -346,11 +341,16 @@ class Series(StgitObject):
     def get_current(self):
         """Return the name of the topmost patch, or None if there is
         no such patch."""
-        name = self._get_field('current')
-        if name == '':
+        try:
+            applied = self.get_applied()
+        except StackException:
+            # No "applied" file: branch is not initialized.
+            return None
+        try:
+            return applied[-1]
+        except IndexError:
+            # No patches applied.
             return None
-        else:
-            return name
 
     def get_applied(self):
         if not os.path.isfile(self.__applied_file):
@@ -602,13 +602,18 @@ class Series(StgitObject):
             patches = applied = unapplied = []
         for p in patches:
             patch = self.get_patch(p)
-            new_series.new_patch(p, message = patch.get_description(),
-                                 can_edit = False, unapplied = True,
-                                 bottom = patch.get_bottom(),
-                                 top = patch.get_top(),
-                                 author_name = patch.get_authname(),
-                                 author_email = patch.get_authemail(),
-                                 author_date = patch.get_authdate())
+            newpatch = new_series.new_patch(p, message = patch.get_description(),
+                                            can_edit = False, unapplied = True,
+                                            bottom = patch.get_bottom(),
+                                            top = patch.get_top(),
+                                            author_name = patch.get_authname(),
+                                            author_email = patch.get_authemail(),
+                                            author_date = patch.get_authdate())
+            if patch.get_log():
+                print "setting log to %s" %  patch.get_log()
+                newpatch.set_log(patch.get_log())
+            else:
+                print "no log for %s" % patchname
 
         # fast forward the cloned series to self's top
         new_series.forward_patches(applied)
@@ -650,8 +655,6 @@ class Series(StgitObject):
                 os.remove(self.__unapplied_file)
             if os.path.exists(self.__hidden_file):
                 os.remove(self.__hidden_file)
-            if os.path.exists(self.__current_file):
-                os.remove(self.__current_file)
             if os.path.exists(self.__descr_file):
                 os.remove(self.__descr_file)
             if os.path.exists(self._dir()+'/orig-base'):
@@ -825,14 +828,13 @@ class Series(StgitObject):
             self.log_patch(patch, 'new')
 
             insert_string(self.__applied_file, patch.get_name())
-            if not self.get_current():
-                self.__set_current(name)
         else:
             append_string(self.__applied_file, patch.get_name())
-            self.__set_current(name)
             if refresh:
                 self.refresh_patch(cache_update = False, log = 'new')
 
+        return patch
+
     def delete_patch(self, name):
         """Deletes a patch
         """
@@ -936,8 +938,6 @@ class Series(StgitObject):
         f.writelines([line + '\n' for line in unapplied])
         f.close()
 
-        self.__set_current(name)
-
         return forwarded
 
     def merged_patches(self, names):
@@ -1019,8 +1019,6 @@ class Series(StgitObject):
         f.writelines([line + '\n' for line in unapplied])
         f.close()
 
-        self.__set_current(name)
-
         # head == bottom case doesn't need to refresh the patch
         if empty or head != bottom:
             if not ex:
@@ -1073,12 +1071,13 @@ class Series(StgitObject):
 
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
-        # only keep the local changes
-        if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
-            raise StackException, \
-                  'Failed to pop patches while preserving the local changes'
-
-        git.switch(patch.get_bottom(), keep)
+        if git.get_head_file() == self.get_branch():
+            if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
+                raise StackException(
+                    'Failed to pop patches while preserving the local changes')
+            git.switch(patch.get_bottom(), keep)
+        else:
+            git.set_branch(self.get_branch(), patch.get_bottom())
 
         # save the new applied list
         idx = applied.index(name) + 1
@@ -1098,11 +1097,6 @@ class Series(StgitObject):
         f.writelines([line + '\n' for line in applied])
         f.close()
 
-        if applied == []:
-            self.__set_current(None)
-        else:
-            self.__set_current(applied[-1])
-
     def empty_patch(self, name):
         """Returns True if the patch is empty
         """
@@ -1144,8 +1138,6 @@ class Series(StgitObject):
             f.close()
         elif oldname in applied:
             Patch(oldname, self.__patch_dir, self.__refs_dir).rename(newname)
-            if oldname == self.get_current():
-                self.__set_current(newname)
 
             applied[applied.index(oldname)] = newname