Ask git for author and committer name
[stgit] / stgit / stack.py
index 26a2fc5..4df306a 100644 (file)
@@ -238,53 +238,31 @@ class Patch:
         return self.__get_field('authname')
 
     def set_authname(self, name):
-        if not name:
-            if config.has_option('stgit', 'authname'):
-                name = config.get('stgit', 'authname')
-            elif 'GIT_AUTHOR_NAME' in os.environ:
-                name = os.environ['GIT_AUTHOR_NAME']
-        self.__set_field('authname', name)
+        self.__set_field('authname', name or git.author().name)
 
     def get_authemail(self):
         return self.__get_field('authemail')
 
-    def set_authemail(self, address):
-        if not address:
-            if config.has_option('stgit', 'authemail'):
-                address = config.get('stgit', 'authemail')
-            elif 'GIT_AUTHOR_EMAIL' in os.environ:
-                address = os.environ['GIT_AUTHOR_EMAIL']
-        self.__set_field('authemail', address)
+    def set_authemail(self, email):
+        self.__set_field('authemail', email or git.author().email)
 
     def get_authdate(self):
         return self.__get_field('authdate')
 
     def set_authdate(self, date):
-        if not date and 'GIT_AUTHOR_DATE' in os.environ:
-            date = os.environ['GIT_AUTHOR_DATE']
-        self.__set_field('authdate', date)
+        self.__set_field('authdate', date or git.author().date)
 
     def get_commname(self):
         return self.__get_field('commname')
 
     def set_commname(self, name):
-        if not name:
-            if config.has_option('stgit', 'commname'):
-                name = config.get('stgit', 'commname')
-            elif 'GIT_COMMITTER_NAME' in os.environ:
-                name = os.environ['GIT_COMMITTER_NAME']
-        self.__set_field('commname', name)
+        self.__set_field('commname', name or git.committer().name)
 
     def get_commemail(self):
         return self.__get_field('commemail')
 
-    def set_commemail(self, address):
-        if not address:
-            if config.has_option('stgit', 'commemail'):
-                address = config.get('stgit', 'commemail')
-            elif 'GIT_COMMITTER_EMAIL' in os.environ:
-                address = os.environ['GIT_COMMITTER_EMAIL']
-        self.__set_field('commemail', address)
+    def set_commemail(self, email):
+        self.__set_field('commemail', email or git.committer().email)
 
     def get_log(self):
         return self.__get_field('log')
@@ -332,6 +310,11 @@ class Series:
             for patch in self.get_applied() + self.get_unapplied():
                 self.get_patch(patch).update_top_ref()
 
+        # trash directory
+        self.__trash_dir = os.path.join(self.__series_dir, 'trash')
+        if self.is_initialised() and not os.path.isdir(self.__trash_dir):
+            os.makedirs(self.__trash_dir)
+
     def get_branch(self):
         """Return the branch name for the Series object
         """
@@ -350,9 +333,17 @@ class Series:
         """
         return Patch(name, self.__patch_dir, self.__refs_dir)
 
+    def get_current_patch(self):
+        """Return a Patch object representing the topmost patch, or
+        None if there is no such patch."""
+        crt = self.get_current()
+        if not crt:
+            return None
+        return Patch(crt, self.__patch_dir, self.__refs_dir)
+
     def get_current(self):
-        """Return a Patch object representing the topmost patch
-        """
+        """Return the name of the topmost patch, or None if there is
+        no such patch."""
         if os.path.isfile(self.__current_file):
             name = read_string(self.__current_file)
         else:
@@ -404,16 +395,21 @@ class Series:
     def __patch_is_current(self, patch):
         return patch.get_name() == read_string(self.__current_file)
 
-    def __patch_applied(self, name):
+    def patch_applied(self, name):
         """Return true if the patch exists in the applied list
         """
         return name in self.get_applied()
 
-    def __patch_unapplied(self, name):
+    def patch_unapplied(self, name):
         """Return true if the patch exists in the unapplied list
         """
         return name in self.get_unapplied()
 
+    def patch_exists(self, name):
+        """Return true if there is a patch with the given name, false
+        otherwise."""
+        return self.patch_applied(name) or self.patch_unapplied(name)
+
     def __begin_stack_check(self):
         """Save the current HEAD into .git/refs/heads/base if the stack
         is empty
@@ -433,12 +429,11 @@ class Series:
     def head_top_equal(self):
         """Return true if the head and the top are the same
         """
-        crt = self.get_current()
+        crt = self.get_current_patch()
         if not crt:
             # we don't care, no patches applied
             return True
-        return git.get_head() == Patch(crt, self.__patch_dir,
-                                       self.__refs_dir).get_top()
+        return git.get_head() == crt.get_top()
 
     def is_initialised(self):
         """Checks if series is already initialised
@@ -568,6 +563,11 @@ class Series:
             for p in patches:
                 Patch(p, self.__patch_dir, self.__refs_dir).delete()
 
+            # remove the trash directory
+            for fname in os.listdir(self.__trash_dir):
+                os.remove(fname)
+            os.rmdir(self.__trash_dir)
+
             if os.path.exists(self.__applied_file):
                 os.remove(self.__applied_file)
             if os.path.exists(self.__unapplied_file):
@@ -688,10 +688,10 @@ class Series:
                   top = None, bottom = None,
                   author_name = None, author_email = None, author_date = None,
                   committer_name = None, committer_email = None,
-                  before_existing = False):
+                  before_existing = False, refresh = True):
         """Creates a new patch
         """
-        if self.__patch_applied(name) or self.__patch_unapplied(name):
+        if self.patch_applied(name) or self.patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
         if not message and can_edit:
@@ -741,8 +741,8 @@ class Series:
         else:
             append_string(self.__applied_file, patch.get_name())
             self.__set_current(name)
-
-            self.refresh_patch(cache_update = False, log = 'new')
+            if refresh:
+                self.refresh_patch(cache_update = False, log = 'new')
 
     def delete_patch(self, name):
         """Deletes a patch
@@ -751,12 +751,15 @@ class Series:
 
         if self.__patch_is_current(patch):
             self.pop_patch(name)
-        elif self.__patch_applied(name):
+        elif self.patch_applied(name):
             raise StackException, 'Cannot remove an applied patch, "%s", ' \
                   'which is not current' % name
         elif not name in self.get_unapplied():
             raise StackException, 'Unknown patch "%s"' % name
 
+        # save the commit id to a trash file
+        write_string(os.path.join(self.__trash_dir, name), patch.get_top())
+
         patch.delete()
 
         unapplied = self.get_unapplied()
@@ -975,6 +978,11 @@ class Series:
 
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
+        # only keep the local changes
+        if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
+            raise StackException, \
+                  'Failed to pop patches while preserving the local changes'
+
         git.switch(patch.get_bottom(), keep)
 
         # save the new applied list