Create a StgitObject class to factorise code for property handling.
[stgit] / stgit / stack.py
index 4df306a..69fa03b 100644 (file)
@@ -119,27 +119,57 @@ def edit_file(series, line, comment, show_patch = True):
 # Classes
 #
 
-class Patch:
+class StgitObject:
+    """An object with stgit-like properties stored as files in a directory
+    """
+    def _set_dir(self, dir):
+        self.__dir = dir
+    def _dir(self):
+        return self.__dir
+
+    def create_empty_field(self, name):
+        create_empty_file(os.path.join(self.__dir, name))
+
+    def _get_field(self, name, multiline = False):
+        id_file = os.path.join(self.__dir, name)
+        if os.path.isfile(id_file):
+            line = read_string(id_file, multiline)
+            if line == '':
+                return None
+            else:
+                return line
+        else:
+            return None
+
+    def _set_field(self, name, value, multiline = False):
+        fname = os.path.join(self.__dir, name)
+        if value and value != '':
+            write_string(fname, value, multiline)
+        elif os.path.isfile(fname):
+            os.remove(fname)
+
+    
+class Patch(StgitObject):
     """Basic patch implementation
     """
     def __init__(self, name, series_dir, refs_dir):
         self.__series_dir = series_dir
         self.__name = name
-        self.__dir = os.path.join(self.__series_dir, self.__name)
+        self._set_dir(os.path.join(self.__series_dir, self.__name))
         self.__refs_dir = refs_dir
         self.__top_ref_file = os.path.join(self.__refs_dir, self.__name)
         self.__log_ref_file = os.path.join(self.__refs_dir,
                                            self.__name + '.log')
 
     def create(self):
-        os.mkdir(self.__dir)
-        create_empty_file(os.path.join(self.__dir, 'bottom'))
-        create_empty_file(os.path.join(self.__dir, 'top'))
+        os.mkdir(self._dir())
+        self.create_empty_field('bottom')
+        self.create_empty_field('top')
 
     def delete(self):
-        for f in os.listdir(self.__dir):
-            os.remove(os.path.join(self.__dir, f))
-        os.rmdir(self.__dir)
+        for f in os.listdir(self._dir()):
+            os.remove(os.path.join(self._dir(), f))
+        os.rmdir(self._dir())
         os.remove(self.__top_ref_file)
         if os.path.exists(self.__log_ref_file):
             os.remove(self.__log_ref_file)
@@ -148,16 +178,16 @@ class Patch:
         return self.__name
 
     def rename(self, newname):
-        olddir = self.__dir
+        olddir = self._dir()
         old_top_ref_file = self.__top_ref_file
         old_log_ref_file = self.__log_ref_file
         self.__name = newname
-        self.__dir = os.path.join(self.__series_dir, self.__name)
+        self._set_dir(os.path.join(self.__series_dir, self.__name))
         self.__top_ref_file = os.path.join(self.__refs_dir, self.__name)
         self.__log_ref_file = os.path.join(self.__refs_dir,
                                            self.__name + '.log')
 
-        os.rename(olddir, self.__dir)
+        os.rename(olddir, self._dir())
         os.rename(old_top_ref_file, self.__top_ref_file)
         if os.path.exists(old_log_ref_file):
             os.rename(old_log_ref_file, self.__log_ref_file)
@@ -173,106 +203,88 @@ class Patch:
         if top:
             self.__update_top_ref(top)
 
-    def __get_field(self, name, multiline = False):
-        id_file = os.path.join(self.__dir, name)
-        if os.path.isfile(id_file):
-            line = read_string(id_file, multiline)
-            if line == '':
-                return None
-            else:
-                return line
-        else:
-            return None
-
-    def __set_field(self, name, value, multiline = False):
-        fname = os.path.join(self.__dir, name)
-        if value and value != '':
-            write_string(fname, value, multiline)
-        elif os.path.isfile(fname):
-            os.remove(fname)
-
     def get_old_bottom(self):
-        return self.__get_field('bottom.old')
+        return self._get_field('bottom.old')
 
     def get_bottom(self):
-        return self.__get_field('bottom')
+        return self._get_field('bottom')
 
     def set_bottom(self, value, backup = False):
         if backup:
-            curr = self.__get_field('bottom')
-            self.__set_field('bottom.old', curr)
-        self.__set_field('bottom', value)
+            curr = self._get_field('bottom')
+            self._set_field('bottom.old', curr)
+        self._set_field('bottom', value)
 
     def get_old_top(self):
-        return self.__get_field('top.old')
+        return self._get_field('top.old')
 
     def get_top(self):
-        return self.__get_field('top')
+        return self._get_field('top')
 
     def set_top(self, value, backup = False):
         if backup:
-            curr = self.__get_field('top')
-            self.__set_field('top.old', curr)
-        self.__set_field('top', value)
+            curr = self._get_field('top')
+            self._set_field('top.old', curr)
+        self._set_field('top', value)
         self.__update_top_ref(value)
 
     def restore_old_boundaries(self):
-        bottom = self.__get_field('bottom.old')
-        top = self.__get_field('top.old')
+        bottom = self._get_field('bottom.old')
+        top = self._get_field('top.old')
 
         if top and bottom:
-            self.__set_field('bottom', bottom)
-            self.__set_field('top', top)
+            self._set_field('bottom', bottom)
+            self._set_field('top', top)
             self.__update_top_ref(top)
             return True
         else:
             return False
 
     def get_description(self):
-        return self.__get_field('description', True)
+        return self._get_field('description', True)
 
     def set_description(self, line):
-        self.__set_field('description', line, True)
+        self._set_field('description', line, True)
 
     def get_authname(self):
-        return self.__get_field('authname')
+        return self._get_field('authname')
 
     def set_authname(self, name):
-        self.__set_field('authname', name or git.author().name)
+        self._set_field('authname', name or git.author().name)
 
     def get_authemail(self):
-        return self.__get_field('authemail')
+        return self._get_field('authemail')
 
     def set_authemail(self, email):
-        self.__set_field('authemail', email or git.author().email)
+        self._set_field('authemail', email or git.author().email)
 
     def get_authdate(self):
-        return self.__get_field('authdate')
+        return self._get_field('authdate')
 
     def set_authdate(self, date):
-        self.__set_field('authdate', date or git.author().date)
+        self._set_field('authdate', date or git.author().date)
 
     def get_commname(self):
-        return self.__get_field('commname')
+        return self._get_field('commname')
 
     def set_commname(self, name):
-        self.__set_field('commname', name or git.committer().name)
+        self._set_field('commname', name or git.committer().name)
 
     def get_commemail(self):
-        return self.__get_field('commemail')
+        return self._get_field('commemail')
 
     def set_commemail(self, email):
-        self.__set_field('commemail', email or git.committer().email)
+        self._set_field('commemail', email or git.committer().email)
 
     def get_log(self):
-        return self.__get_field('log')
+        return self._get_field('log')
 
     def set_log(self, value, backup = False):
-        self.__set_field('log', value)
+        self._set_field('log', value)
         self.__update_log_ref(value)
 
 
-class Series:
+class Series(StgitObject):
     """Class including the operations on series
     """
     def __init__(self, name = None):
@@ -287,22 +299,21 @@ class Series:
         except git.GitException, ex:
             raise StackException, 'GIT tree not initialised: %s' % ex
 
-        self.__series_dir = os.path.join(self.__base_dir, 'patches',
-                                         self.__name)
+        self._set_dir(os.path.join(self.__base_dir, 'patches', self.__name))
         self.__refs_dir = os.path.join(self.__base_dir, 'refs', 'patches',
                                        self.__name)
         self.__base_file = os.path.join(self.__base_dir, 'refs', 'bases',
                                         self.__name)
 
-        self.__applied_file = os.path.join(self.__series_dir, 'applied')
-        self.__unapplied_file = os.path.join(self.__series_dir, 'unapplied')
-        self.__current_file = os.path.join(self.__series_dir, 'current')
-        self.__descr_file = os.path.join(self.__series_dir, 'description')
+        self.__applied_file = os.path.join(self._dir(), 'applied')
+        self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
+        self.__current_file = os.path.join(self._dir(), 'current')
+        self.__descr_file = os.path.join(self._dir(), 'description')
 
         # where this series keeps its patches
-        self.__patch_dir = os.path.join(self.__series_dir, 'patches')
+        self.__patch_dir = os.path.join(self._dir(), 'patches')
         if not os.path.isdir(self.__patch_dir):
-            self.__patch_dir = self.__series_dir
+            self.__patch_dir = self._dir()
 
         # if no __refs_dir, create and populate it (upgrade old repositories)
         if self.is_initialised() and not os.path.isdir(self.__refs_dir):
@@ -311,7 +322,7 @@ class Series:
                 self.get_patch(patch).update_top_ref()
 
         # trash directory
-        self.__trash_dir = os.path.join(self.__series_dir, 'trash')
+        self.__trash_dir = os.path.join(self._dir(), 'trash')
         if self.is_initialised() and not os.path.isdir(self.__trash_dir):
             os.makedirs(self.__trash_dir)
 
@@ -323,10 +334,7 @@ class Series:
     def __set_current(self, name):
         """Sets the topmost patch
         """
-        if name:
-            write_string(self.__current_file, name)
-        else:
-            create_empty_file(self.__current_file)
+        self._set_field('current', name)
 
     def get_patch(self, name):
         """Return a Patch object for the given name
@@ -344,10 +352,7 @@ class Series:
     def get_current(self):
         """Return the name of the topmost patch, or None if there is
         no such patch."""
-        if os.path.isfile(self.__current_file):
-            name = read_string(self.__current_file)
-        else:
-            return None
+        name = self._get_field('current')
         if name == '':
             return None
         else:
@@ -374,26 +379,26 @@ class Series:
         return self.__base_file
 
     def get_protected(self):
-        return os.path.isfile(os.path.join(self.__series_dir, 'protected'))
+        return os.path.isfile(os.path.join(self._dir(), 'protected'))
 
     def protect(self):
-        protect_file = os.path.join(self.__series_dir, 'protected')
+        protect_file = os.path.join(self._dir(), 'protected')
         if not os.path.isfile(protect_file):
             create_empty_file(protect_file)
 
     def unprotect(self):
-        protect_file = os.path.join(self.__series_dir, 'protected')
+        protect_file = os.path.join(self._dir(), 'protected')
         if os.path.isfile(protect_file):
             os.remove(protect_file)
 
     def get_description(self):
-        if os.path.isfile(self.__descr_file):
-            return read_string(self.__descr_file)
-        else:
-            return ''
+        return self._get_field('description')
+
+    def set_description(self, line):
+        self._set_field('description', line)
 
     def __patch_is_current(self, patch):
-        return patch.get_name() == read_string(self.__current_file)
+        return patch.get_name() == self.get_current()
 
     def patch_applied(self, name):
         """Return true if the patch exists in the applied list
@@ -459,10 +464,10 @@ class Series:
 
         create_dirs(bases_dir)
 
-        create_empty_file(self.__applied_file)
-        create_empty_file(self.__unapplied_file)
-        create_empty_file(self.__descr_file)
-        os.makedirs(os.path.join(self.__series_dir, 'patches'))
+        self.create_empty_field('applied')
+        self.create_empty_field('unapplied')
+        self.create_empty_field('description')
+        os.makedirs(os.path.join(self._dir(), 'patches'))
         os.makedirs(self.__refs_dir)
         self.__begin_stack_check()
 
@@ -471,15 +476,15 @@ class Series:
         unconvert to place the patches in the same directory with
         series control files
         """
-        if self.__patch_dir == self.__series_dir:
+        if self.__patch_dir == self._dir():
             print 'Converting old-style to new-style...',
             sys.stdout.flush()
 
-            self.__patch_dir = os.path.join(self.__series_dir, 'patches')
+            self.__patch_dir = os.path.join(self._dir(), 'patches')
             os.makedirs(self.__patch_dir)
 
             for p in self.get_applied() + self.get_unapplied():
-                src = os.path.join(self.__series_dir, p)
+                src = os.path.join(self._dir(), p)
                 dest = os.path.join(self.__patch_dir, p)
                 os.rename(src, dest)
 
@@ -491,7 +496,7 @@ class Series:
 
             for p in self.get_applied() + self.get_unapplied():
                 src = os.path.join(self.__patch_dir, p)
-                dest = os.path.join(self.__series_dir, p)
+                dest = os.path.join(self._dir(), p)
                 os.rename(src, dest)
 
             if not os.listdir(self.__patch_dir):
@@ -500,7 +505,7 @@ class Series:
             else:
                 print 'Patch directory %s is not empty.' % self.__name
 
-            self.__patch_dir = self.__series_dir
+            self.__patch_dir = self._dir()
 
     def rename(self, to_name):
         """Renames a series
@@ -514,7 +519,7 @@ class Series:
 
         git.rename_branch(self.__name, to_name)
 
-        if os.path.isdir(self.__series_dir):
+        if os.path.isdir(self._dir()):
             rename(os.path.join(self.__base_dir, 'patches'),
                    self.__name, to_stack.__name)
         if os.path.exists(self.__base_file):
@@ -534,7 +539,7 @@ class Series:
         new_series = Series(target_series)
 
         # generate an artificial description file
-        write_string(new_series.__descr_file, 'clone of "%s"' % self.__name)
+        new_series.set_description('clone of "%s"' % self.__name)
 
         # clone self's entire series as unapplied patches
         patches = self.get_applied() + self.get_unapplied()
@@ -568,6 +573,8 @@ class Series:
                 os.remove(fname)
             os.rmdir(self.__trash_dir)
 
+            # FIXME: find a way to get rid of those manual removals
+            # (move functionnality to StgitObject ?)
             if os.path.exists(self.__applied_file):
                 os.remove(self.__applied_file)
             if os.path.exists(self.__unapplied_file):
@@ -580,7 +587,7 @@ class Series:
                 os.rmdir(self.__patch_dir)
             else:
                 print 'Patch directory %s is not empty.' % self.__name
-            if not os.listdir(self.__series_dir):
+            if not os.listdir(self._dir()):
                 remove_dirs(os.path.join(self.__base_dir, 'patches'),
                             self.__name)
             else: