Check for invalid patch names before acting
[stgit] / stgit / stack.py
index 5237084..f57e4f0 100644 (file)
@@ -18,7 +18,7 @@ along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 """
 
-import sys, os
+import sys, os, re
 
 from stgit.utils import *
 from stgit import git, basedir, templates
@@ -92,8 +92,9 @@ def edit_file(series, line, comment, show_patch = True):
     f.close()
 
     # the editor
-    if config.has_option('stgit', 'editor'):
-        editor = config.get('stgit', 'editor')
+    editor = config.get('stgit.editor')
+    if editor:
+        pass
     elif 'EDITOR' in os.environ:
         editor = os.environ['EDITOR']
     else:
@@ -327,6 +328,12 @@ class Series(StgitObject):
         if self.is_initialised() and not os.path.isdir(self.__trash_dir):
             os.makedirs(self.__trash_dir)
 
+    def __patch_name_valid(self, name):
+        """Raise an exception if the patch name is not valid.
+        """
+        if not name or re.search('[^\w.-]', name):
+            raise StackException, 'Invalid patch name: "%s"' % name
+
     def get_branch(self):
         """Return the branch name for the Series object
         """
@@ -406,6 +413,32 @@ class Series(StgitObject):
     def set_description(self, line):
         self._set_field('description', line)
 
+    def get_parent_remote(self):
+        return config.get('branch.%s.remote' % self.__name) or 'origin'
+
+    def __set_parent_remote(self, remote):
+        value = config.set('branch.%s.remote' % self.__name, remote)
+
+    def get_parent_branch(self):
+        value = config.get('branch.%s.merge' % self.__name)
+        if value:
+            return value
+        elif git.rev_parse('heads/origin'):
+            return 'heads/origin'
+        else:
+            raise StackException, 'Cannot find a parent branch for "%s"' % self.__name
+
+    def __set_parent_branch(self, name):
+        config.set('branch.%s.merge' % self.__name, name)
+
+    def set_parent(self, remote, localbranch):
+        if localbranch:
+            self.__set_parent_branch(localbranch)
+            if remote:
+                self.__set_parent_remote(remote)
+        elif remote:
+            raise StackException, 'Remote "%s" without a branch cannot be used as parent' % remote
+
     def __patch_is_current(self, patch):
         return patch.get_name() == self.get_current()
 
@@ -459,7 +492,7 @@ class Series(StgitObject):
         """
         return os.path.isdir(self.__patch_dir)
 
-    def init(self, create_at=False):
+    def init(self, create_at=False, parent_remote=None, parent_branch=None):
         """Initialises the stgit series
         """
         bases_dir = os.path.join(self.__base_dir, 'refs', 'bases')
@@ -476,6 +509,8 @@ class Series(StgitObject):
 
         os.makedirs(self.__patch_dir)
 
+        self.set_parent(parent_remote, parent_branch)
+        
         create_dirs(bases_dir)
 
         self.create_empty_field('applied')
@@ -724,6 +759,8 @@ class Series(StgitObject):
                   before_existing = False, refresh = True):
         """Creates a new patch
         """
+        self.__patch_name_valid(name)
+
         if self.patch_applied(name) or self.patch_unapplied(name):
             raise StackException, 'Patch "%s" already exists' % name
 
@@ -780,6 +817,7 @@ class Series(StgitObject):
     def delete_patch(self, name):
         """Deletes a patch
         """
+        self.__patch_name_valid(name)
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
 
         if self.__patch_is_current(patch):
@@ -1055,6 +1093,7 @@ class Series(StgitObject):
     def empty_patch(self, name):
         """Returns True if the patch is empty
         """
+        self.__patch_name_valid(name)
         patch = Patch(name, self.__patch_dir, self.__refs_dir)
         bottom = patch.get_bottom()
         top = patch.get_top()
@@ -1068,6 +1107,8 @@ class Series(StgitObject):
         return False
 
     def rename_patch(self, oldname, newname):
+        self.__patch_name_valid(newname)
+
         applied = self.get_applied()
         unapplied = self.get_unapplied()