Handle branch names with slashes
[stgit] / stgit / commands / common.py
index c6ca514..9b97eb6 100644 (file)
@@ -18,7 +18,7 @@ along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 """
 
-import sys, os, re
+import sys, os, os.path, re
 from optparse import OptionParser, make_option
 
 from stgit.utils import *
@@ -34,54 +34,74 @@ class CmdException(Exception):
 
 
 # Utility functions
+class RevParseException(Exception):
+    """Revision spec parse error."""
+    pass
+
+def parse_rev(rev):
+    """Parse a revision specification into its
+    patchname@branchname//patch_id parts. If no branch name has a slash
+    in it, also accept / instead of //."""
+    files, dirs = list_files_and_dirs(os.path.join(basedir.get(),
+                                                   'refs', 'heads'))
+    if len(dirs) != 0:
+        # We have branch names with / in them.
+        branch_chars = r'[^@]'
+        patch_id_mark = r'//'
+    else:
+        # No / in branch names.
+        branch_chars = r'[^@/]'
+        patch_id_mark = r'(/|//)'
+    patch_re = r'(?P<patch>[^@/]+)'
+    branch_re = r'@(?P<branch>%s+)' % branch_chars
+    patch_id_re = r'%s(?P<patch_id>[a-z.]*)' % patch_id_mark
+
+    # Try //patch_id.
+    m = re.match(r'^%s$' % patch_id_re, rev)
+    if m:
+        return None, None, m.group('patch_id')
+
+    # Try path[@branch]//patch_id.
+    m = re.match(r'^%s(%s)?%s$' % (patch_re, branch_re, patch_id_re), rev)
+    if m:
+        return m.group('patch'), m.group('branch'), m.group('patch_id')
+
+    # Try patch[@branch].
+    m = re.match(r'^%s(%s)?$' % (patch_re, branch_re), rev)
+    if m:
+        return m.group('patch'), m.group('branch'), None
+
+    # No, we can't parse that.
+    raise RevParseException
+
 def git_id(rev):
     """Return the GIT id
     """
     if not rev:
         return None
-    
-    rev_list = rev.split('/')
-    if len(rev_list) == 2:
-        patch_id = rev_list[1]
-        if not patch_id:
-            patch_id = 'top'
-    elif len(rev_list) == 1:
-        patch_id = 'top'
-    else:
-        patch_id = None
-
-    patch_branch = rev_list[0].split('@')
-    if len(patch_branch) == 1:
-        series = crt_series
-    elif len(patch_branch) == 2:
-        series = stack.Series(patch_branch[1])
-    else:
-        raise CmdException, 'Unknown id: %s' % rev
-
-    patch_name = patch_branch[0]
-    if not patch_name:
-        patch_name = series.get_current()
-        if not patch_name:
-            raise CmdException, 'No patches applied'
-
-    # patch
-    if patch_name in series.get_applied() \
-           or patch_name in series.get_unapplied():
-        if patch_id == 'top':
-            return series.get_patch(patch_name).get_top()
-        elif patch_id == 'bottom':
-            return series.get_patch(patch_name).get_bottom()
-        # Note we can return None here.
-        elif patch_id == 'top.old':
-            return series.get_patch(patch_name).get_old_top()
-        elif patch_id == 'bottom.old':
-            return series.get_patch(patch_name).get_old_bottom()
-
-    # base
-    if patch_name == 'base' and len(rev_list) == 1:
-        return read_string(series.get_base_file())
-
-    # anything else failed
+    try:
+        patch, branch, patch_id = parse_rev(rev)
+        if branch == None:
+            series = crt_series
+        else:
+            series = stack.Series(branch)
+        if patch == None:
+            patch = series.get_current()
+            if not patch:
+                raise CmdException, 'No patches applied'
+        if patch in series.get_applied() or patch in series.get_unapplied():
+            if patch_id in ['top', '', None]:
+                return series.get_patch(patch).get_top()
+            elif patch_id == 'bottom':
+                return series.get_patch(patch).get_bottom()
+            elif patch_id == 'top.old':
+                return series.get_patch(patch).get_old_top()
+            elif patch_id == 'bottom.old':
+                return series.get_patch(patch).get_old_bottom()
+        if patch == 'base' and patch_id == None:
+            return read_string(series.get_base_file())
+    except RevParseException:
+        pass
     return git.rev_parse(rev + '^{commit}')
 
 def check_local_changes():