Allow configurable file extensions for merge conflicts
[stgit] / stgit / commands / resolved.py
index d98ca9a..186bc73 100644 (file)
@@ -21,7 +21,8 @@ from optparse import OptionParser, make_option
 
 from stgit.commands.common import *
 from stgit.utils import *
-from stgit import stack, git
+from stgit import stack, git, basedir
+from stgit.config import file_extensions
 
 
 help = 'mark a file conflict as solved'
@@ -29,18 +30,25 @@ usage = """%prog [options] [<files...>]
 
 Mark a merge conflict as resolved. The conflicts can be seen with the
 'status' command, the corresponding files being prefixed with a
-'C'. This command also removes any <file>.{local,remote,older} files."""
+'C'. This command also removes any <file>.{ancestor,current,patched}
+files."""
 
 options = [make_option('-a', '--all',
                        help = 'mark all conflicts as solved',
-                       action = 'store_true')]
+                       action = 'store_true'),
+           make_option('-r', '--reset', metavar = '(ancestor|current|patched)',
+                       help = 'reset the file(s) to the given state')]
 
 
 def func(parser, options, args):
     """Mark the conflict as resolved
     """
+    if options.reset \
+           and options.reset not in file_extensions():
+        raise CmdException, 'Unknown reset state: %s' % options.reset
+
     if options.all:
-        resolved_all()
+        resolved_all(options.reset)
         return
 
     if len(args) == 0:
@@ -55,13 +63,13 @@ def func(parser, options, args):
             raise CmdException, 'No conflicts for "%s"' % filename
     # resolved
     for filename in args:
-        resolved(filename)
+        resolved(filename, options.reset)
         del conflicts[conflicts.index(filename)]
 
     # save or remove the conflicts file
     if conflicts == []:
-        os.remove(os.path.join(git.base_dir, 'conflicts'))
+        os.remove(os.path.join(basedir.get(), 'conflicts'))
     else:
-        f = file(os.path.join(git.base_dir, 'conflicts'), 'w+')
+        f = file(os.path.join(basedir.get(), 'conflicts'), 'w+')
         f.writelines([line + '\n' for line in conflicts])
         f.close()