Refactor message printing
[stgit] / stgit / utils.py
index d7d4777..ad9b1f1 100644 (file)
@@ -1,7 +1,7 @@
 """Common utility functions
 """
 
-import errno, os, os.path, sys
+import errno, os, os.path, re, sys
 from stgit.config import config
 
 __copyright__ = """
@@ -21,6 +21,86 @@ along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 """
 
+class MessagePrinter(object):
+    def __init__(self):
+        class Output(object):
+            def __init__(self, write, flush):
+                self.write = write
+                self.flush = flush
+                self.at_start_of_line = True
+                self.level = 0
+            def new_line(self):
+                """Ensure that we're at the beginning of a line."""
+                if not self.at_start_of_line:
+                    self.write('\n')
+                    self.at_start_of_line = True
+            def single_line(self, msg, print_newline = True,
+                            need_newline = True):
+                """Write a single line. Newline before and after are
+                separately configurable."""
+                if need_newline:
+                    self.new_line()
+                if self.at_start_of_line:
+                    self.write('  '*self.level)
+                self.write(msg)
+                if print_newline:
+                    self.write('\n')
+                    self.at_start_of_line = True
+                else:
+                    self.flush()
+                    self.at_start_of_line = False
+            def tagged_lines(self, tag, lines):
+                tag += ': '
+                for line in lines:
+                    self.single_line(tag + line)
+                    tag = ' '*len(tag)
+            def write_line(self, line):
+                """Write one line of text on a lines of its own, not
+                indented."""
+                self.new_line()
+                self.write('%s\n' % line)
+                self.at_start_of_line = True
+            def write_raw(self, string):
+                """Write an arbitrary string, possibly containing
+                newlines."""
+                self.new_line()
+                self.write(string)
+                self.at_start_of_line = string.endswith('\n')
+        self.__stdout = Output(sys.stdout.write, sys.stdout.flush)
+        if sys.stdout.isatty():
+            self.__out = self.__stdout
+        else:
+            self.__out = Output(lambda msg: None, lambda: None)
+    def stdout(self, line):
+        """Write a line to stdout."""
+        self.__stdout.write_line(line)
+    def stdout_raw(self, string):
+        """Write a string possibly containing newlines to stdout."""
+        self.__stdout.write_raw(string)
+    def info(self, *msgs):
+        for msg in msgs:
+            self.__out.single_line(msg)
+    def note(self, *msgs):
+        self.__out.tagged_lines('Notice', msgs)
+    def warn(self, *msgs):
+        self.__out.tagged_lines('Warning', msgs)
+    def error(self, *msgs):
+        self.__out.tagged_lines('Error', msgs)
+    def start(self, msg):
+        """Start a long-running operation."""
+        self.__out.single_line('%s ... ' % msg, print_newline = False)
+        self.__out.level += 1
+    def done(self, extramsg = None):
+        """Finish long-running operation."""
+        self.__out.level -= 1
+        if extramsg:
+            msg = 'done (%s)' % extramsg
+        else:
+            msg = 'done'
+        self.__out.single_line(msg, need_newline = False)
+
+out = MessagePrinter()
+
 def mkdir_file(filename, mode):
     """Opens filename with the given mode, creating the directory it's
     in if it doesn't already exist."""
@@ -116,24 +196,16 @@ def strip_suffix(suffix, string):
     assert string.endswith(suffix)
     return string[:-len(suffix)]
 
-def remove_dirs(basedir, dirs):
-    """Starting at join(basedir, dirs), remove the directory if empty,
-    and try the same with its parent, until we find a nonempty
-    directory or reach basedir."""
-    path = dirs
-    while path:
-        try:
-            os.rmdir(os.path.join(basedir, path))
-        except OSError:
-            return # can't remove nonempty directory
-        path = os.path.dirname(path)
-
 def remove_file_and_dirs(basedir, file):
     """Remove join(basedir, file), and then remove the directory it
     was in if empty, and try the same with its parent, until we find a
     nonempty directory or reach basedir."""
     os.remove(os.path.join(basedir, file))
-    remove_dirs(basedir, os.path.dirname(file))
+    try:
+        os.removedirs(os.path.join(basedir, os.path.dirname(file)))
+    except OSError:
+        # file's parent dir may not be empty after removal
+        pass
 
 def create_dirs(directory):
     """Create the given directory, if the path doesn't already exist."""
@@ -152,7 +224,14 @@ def rename(basedir, file1, file2):
     full_file2 = os.path.join(basedir, file2)
     create_dirs(os.path.dirname(full_file2))
     os.rename(os.path.join(basedir, file1), full_file2)
-    remove_dirs(basedir, os.path.dirname(file1))
+    try:
+        os.removedirs(os.path.join(basedir, os.path.dirname(file1)))
+    except OSError:
+        # file1's parent dir may not be empty after move
+        pass
+
+class EditorException(Exception):
+    pass
 
 def call_editor(filename):
     """Run the editor on the specified filename."""
@@ -167,9 +246,32 @@ def call_editor(filename):
         editor = 'vi'
     editor += ' %s' % filename
 
-    print 'Invoking the editor: "%s"...' % editor,
-    sys.stdout.flush()
+    out.start('Invoking the editor: "%s"' % editor)
     err = os.system(editor)
     if err:
-        raise Exception, 'editor failed, exit code: %d' % err
-    print 'done'
+        raise EditorException, 'editor failed, exit code: %d' % err
+    out.done()
+
+def patch_name_from_msg(msg):
+    """Return a string to be used as a patch name. This is generated
+    from the top line of the string passed as argument, and is at most
+    30 characters long."""
+    if not msg:
+        return None
+
+    subject_line = msg.split('\n', 1)[0].lstrip().lower()
+    return re.sub('[\W]+', '-', subject_line).strip('-')[:30]
+
+def make_patch_name(msg, unacceptable, default_name = 'patch'):
+    """Return a patch name generated from the given commit message,
+    guaranteed to make unacceptable(name) be false. If the commit
+    message is empty, base the name on default_name instead."""
+    patchname = patch_name_from_msg(msg)
+    if not patchname:
+        patchname = default_name
+    if unacceptable(patchname):
+        suffix = 0
+        while unacceptable('%s-%d' % (patchname, suffix)):
+            suffix += 1
+        patchname = '%s-%d' % (patchname, suffix)
+    return patchname