Fix the parse_patches() function to work with tuples
[stgit] / stgit / commands / common.py
index 1ea6025..029ec65 100644 (file)
@@ -21,20 +21,20 @@ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 import sys, os, os.path, re
 from optparse import OptionParser, make_option
 
 import sys, os, os.path, re
 from optparse import OptionParser, make_option
 
+from stgit.exception import *
 from stgit.utils import *
 from stgit.utils import *
+from stgit.out import *
+from stgit.run import *
 from stgit import stack, git, basedir
 from stgit.config import config, file_extensions
 from stgit import stack, git, basedir
 from stgit.config import config, file_extensions
-
-crt_series = None
-
+from stgit.lib import stack as libstack
 
 # Command exception class
 
 # Command exception class
-class CmdException(Exception):
+class CmdException(StgException):
     pass
 
     pass
 
-
 # Utility functions
 # Utility functions
-class RevParseException(Exception):
+class RevParseException(StgException):
     """Revision spec parse error."""
     pass
 
     """Revision spec parse error."""
     pass
 
@@ -42,9 +42,7 @@ def parse_rev(rev):
     """Parse a revision specification into its
     patchname@branchname//patch_id parts. If no branch name has a slash
     in it, also accept / instead of //."""
     """Parse a revision specification into its
     patchname@branchname//patch_id parts. If no branch name has a slash
     in it, also accept / instead of //."""
-    files, dirs = list_files_and_dirs(os.path.join(basedir.get(),
-                                                   'refs', 'heads'))
-    if len(dirs) != 0:
+    if '/' in ''.join(git.get_heads()):
         # We have branch names with / in them.
         branch_chars = r'[^@]'
         patch_id_mark = r'//'
         # We have branch names with / in them.
         branch_chars = r'[^@]'
         patch_id_mark = r'//'
@@ -74,11 +72,19 @@ def parse_rev(rev):
     # No, we can't parse that.
     raise RevParseException
 
     # No, we can't parse that.
     raise RevParseException
 
-def git_id(rev):
+def git_id(crt_series, rev):
     """Return the GIT id
     """
     if not rev:
         return None
     """Return the GIT id
     """
     if not rev:
         return None
+
+    # try a GIT revision first
+    try:
+        return git.rev_parse(rev + '^{commit}')
+    except git.GitException:
+        pass
+
+    # try an StGIT patch name
     try:
         patch, branch, patch_id = parse_rev(rev)
         if branch == None:
     try:
         patch, branch, patch_id = parse_rev(rev)
         if branch == None:
@@ -89,7 +95,8 @@ def git_id(rev):
             patch = series.get_current()
             if not patch:
                 raise CmdException, 'No patches applied'
             patch = series.get_current()
             if not patch:
                 raise CmdException, 'No patches applied'
-        if patch in series.get_applied() or patch in series.get_unapplied():
+        if patch in series.get_applied() or patch in series.get_unapplied() or \
+               patch in series.get_hidden():
             if patch_id in ['top', '', None]:
                 return series.get_patch(patch).get_top()
             elif patch_id == 'bottom':
             if patch_id in ['top', '', None]:
                 return series.get_patch(patch).get_top()
             elif patch_id == 'bottom':
@@ -101,121 +108,106 @@ def git_id(rev):
             elif patch_id == 'log':
                 return series.get_patch(patch).get_log()
         if patch == 'base' and patch_id == None:
             elif patch_id == 'log':
                 return series.get_patch(patch).get_log()
         if patch == 'base' and patch_id == None:
-            return read_string(series.get_base_file())
+            return series.get_base()
     except RevParseException:
         pass
     except RevParseException:
         pass
-    return git.rev_parse(rev + '^{commit}')
+    except stack.StackException:
+        pass
+
+    raise CmdException, 'Unknown patch or revision: %s' % rev
 
 def check_local_changes():
     if git.local_changes():
 
 def check_local_changes():
     if git.local_changes():
-        raise CmdException, \
-              'local changes in the tree. Use "refresh" to commit them'
+        raise CmdException('local changes in the tree. Use "refresh" or'
+                           ' "status --reset"')
 
 
-def check_head_top_equal():
+def check_head_top_equal(crt_series):
     if not crt_series.head_top_equal():
     if not crt_series.head_top_equal():
-        raise CmdException(
-            'HEAD and top are not the same. You probably committed\n'
-            '  changes to the tree outside of StGIT. To bring them\n'
-            '  into StGIT, use the "assimilate" command')
+        raise CmdException('HEAD and top are not the same. This can happen'
+                           ' if you modify a branch with git. "stg repair'
+                           ' --help" explains more about what to do next.')
 
 def check_conflicts():
 
 def check_conflicts():
-    if os.path.exists(os.path.join(basedir.get(), 'conflicts')):
-        raise CmdException, 'Unsolved conflicts. Please resolve them first'
+    if git.get_conflicts():
+        raise CmdException('Unsolved conflicts. Please resolve them first'
+                           ' or revert the changes with "status --reset"')
 
 
-def print_crt_patch(branch = None):
+def print_crt_patch(crt_series, branch = None):
     if not branch:
         patch = crt_series.get_current()
     else:
         patch = stack.Series(branch).get_current()
 
     if patch:
     if not branch:
         patch = crt_series.get_current()
     else:
         patch = stack.Series(branch).get_current()
 
     if patch:
-        print 'Now at patch "%s"' % patch
+        out.info('Now at patch "%s"' % patch)
     else:
     else:
-        print 'No patches applied'
-
-def resolved(filename, reset = None):
-    if reset:
-        reset_file = filename + file_extensions()[reset]
-        if os.path.isfile(reset_file):
-            if os.path.isfile(filename):
-                os.remove(filename)
-            os.rename(reset_file, filename)
-
-    git.update_cache([filename], force = True)
-
-    for ext in file_extensions().values():
-        fn = filename + ext
-        if os.path.isfile(fn):
-            os.remove(fn)
+        out.info('No patches applied')
 
 def resolved_all(reset = None):
     conflicts = git.get_conflicts()
 
 def resolved_all(reset = None):
     conflicts = git.get_conflicts()
-    if conflicts:
-        for filename in conflicts:
-            resolved(filename, reset)
-        os.remove(os.path.join(basedir.get(), 'conflicts'))
+    git.resolved(conflicts, reset)
 
 
-def push_patches(patches, check_merged = False):
+def push_patches(crt_series, patches, check_merged = False):
     """Push multiple patches onto the stack. This function is shared
     between the push and pull commands
     """
     forwarded = crt_series.forward_patches(patches)
     if forwarded > 1:
     """Push multiple patches onto the stack. This function is shared
     between the push and pull commands
     """
     forwarded = crt_series.forward_patches(patches)
     if forwarded > 1:
-        print 'Fast-forwarded patches "%s" - "%s"' % (patches[0],
-                                                      patches[forwarded - 1])
+        out.info('Fast-forwarded patches "%s" - "%s"'
+                 % (patches[0], patches[forwarded - 1]))
     elif forwarded == 1:
     elif forwarded == 1:
-        print 'Fast-forwarded patch "%s"' % patches[0]
+        out.info('Fast-forwarded patch "%s"' % patches[0])
 
     names = patches[forwarded:]
 
     # check for patches merged upstream
 
     names = patches[forwarded:]
 
     # check for patches merged upstream
-    if check_merged:
-        print 'Checking for patches merged upstream...',
-        sys.stdout.flush()
+    if names and check_merged:
+        out.start('Checking for patches merged upstream')
 
         merged = crt_series.merged_patches(names)
 
 
         merged = crt_series.merged_patches(names)
 
-        print 'done (%d found)' % len(merged)
+        out.done('%d found' % len(merged))
     else:
         merged = []
 
     for p in names:
     else:
         merged = []
 
     for p in names:
-        print 'Pushing patch "%s"...' % p,
-        sys.stdout.flush()
+        out.start('Pushing patch "%s"' % p)
 
         if p in merged:
 
         if p in merged:
-            crt_series.push_patch(p, empty = True)
-            print 'done (merged upstream)'
+            crt_series.push_empty_patch(p)
+            out.done('merged upstream')
         else:
             modified = crt_series.push_patch(p)
 
             if crt_series.empty_patch(p):
         else:
             modified = crt_series.push_patch(p)
 
             if crt_series.empty_patch(p):
-                print 'done (empty patch)'
+                out.done('empty patch')
             elif modified:
             elif modified:
-                print 'done (modified)'
+                out.done('modified')
             else:
             else:
-                print 'done'
+                out.done()
 
 
-def pop_patches(patches, keep = False):
+def pop_patches(crt_series, patches, keep = False):
     """Pop the patches in the list from the stack. It is assumed that
     the patches are listed in the stack reverse order.
     """
     """Pop the patches in the list from the stack. It is assumed that
     the patches are listed in the stack reverse order.
     """
-    p = patches[-1]
-    if len(patches) == 1:
-        print 'Popping patch "%s"...' % p,
+    if len(patches) == 0:
+        out.info('Nothing to push/pop')
     else:
     else:
-        print 'Popping "%s" - "%s" patches...' % (patches[0], p),
-    sys.stdout.flush()
-
-    crt_series.pop_patch(p, keep)
-
-    print 'done'
+        p = patches[-1]
+        if len(patches) == 1:
+            out.start('Popping patch "%s"' % p)
+        else:
+            out.start('Popping patches "%s" - "%s"' % (patches[0], p))
+        crt_series.pop_patch(p, keep)
+        out.done()
 
 
-def parse_patches(patch_args, patch_list):
+def parse_patches(patch_args, patch_list, boundary = 0, ordered = False):
     """Parse patch_args list for patch names in patch_list and return
     a list. The names can be individual patches and/or in the
     patch1..patch2 format.
     """
     """Parse patch_args list for patch names in patch_list and return
     a list. The names can be individual patches and/or in the
     patch1..patch2 format.
     """
+    # in case it receives a tuple
+    patch_list = list(patch_list)
     patches = []
 
     for name in patch_args:
     patches = []
 
     for name in patch_args:
@@ -233,12 +225,26 @@ def parse_patches(patch_args, patch_list):
             if pair[0]:
                 first = patch_list.index(pair[0])
             else:
             if pair[0]:
                 first = patch_list.index(pair[0])
             else:
-                first = 0
+                first = -1
             # exclusive boundary
             if pair[1]:
                 last = patch_list.index(pair[1]) + 1
             else:
             # exclusive boundary
             if pair[1]:
                 last = patch_list.index(pair[1]) + 1
             else:
-                last = len(patch_list)
+                last = -1
+
+            # only cross the boundary if explicitly asked
+            if not boundary:
+                boundary = len(patch_list)
+            if first < 0:
+                if last <= boundary:
+                    first = 0
+                else:
+                    first = boundary
+            if last < 0:
+                if first < boundary:
+                    last = boundary
+                else:
+                    last = len(patch_list)
 
             if last > first:
                 pl = patch_list[first:last]
 
             if last > first:
                 pl = patch_list[first:last]
@@ -254,39 +260,282 @@ def parse_patches(patch_args, patch_list):
 
         patches += pl
 
 
         patches += pl
 
+    if ordered:
+        patches = [p for p in patch_list if p in patches]
+
     return patches
 
 def name_email(address):
     return patches
 
 def name_email(address):
-    """Return a tuple consisting of the name and email parsed from a
-    standard 'name <email>' or 'email (name)' string
+    p = parse_name_email(address)
+    if p:
+        return p
+    else:
+        raise CmdException('Incorrect "name <email>"/"email (name)" string: %s'
+                           % address)
+
+def name_email_date(address):
+    p = parse_name_email_date(address)
+    if p:
+        return p
+    else:
+        raise CmdException('Incorrect "name <email> date" string: %s' % address)
+
+def address_or_alias(addr_str):
+    """Return the address if it contains an e-mail address or look up
+    the aliases in the config files.
     """
     """
-    address = re.sub('[\\\\"]', '\\\\\g<0>', address)
-    str_list = re.findall('^(.*)\s*<(.*)>\s*$', address)
-    if not str_list:
-        str_list = re.findall('^(.*)\s*\((.*)\)\s*$', address)
-        if not str_list:
-            raise CmdException, 'Incorrect "name <email>"/"email (name)" string: %s' % address
-        return ( str_list[0][1], str_list[0][0] )
+    def __address_or_alias(addr):
+        if not addr:
+            return None
+        if addr.find('@') >= 0:
+            # it's an e-mail address
+            return addr
+        alias = config.get('mail.alias.'+addr)
+        if alias:
+            # it's an alias
+            return alias
+        raise CmdException, 'unknown e-mail alias: %s' % addr
+
+    addr_list = [__address_or_alias(addr.strip())
+                 for addr in addr_str.split(',')]
+    return ', '.join([addr for addr in addr_list if addr])
+
+def prepare_rebase(crt_series):
+    # pop all patches
+    applied = crt_series.get_applied()
+    if len(applied) > 0:
+        out.start('Popping all applied patches')
+        crt_series.pop_patch(applied[0])
+        out.done()
+    return applied
+
+def rebase(crt_series, target):
+    try:
+        tree_id = git_id(crt_series, target)
+    except:
+        # it might be that we use a custom rebase command with its own
+        # target type
+        tree_id = target
+    if tree_id == git.get_head():
+        out.info('Already at "%s", no need for rebasing.' % target)
+        return
+    if target:
+        out.start('Rebasing to "%s"' % target)
+    else:
+        out.start('Rebasing to the default target')
+    git.rebase(tree_id = tree_id)
+    out.done()
+
+def post_rebase(crt_series, applied, nopush, merged):
+    # memorize that we rebased to here
+    crt_series._set_field('orig-base', git.get_head())
+    # push the patches back
+    if not nopush:
+        push_patches(crt_series, applied, merged)
+
+#
+# Patch description/e-mail/diff parsing
+#
+def __end_descr(line):
+    return re.match('---\s*$', line) or re.match('diff -', line) or \
+            re.match('Index: ', line)
+
+def __split_descr_diff(string):
+    """Return the description and the diff from the given string
+    """
+    descr = diff = ''
+    top = True
+
+    for line in string.split('\n'):
+        if top:
+            if not __end_descr(line):
+                descr += line + '\n'
+                continue
+            else:
+                top = False
+        diff += line + '\n'
 
 
-    return str_list[0]
+    return (descr.rstrip(), diff)
 
 
-def name_email_date(address):
-    """Return a tuple consisting of the name, email and date parsed
-    from a 'name <email> date' string
+def __parse_description(descr):
+    """Parse the patch description and return the new description and
+    author information (if any).
     """
     """
-    address = re.sub('[\\\\"]', '\\\\\g<0>', address)
-    str_list = re.findall('^(.*)\s*<(.*)>\s*(.*)\s*$', address)
-    if not str_list:
-        raise CmdException, 'Incorrect "name <email> date" string: %s' % address
+    subject = body = ''
+    authname = authemail = authdate = None
+
+    descr_lines = [line.rstrip() for line in  descr.split('\n')]
+    if not descr_lines:
+        raise CmdException, "Empty patch description"
+
+    lasthdr = 0
+    end = len(descr_lines)
+
+    # Parse the patch header
+    for pos in range(0, end):
+        if not descr_lines[pos]:
+           continue
+        # check for a "From|Author:" line
+        if re.match('\s*(?:from|author):\s+', descr_lines[pos], re.I):
+            auth = re.findall('^.*?:\s+(.*)$', descr_lines[pos])[0]
+            authname, authemail = name_email(auth)
+            lasthdr = pos + 1
+            continue
+        # check for a "Date:" line
+        if re.match('\s*date:\s+', descr_lines[pos], re.I):
+            authdate = re.findall('^.*?:\s+(.*)$', descr_lines[pos])[0]
+            lasthdr = pos + 1
+            continue
+        if subject:
+            break
+        # get the subject
+        subject = descr_lines[pos]
+        lasthdr = pos + 1
+
+    # get the body
+    if lasthdr < end:
+        body = reduce(lambda x, y: x + '\n' + y, descr_lines[lasthdr:], '')
+
+    return (subject + body, authname, authemail, authdate)
+
+def parse_mail(msg):
+    """Parse the message object and return (description, authname,
+    authemail, authdate, diff)
+    """
+    from email.Header import decode_header, make_header
+
+    def __decode_header(header):
+        """Decode a qp-encoded e-mail header as per rfc2047"""
+        try:
+            words_enc = decode_header(header)
+            hobj = make_header(words_enc)
+        except Exception, ex:
+            raise CmdException, 'header decoding error: %s' % str(ex)
+        return unicode(hobj).encode('utf-8')
+
+    # parse the headers
+    if msg.has_key('from'):
+        authname, authemail = name_email(__decode_header(msg['from']))
+    else:
+        authname = authemail = None
 
 
-    return str_list[0]
+    # '\n\t' can be found on multi-line headers
+    descr = __decode_header(msg['subject']).replace('\n\t', ' ')
+    authdate = msg['date']
 
 
-def make_patch_name(msg):
-    """Return a string to be used as a patch name. This is generated
-    from the top line of the string passed as argument.
+    # remove the '[*PATCH*]' expression in the subject
+    if descr:
+        descr = re.findall('^(\[.*?[Pp][Aa][Tt][Cc][Hh].*?\])?\s*(.*)$',
+                           descr)[0][1]
+    else:
+        raise CmdException, 'Subject: line not found'
+
+    # the rest of the message
+    msg_text = ''
+    for part in msg.walk():
+        if part.get_content_type() == 'text/plain':
+            msg_text += part.get_payload(decode = True)
+
+    rem_descr, diff = __split_descr_diff(msg_text)
+    if rem_descr:
+        descr += '\n\n' + rem_descr
+
+    # parse the description for author information
+    descr, descr_authname, descr_authemail, descr_authdate = \
+           __parse_description(descr)
+    if descr_authname:
+        authname = descr_authname
+    if descr_authemail:
+        authemail = descr_authemail
+    if descr_authdate:
+       authdate = descr_authdate
+
+    return (descr, authname, authemail, authdate, diff)
+
+def parse_patch(text):
+    """Parse the input text and return (description, authname,
+    authemail, authdate, diff)
     """
     """
-    if not msg:
-        return None
+    descr, diff = __split_descr_diff(text)
+    descr, authname, authemail, authdate = __parse_description(descr)
+
+    # we don't yet have an agreed place for the creation date.
+    # Just return None
+    return (descr, authname, authemail, authdate, diff)
+
+def readonly_constant_property(f):
+    """Decorator that converts a function that computes a value to an
+    attribute that returns the value. The value is computed only once,
+    the first time it is accessed."""
+    def new_f(self):
+        n = '__' + f.__name__
+        if not hasattr(self, n):
+            setattr(self, n, f(self))
+        return getattr(self, n)
+    return property(new_f)
+
+class DirectoryException(StgException):
+    pass
+
+class _Directory(object):
+    def __init__(self, needs_current_series = True):
+        self.needs_current_series =  needs_current_series
+    @readonly_constant_property
+    def git_dir(self):
+        try:
+            return Run('git', 'rev-parse', '--git-dir'
+                       ).discard_stderr().output_one_line()
+        except RunException:
+            raise DirectoryException('No git repository found')
+    @readonly_constant_property
+    def __topdir_path(self):
+        try:
+            lines = Run('git', 'rev-parse', '--show-cdup'
+                        ).discard_stderr().output_lines()
+            if len(lines) == 0:
+                return '.'
+            elif len(lines) == 1:
+                return lines[0]
+            else:
+                raise RunException('Too much output')
+        except RunException:
+            raise DirectoryException('No git repository found')
+    @readonly_constant_property
+    def is_inside_git_dir(self):
+        return { 'true': True, 'false': False
+                 }[Run('git', 'rev-parse', '--is-inside-git-dir'
+                       ).output_one_line()]
+    @readonly_constant_property
+    def is_inside_worktree(self):
+        return { 'true': True, 'false': False
+                 }[Run('git', 'rev-parse', '--is-inside-work-tree'
+                       ).output_one_line()]
+    def cd_to_topdir(self):
+        os.chdir(self.__topdir_path)
+
+class DirectoryAnywhere(_Directory):
+    def setup(self):
+        pass
 
 
-    subject_line = msg.lstrip().split('\n', 1)[0].lower()
-    return re.sub('[\W]+', '-', subject_line).strip('-')
+class DirectoryHasRepository(_Directory):
+    def setup(self):
+        self.git_dir # might throw an exception
+
+class DirectoryInWorktree(DirectoryHasRepository):
+    def setup(self):
+        DirectoryHasRepository.setup(self)
+        if not self.is_inside_worktree:
+            raise DirectoryException('Not inside a git worktree')
+
+class DirectoryGotoToplevel(DirectoryInWorktree):
+    def setup(self):
+        DirectoryInWorktree.setup(self)
+        self.cd_to_topdir()
+
+class DirectoryHasRepositoryLib(_Directory):
+    """For commands that use the new infrastructure in stgit.lib.*."""
+    def __init__(self):
+        self.needs_current_series = False
+    def setup(self):
+        # This will throw an exception if we don't have a repository.
+        self.repository = libstack.Repository.default()