server/admin.c: Remove spurious `ping' in usage message.
[tripe] / svc / conntrack.in
index 261cf64..28e4b0b 100644 (file)
@@ -42,6 +42,7 @@ for i in ['mainloop', 'mainloop.glib']:
 try: from gi.repository import GLib as G
 except ImportError: import gobject as G
 from struct import pack, unpack
 try: from gi.repository import GLib as G
 except ImportError: import gobject as G
 from struct import pack, unpack
+from cStringIO import StringIO
 
 SM = T.svcmgr
 ##__import__('rmcr').__debug = True
 
 SM = T.svcmgr
 ##__import__('rmcr').__debug = True
@@ -54,52 +55,132 @@ class struct (object):
   def __init__(me, **kw):
     me.__dict__.update(kw)
 
   def __init__(me, **kw):
     me.__dict__.update(kw)
 
+def loadb(s):
+  n = 0
+  for ch in s: n = 256*n + ord(ch)
+  return n
+
+def storeb(n, wd = None):
+  if wd is None: wd = n.bit_length()
+  s = StringIO()
+  for i in xrange((wd - 1)&-8, -8, -8): s.write(chr((n >> i)&0xff))
+  return s.getvalue()
+
 ###--------------------------------------------------------------------------
 ### Address manipulation.
 ###--------------------------------------------------------------------------
 ### Address manipulation.
+###
+### I think this is the most demanding application, in terms of address
+### hacking, in the entire TrIPE suite.  At least we don't have to do it in
+### C.
 
 
-class InetAddress (object):
+class BaseAddress (object):
   def __init__(me, addrstr, maskstr = None):
   def __init__(me, addrstr, maskstr = None):
-    me.addr = me._addrstr_to_int(addrstr)
+    me._setaddr(addrstr)
     if maskstr is None:
       me.mask = -1
     elif maskstr.isdigit():
     if maskstr is None:
       me.mask = -1
     elif maskstr.isdigit():
-      me.mask = (1 << 32) - (1 << 32 - int(maskstr))
+      me.mask = (1 << me.NBITS) - (1 << me.NBITS - int(maskstr))
     else:
     else:
-      me.mask = me._addrstr_to_int(maskstr)
+      me._setmask(maskstr)
     if me.addr&~me.mask:
       raise ValueError('network contains bits set beyond mask')
   def _addrstr_to_int(me, addrstr):
     if me.addr&~me.mask:
       raise ValueError('network contains bits set beyond mask')
   def _addrstr_to_int(me, addrstr):
-    return unpack('>L', S.inet_aton(addrstr))[0]
+    try: return loadb(S.inet_pton(me.AF, addrstr))
+    except S.error: raise ValueError('bad address syntax')
   def _int_to_addrstr(me, n):
   def _int_to_addrstr(me, n):
-    return S.inet_ntoa(pack('>L', n))
+    return S.inet_ntop(me.AF, storeb(me.addr, me.NBITS))
+  def _setmask(me, maskstr):
+    raise ValueError('only prefix masked supported')
+  def _maskstr(me):
+    raise ValueError('only prefix masked supported')
   def sockaddr(me, port = 0):
     if me.mask != -1: raise ValueError('not a simple address')
   def sockaddr(me, port = 0):
     if me.mask != -1: raise ValueError('not a simple address')
-    return me._int_to_addrstr(me.addr), port
+    return me._sockaddr(port)
   def __str__(me):
   def __str__(me):
-    addrstr = me._int_to_addrstr(me.addr)
+    addrstr = me._addrstr()
     if me.mask == -1:
       return addrstr
     else:
     if me.mask == -1:
       return addrstr
     else:
-      inv = me.mask ^ ((1 << 32) - 1)
+      inv = me.mask ^ ((1 << me.NBITS) - 1)
       if (inv&(inv + 1)) == 0:
       if (inv&(inv + 1)) == 0:
-        return '%s/%d' % (addrstr, 32 - inv.bit_length())
+        return '%s/%d' % (addrstr, me.NBITS - inv.bit_length())
       else:
       else:
-        return '%s/%s' % (addrstr, me._int_to_addrstr(me.mask))
+        return '%s/%s' % (addrstr, me._maskstr())
   def withinp(me, net):
   def withinp(me, net):
+    if type(net) != type(me): return False
     if (me.mask&net.mask) != net.mask: return False
     if (me.addr ^ net.addr)&net.mask: return False
     if (me.mask&net.mask) != net.mask: return False
     if (me.addr ^ net.addr)&net.mask: return False
-    return True
+    return me._withinp(net)
   def eq(me, other):
   def eq(me, other):
+    if type(me) != type(other): return False
     if me.mask != other.mask: return False
     if me.addr != other.addr: return False
     if me.mask != other.mask: return False
     if me.addr != other.addr: return False
+    return me._eq(other)
+  def _withinp(me, net):
     return True
     return True
+  def _eq(me, other):
+    return True
+
+class InetAddress (BaseAddress):
+  AF = S.AF_INET
+  AFNAME = 'IPv4'
+  NBITS = 32
+  def _addrstr_to_int(me, addrstr):
+    try: return loadb(S.inet_aton(addrstr))
+    except S.error: raise ValueError('bad address syntax')
+  def _setaddr(me, addrstr):
+    me.addr = me._addrstr_to_int(addrstr)
+  def _setmask(me, maskstr):
+    me.mask = me._addrstr_to_int(maskstr)
+  def _addrstr(me):
+    return me._int_to_addrstr(me.addr)
+  def _maskstr(me):
+    return me._int_to_addrstr(me.mask)
+  def _sockaddr(me, port = 0):
+    return (me._addrstr(), port)
   @classmethod
   def from_sockaddr(cls, sa):
     addr, port = (lambda a, p: (a, p))(*sa)
     return cls(addr), port
 
   @classmethod
   def from_sockaddr(cls, sa):
     addr, port = (lambda a, p: (a, p))(*sa)
     return cls(addr), port
 
+class Inet6Address (BaseAddress):
+  AF = S.AF_INET6
+  AFNAME = 'IPv6'
+  NBITS = 128
+  def _setaddr(me, addrstr):
+    pc = addrstr.find('%')
+    if pc == -1:
+      me.addr = me._addrstr_to_int(addrstr)
+      me.scope = 0
+    else:
+      me.addr = me._addrstr_to_int(addrstr[:pc])
+      ais = S.getaddrinfo(addrstr, 0, S.AF_INET6, S.SOCK_DGRAM, 0,
+                          S.AI_NUMERICHOST | S.AI_NUMERICSERV)
+      me.scope = ais[0][4][3]
+  def _addrstr(me):
+    addrstr = me._int_to_addrstr(me.addr)
+    if me.scope == 0:
+      return addrstr
+    else:
+      name, _ = S.getnameinfo((addrstr, 0, 0, me.scope),
+                              S.NI_NUMERICHOST | S.NI_NUMERICSERV)
+      return name
+  def _sockaddr(me, port = 0):
+    return (me._addrstr(), port, 0, me.scope)
+  @classmethod
+  def from_sockaddr(cls, sa):
+    addr, port, _, scope = (lambda a, p, f = 0, s = 0: (a, p, f, s))(*sa)
+    me = cls(addr)
+    me.scope = scope
+    return me, port
+  def _withinp(me, net):
+    return net.scope == 0 or me.scope == net.scope
+  def _eq(me, other):
+    return me.scope == other.scope
+
 def parse_address(addrstr, maskstr = None):
 def parse_address(addrstr, maskstr = None):
-  return InetAddress(addrstr, maskstr)
+  if addrstr.find(':') >= 0: return Inet6Address(addrstr, maskstr)
+  else: return InetAddress(addrstr, maskstr)
 
 def parse_net(netstr):
   try: sl = netstr.index('/')
 
 def parse_net(netstr):
   try: sl = netstr.index('/')
@@ -116,7 +197,7 @@ def straddr(a): return a is None and '#<none>' or str(a)
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
-TESTADDR = InetAddress('1.2.3.4')
+TESTADDRS = [InetAddress('1.2.3.4'), Inet6Address('2001::1')]
 
 CONFSYNTAX = [
   ('COMMENT', RX.compile(r'^\s*($|[;#])')),
 
 CONFSYNTAX = [
   ('COMMENT', RX.compile(r'^\s*($|[;#])')),
@@ -137,11 +218,11 @@ class Config (object):
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
-  list of (TAG, PEER, NET) triples.  The implication is that there should be
+  list of (TAG, PEER, NETS) triples.  The implication is that there should be
   precisely one peer from the set, and that it should be named TAG, where
   precisely one peer from the set, and that it should be named TAG, where
-  (TAG, PEER, NET) is the first triple such that the host's primary IP
+  (TAG, PEER, NETS) is the first triple such that the host's primary IP
   address (if PEER is None -- or the IP address it would use for
   address (if PEER is None -- or the IP address it would use for
-  communicating with PEER) is within the NET.
+  communicating with PEER) is within one of the networks defined by NETS.
   """
 
   def __init__(me, file):
   """
 
   def __init__(me, file):
@@ -167,7 +248,7 @@ class Config (object):
     if T._debug: print '# reread config'
 
     ## Initial state.
     if T._debug: print '# reread config'
 
     ## Initial state.
-    testaddr = None
+    testaddrs = {}
     groups = {}
     grpname = None
     grplist = []
     groups = {}
     grpname = None
     grplist = []
@@ -215,9 +296,10 @@ class Config (object):
                   raise ConfigError(me._file, lno,
                                     "invalid IP address `%s': %s" %
                                     (astr, e))
                   raise ConfigError(me._file, lno,
                                     "invalid IP address `%s': %s" %
                                     (astr, e))
-                if testaddr is not None:
-                  raise ConfigError(me._file, lno, 'duplicate test-address')
-                testaddr = a
+                if a.AF in testaddrs:
+                  raise ConfigError(me._file, lno,
+                                    'duplicate %s test-address' % a.AFNAME)
+                testaddrs[a.AF] = a
             else:
               raise ConfigError(me._file, lno,
                                 "unknown global option `%s'" % name)
             else:
               raise ConfigError(me._file, lno,
                                 "unknown global option `%s'" % name)
@@ -230,6 +312,7 @@ class Config (object):
             ## Check for an explicit target address.
             if i >= len(spec) or spec[i].find('/') >= 0:
               peer = None
             ## Check for an explicit target address.
             if i >= len(spec) or spec[i].find('/') >= 0:
               peer = None
+              af = None
             else:
               try:
                 peer = parse_address(spec[i])
             else:
               try:
                 peer = parse_address(spec[i])
@@ -237,44 +320,61 @@ class Config (object):
                 raise ConfigError(me._file, lno,
                                   "invalid IP address `%s': %s" %
                                   (spec[i], e))
                 raise ConfigError(me._file, lno,
                                   "invalid IP address `%s': %s" %
                                   (spec[i], e))
+              af = peer.AF
               i += 1
 
               i += 1
 
-            ## Parse the local network.
-            if len(spec) != i + 1:
-              raise ConfigError(me._file, lno, 'no network defined')
-            try:
-              net = parse_net(spec[i])
-            except Exception, e:
-              raise ConfigError(me._file, lno,
-                                "invalid IP network `%s': %s" %
-                                (spec[i], e))
+            ## Parse the list of local networks.
+            nets = []
+            while i < len(spec):
+              try:
+                net = parse_net(spec[i])
+              except Exception, e:
+                raise ConfigError(me._file, lno,
+                                  "invalid IP network `%s': %s" %
+                                  (spec[i], e))
+              else:
+                nets.append(net)
+              i += 1
+            if not nets:
+              raise ConfigError(me._file, lno, 'no networks defined')
+
+            ## Make sure that the addresses are consistent.
+            for net in nets:
+              if af is None:
+                af = net.AF
+              elif net.AF != af:
+                raise ConfigError(me._file, lno,
+                                  "net %s doesn't match" % net)
 
             ## Add this entry to the list.
 
             ## Add this entry to the list.
-            grplist.append((name, peer, net))
+            grplist.append((name, peer, nets))
 
 
-    ## Fill in the default test address if necessary.
-    if testaddr is None: testaddr = TESTADDR
+    ## Fill in the default test addresses if necessary.
+    for a in TESTADDRS: testaddrs.setdefault(a.AF, a)
 
     ## Done.
     if grpname is not None: groups[grpname] = grplist
 
     ## Done.
     if grpname is not None: groups[grpname] = grplist
-    me.testaddr = testaddr
+    me.testaddrs = testaddrs
     me.groups = groups
 
 ### This will be a configuration file.
 CF = None
 
 def cmd_showconfig():
     me.groups = groups
 
 ### This will be a configuration file.
 CF = None
 
 def cmd_showconfig():
-  T.svcinfo('test-addr=%s' % CF.testaddr)
+  T.svcinfo('test-addr=%s' %
+            ' '.join(str(a)
+                     for a in sorted(CF.testaddrs.itervalues(),
+                                     key = lambda a: a.AFNAME)))
 def cmd_showgroups():
   for g in sorted(CF.groups.iterkeys()):
     T.svcinfo(g)
 def cmd_showgroup(g):
   try: pats = CF.groups[g]
   except KeyError: raise T.TripeJobError('unknown-group', g)
 def cmd_showgroups():
   for g in sorted(CF.groups.iterkeys()):
     T.svcinfo(g)
 def cmd_showgroup(g):
   try: pats = CF.groups[g]
   except KeyError: raise T.TripeJobError('unknown-group', g)
-  for t, p, n in pats:
+  for t, p, nn in pats:
     T.svcinfo('peer', t,
               'target', p and str(p) or '(default)',
     T.svcinfo('peer', t,
               'target', p and str(p) or '(default)',
-              'net', str(n))
+              'net', ' '.join(map(str, nn)))
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
@@ -283,12 +383,12 @@ def localaddr(peer):
   """
   Return the local IP address used for talking to PEER.
   """
   """
   Return the local IP address used for talking to PEER.
   """
-  sk = S.socket(S.AF_INET, S.SOCK_DGRAM)
+  sk = S.socket(peer.AF, S.SOCK_DGRAM)
   try:
     try:
       sk.connect(peer.sockaddr(1))
       addr = sk.getsockname()
   try:
     try:
       sk.connect(peer.sockaddr(1))
       addr = sk.getsockname()
-      return InetAddress.from_sockaddr(addr)[0]
+      return type(peer).from_sockaddr(addr)[0]
     except S.error:
       return None
   finally:
     except S.error:
       return None
   finally:
@@ -343,16 +443,18 @@ def kickpeers():
     ## Find the current list of peers.
     peers = SM.list()
 
     ## Find the current list of peers.
     peers = SM.list()
 
-    ## Work out the primary IP address.
+    ## Work out the primary IP addresses.
+    locals = {}
     if upness:
     if upness:
-      addr = localaddr(CF.testaddr)
-      if addr is None:
-        upness = False
-    else:
-      addr = None
+      for af, remote in CF.testaddrs.iteritems():
+        local = localaddr(remote)
+        if local is not None: locals[af] = local
+      if not locals: upness = False
     if not T._debug: pass
     if not T._debug: pass
-    elif addr: print '#   local address = %s' % straddr(addr)
-    else: print '#   offline'
+    elif not locals: print '#   offline'
+    else:
+      for local in locals.itervalues():
+        print '#   local %s address = %s' % (local.AFNAME, local)
 
     ## Now decide what to do.
     changes = []
 
     ## Now decide what to do.
     changes = []
@@ -363,14 +465,15 @@ def kickpeers():
       statemap = {}
       want = None
       matchp = False
       statemap = {}
       want = None
       matchp = False
-      for t, p, n in pp:
-        if p is None or not upness: ip = addr
+      for t, p, nn in pp:
+        af = nn[0].AF
+        if p is None or not upness: ip = locals.get(af)
         else: ip = localaddr(p)
         if T._debug:
         else: ip = localaddr(p)
         if T._debug:
-          info = 'peer=%s; target=%s; net=%s; local=%s' % (
-            t, p or '(default)', n, straddr(ip))
+          info = 'peer = %s; target = %s; nets = %s; local = %s' % (
+            t, p or '(default)', ', '.join(map(str, nn)), straddr(ip))
         if upness and not matchp and \
         if upness and not matchp and \
-              ip is not None and ip.withinp(n):
+           ip is not None and any(ip.withinp(n) for n in nn):
           if T._debug: print '#     %s: SELECTED' % info
           statemap[t] = 'up'
           select.append('%s=%s' % (g, t))
           if T._debug: print '#     %s: SELECTED' % info
           statemap[t] = 'up'
           select.append('%s=%s' % (g, t))