X-Git-Url: https://git.distorted.org.uk/~mdw/stgit/blobdiff_plain/37a4d1bfabaca3dd947799f10c6cfc369a9edb5e..d7fade4be437ade1eaf92ca29cc99b566b0e3c18:/stgit/commands/common.py diff --git a/stgit/commands/common.py b/stgit/commands/common.py index c6d4535..c6ca514 100644 --- a/stgit/commands/common.py +++ b/stgit/commands/common.py @@ -1,4 +1,4 @@ -"""Function/variables commmon to all the commands +"""Function/variables common to all the commands """ __copyright__ = """ @@ -22,7 +22,10 @@ import sys, os, re from optparse import OptionParser, make_option from stgit.utils import * -from stgit import stack, git +from stgit import stack, git, basedir +from stgit.config import config, file_extensions + +crt_series = None # Command exception class @@ -31,50 +34,55 @@ class CmdException(Exception): # Utility functions -def git_id(string, strict = False): +def git_id(rev): """Return the GIT id """ - if not string: + if not rev: return None - string_list = string.split('/') - - if len(string_list) == 1: - patch_name = None - git_id = string_list[0] - - if git_id == 'HEAD': - return git.get_head() - if git_id == 'base': - return read_string(crt_series.get_base_file()) - - for path in [os.path.join(git.base_dir, 'refs', 'heads'), - os.path.join(git.base_dir, 'refs', 'tags')]: - id_file = os.path.join(path, git_id) - if os.path.isfile(id_file): - return read_string(id_file) - - # maybe GIT knows more about this id - if not strict: - return git_id - elif len(string_list) == 2: - patch_name = string_list[0] - if patch_name == '': - patch_name = crt_series.get_current() - git_id = string_list[1] + rev_list = rev.split('/') + if len(rev_list) == 2: + patch_id = rev_list[1] + if not patch_id: + patch_id = 'top' + elif len(rev_list) == 1: + patch_id = 'top' + else: + patch_id = None + + patch_branch = rev_list[0].split('@') + if len(patch_branch) == 1: + series = crt_series + elif len(patch_branch) == 2: + series = stack.Series(patch_branch[1]) + else: + raise CmdException, 'Unknown id: %s' % rev + patch_name = patch_branch[0] + if not patch_name: + patch_name = series.get_current() if not patch_name: raise CmdException, 'No patches applied' - elif not (patch_name in crt_series.get_applied() - + crt_series.get_unapplied()): - raise CmdException, 'Unknown patch "%s"' % patch_name - if git_id == 'bottom': - return crt_series.get_patch(patch_name).get_bottom() - if git_id == 'top': - return crt_series.get_patch(patch_name).get_top() - - raise CmdException, 'Unknown id: %s' % string + # patch + if patch_name in series.get_applied() \ + or patch_name in series.get_unapplied(): + if patch_id == 'top': + return series.get_patch(patch_name).get_top() + elif patch_id == 'bottom': + return series.get_patch(patch_name).get_bottom() + # Note we can return None here. + elif patch_id == 'top.old': + return series.get_patch(patch_name).get_old_top() + elif patch_id == 'bottom.old': + return series.get_patch(patch_name).get_old_bottom() + + # base + if patch_name == 'base' and len(rev_list) == 1: + return read_string(series.get_base_file()) + + # anything else failed + return git.rev_parse(rev + '^{commit}') def check_local_changes(): if git.local_changes(): @@ -85,50 +93,108 @@ def check_head_top_equal(): if not crt_series.head_top_equal(): raise CmdException, \ 'HEAD and top are not the same. You probably committed\n' \ - ' changes to the tree ouside of StGIT. If you know what you\n' \ + ' changes to the tree outside of StGIT. If you know what you\n' \ ' are doing, use the "refresh -f" command' def check_conflicts(): - if os.path.exists(os.path.join(git.base_dir, 'conflicts')): + if os.path.exists(os.path.join(basedir.get(), 'conflicts')): raise CmdException, 'Unsolved conflicts. Please resolve them first' -def print_crt_patch(): - patch = crt_series.get_current() +def print_crt_patch(branch = None): + if not branch: + patch = crt_series.get_current() + else: + patch = stack.Series(branch).get_current() + if patch: print 'Now at patch "%s"' % patch else: print 'No patches applied' -def resolved(filename): +def resolved(filename, reset = None): + if reset: + reset_file = filename + file_extensions()[reset] + if os.path.isfile(reset_file): + if os.path.isfile(filename): + os.remove(filename) + os.rename(reset_file, filename) + git.update_cache([filename], force = True) - for ext in ['.local', '.older', '.remote']: + + for ext in file_extensions().values(): fn = filename + ext if os.path.isfile(fn): os.remove(fn) -def resolved_all(): +def resolved_all(reset = None): conflicts = git.get_conflicts() if conflicts: for filename in conflicts: - resolved(filename) - os.remove(os.path.join(git.base_dir, 'conflicts')) + resolved(filename, reset) + os.remove(os.path.join(basedir.get(), 'conflicts')) + +def push_patches(patches, check_merged = False): + """Push multiple patches onto the stack. This function is shared + between the push and pull commands + """ + forwarded = crt_series.forward_patches(patches) + if forwarded > 1: + print 'Fast-forwarded patches "%s" - "%s"' % (patches[0], + patches[forwarded - 1]) + elif forwarded == 1: + print 'Fast-forwarded patch "%s"' % patches[0] + + names = patches[forwarded:] -def name_email(string): + # check for patches merged upstream + if check_merged: + print 'Checking for patches merged upstream...', + sys.stdout.flush() + + merged = crt_series.merged_patches(names) + + print 'done (%d found)' % len(merged) + else: + merged = [] + + for p in names: + print 'Pushing patch "%s"...' % p, + sys.stdout.flush() + + if p in merged: + crt_series.push_patch(p, empty = True) + print 'done (merged upstream)' + else: + modified = crt_series.push_patch(p) + + if crt_series.empty_patch(p): + print 'done (empty patch)' + elif modified: + print 'done (modified)' + else: + print 'done' + +def name_email(address): """Return a tuple consisting of the name and email parsed from a - standard 'name ' string + standard 'name ' or 'email (name)' string """ - str_list = re.findall('^(.*)\s+<(.*)>$', string) + address = re.sub('[\\\\"]', '\\\\\g<0>', address) + str_list = re.findall('^(.*)\s*<(.*)>\s*$', address) if not str_list: - raise CmdException, 'Incorrect "name " string: %s' % string + str_list = re.findall('^(.*)\s*\((.*)\)\s*$', address) + if not str_list: + raise CmdException, 'Incorrect "name "/"email (name)" string: %s' % address + return ( str_list[0][1], str_list[0][0] ) return str_list[0] -def name_email_date(string): +def name_email_date(address): """Return a tuple consisting of the name, email and date parsed from a 'name date' string """ - str_list = re.findall('^(.*)\s+<(.*)>\s+(.*)$', string) + address = re.sub('[\\\\"]', '\\\\\g<0>', address) + str_list = re.findall('^(.*)\s*<(.*)>\s*(.*)\s*$', address) if not str_list: - raise CmdException, 'Incorrect "name date" string: %s' % string + raise CmdException, 'Incorrect "name date" string: %s' % address return str_list[0]