Let the caller supply the diff text to diffstat()
[stgit] / stgit / commands / diff.py
index aeca4ab..fd6be34 100644 (file)
@@ -27,7 +27,7 @@ from stgit import stack, git
 
 
 help = 'show the tree diff'
-usage = """%prog [options] [<files...>]
+usage = """%prog [options] [<files or dirs>]
 
 Show the diff (default) or diffstat between the current working copy
 or a tree-ish object and another tree-ish object. File names can also
@@ -42,19 +42,21 @@ rev = '([patch][//[bottom | top]]) | <tree-ish> | base'
 If neither bottom nor top are given but a '//' is present, the command
 shows the specified patch (defaulting to the current one)."""
 
+directory = DirectoryHasRepository()
 options = [make_option('-r', '--range',
                        metavar = 'rev1[..[rev2]]', dest = 'revs',
                        help = 'show the diff between revisions'),
-           make_option('-O', '--diff-opts',
-                       help = 'options to pass to git-diff'),
            make_option('-s', '--stat',
                        help = 'show the stat instead of the diff',
-                       action = 'store_true')]
-
+                       action = 'store_true')
+           ] + make_diff_opts_option()
 
 def func(parser, options, args):
     """Show the tree diff
     """
+    args = git.ls_files(args)
+    directory.cd_to_topdir()
+
     if options.revs:
         rev_list = options.revs.split('..')
         rev_list_len = len(rev_list)
@@ -79,15 +81,11 @@ def func(parser, options, args):
         rev1 = 'HEAD'
         rev2 = None
 
-    if options.diff_opts:
-        diff_flags = options.diff_opts.split()
-    else:
-        diff_flags = []
-
+    diff_str = git.diff(args, git_id(crt_series, rev1),
+                        git_id(crt_series, rev2),
+                        diff_flags = options.diff_flags)
     if options.stat:
-        out.stdout_raw(git.diffstat(args, git_id(rev1), git_id(rev2)) + '\n')
+        out.stdout_raw(git.diffstat(diff_str) + '\n')
     else:
-        diff_str = git.diff(args, git_id(rev1), git_id(rev2),
-                            diff_flags = diff_flags )
         if diff_str:
             pager(diff_str)