svc/conntrack.in: Split out a base class from `InetAddress'.
[tripe] / svc / conntrack.in
index 043c969..bfdd7b8 100644 (file)
@@ -42,6 +42,7 @@ for i in ['mainloop', 'mainloop.glib']:
 try: from gi.repository import GLib as G
 except ImportError: import gobject as G
 from struct import pack, unpack
+from cStringIO import StringIO
 
 SM = T.svcmgr
 ##__import__('rmcr').__debug = True
@@ -54,45 +55,89 @@ class struct (object):
   def __init__(me, **kw):
     me.__dict__.update(kw)
 
+def loadb(s):
+  n = 0
+  for ch in s: n = 256*n + ord(ch)
+  return n
+
+def storeb(n, wd = None):
+  if wd is None: wd = n.bit_length()
+  s = StringIO()
+  for i in xrange((wd - 1)&-8, -8, -8): s.write(chr((n >> i)&0xff))
+  return s.getvalue()
+
 ###--------------------------------------------------------------------------
 ### Address manipulation.
+###
+### I think this is the most demanding application, in terms of address
+### hacking, in the entire TrIPE suite.  At least we don't have to do it in
+### C.
 
-class InetAddress (object):
+class BaseAddress (object):
   def __init__(me, addrstr, maskstr = None):
-    me.addr = me._addrstr_to_int(addrstr)
+    me._setaddr(addrstr)
     if maskstr is None:
       me.mask = -1
     elif maskstr.isdigit():
-      me.mask = (1 << 32) - (1 << 32 - int(maskstr))
+      me.mask = (1 << me.NBITS) - (1 << me.NBITS - int(maskstr))
     else:
-      me.mask = me._addrstr_to_int(maskstr)
+      me._setmask(maskstr)
     if me.addr&~me.mask:
       raise ValueError('network contains bits set beyond mask')
   def _addrstr_to_int(me, addrstr):
-    return unpack('>L', S.inet_aton(addrstr))[0]
+    try: return loadb(S.inet_pton(me.AF, addrstr))
+    except S.error: raise ValueError('bad address syntax')
   def _int_to_addrstr(me, n):
-    return S.inet_ntoa(pack('>L', n))
+    return S.inet_ntop(me.AF, storeb(me.addr, me.NBITS))
+  def _setmask(me, maskstr):
+    raise ValueError('only prefix masked supported')
+  def _maskstr(me):
+    raise ValueError('only prefix masked supported')
   def sockaddr(me, port = 0):
     if me.mask != -1: raise ValueError('not a simple address')
-    return me._int_to_addrstr(me.addr), port
+    return me._sockaddr(port)
   def __str__(me):
-    addrstr = me._int_to_addrstr(me.addr)
+    addrstr = me._addrstr()
     if me.mask == -1:
       return addrstr
     else:
-      inv = me.mask ^ ((1 << 32) - 1)
+      inv = me.mask ^ ((1 << me.NBITS) - 1)
       if (inv&(inv + 1)) == 0:
-        return '%s/%d' % (addrstr, 32 - inv.bit_length())
+        return '%s/%d' % (addrstr, me.NBITS - inv.bit_length())
       else:
-        return '%s/%s' % (addrstr, me._int_to_addrstr(me.mask))
+        return '%s/%s' % (addrstr, me._maskstr())
   def withinp(me, net):
+    if type(net) != type(me): return False
     if (me.mask&net.mask) != net.mask: return False
     if (me.addr ^ net.addr)&net.mask: return False
-    return True
+    return me._withinp(net)
   def eq(me, other):
+    if type(me) != type(other): return False
     if me.mask != other.mask: return False
     if me.addr != other.addr: return False
+    return me._eq(other)
+  def _withinp(me, net):
+    return True
+  def _eq(me, other):
     return True
+
+class InetAddress (BaseAddress):
+  AF = S.AF_INET
+  AFNAME = 'IPv4'
+  NBITS = 32
+  def _addrstr_to_int(me, addrstr):
+    try: return loadb(S.inet_aton(addrstr))
+    except S.error: raise ValueError('bad address syntax')
+  def _setaddr(me, addrstr):
+    me.addr = me._addrstr_to_int(addrstr)
+  def _setmask(me, maskstr):
+    me.mask = me._addrstr_to_int(maskstr)
+  def _addrstr(me):
+    return me._int_to_addrstr(me.addr)
+  def _maskstr(me):
+    return me._int_to_addrstr(me.mask)
+  def _sockaddr(me, port = 0):
+    return (me._addrstr(), port)
   @classmethod
   def from_sockaddr(cls, sa):
     addr, port = (lambda a, p: (a, p))(*sa)
@@ -116,7 +161,7 @@ def straddr(a): return a is None and '#<none>' or str(a)
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
-TESTADDR = InetAddress('1.2.3.4')
+TESTADDRS = [InetAddress('1.2.3.4')]
 
 CONFSYNTAX = [
   ('COMMENT', RX.compile(r'^\s*($|[;#])')),
@@ -137,11 +182,11 @@ class Config (object):
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
-  list of (TAG, PEER, NET) triples.  The implication is that there should be
+  list of (TAG, PEER, NETS) triples.  The implication is that there should be
   precisely one peer from the set, and that it should be named TAG, where
-  (TAG, PEER, NET) is the first triple such that the host's primary IP
+  (TAG, PEER, NETS) is the first triple such that the host's primary IP
   address (if PEER is None -- or the IP address it would use for
-  communicating with PEER) is within the NET.
+  communicating with PEER) is within one of the networks defined by NETS.
   """
 
   def __init__(me, file):
@@ -167,7 +212,7 @@ class Config (object):
     if T._debug: print '# reread config'
 
     ## Initial state.
-    testaddr = None
+    testaddrs = {}
     groups = {}
     grpname = None
     grplist = []
@@ -215,9 +260,10 @@ class Config (object):
                   raise ConfigError(me._file, lno,
                                     "invalid IP address `%s': %s" %
                                     (astr, e))
-                if testaddr is not None:
-                  raise ConfigError(me._file, lno, 'duplicate test-address')
-                testaddr = a
+                if a.AF in testaddrs:
+                  raise ConfigError(me._file, lno,
+                                    'duplicate %s test-address' % a.AFNAME)
+                testaddrs[a.AF] = a
             else:
               raise ConfigError(me._file, lno,
                                 "unknown global option `%s'" % name)
@@ -230,6 +276,7 @@ class Config (object):
             ## Check for an explicit target address.
             if i >= len(spec) or spec[i].find('/') >= 0:
               peer = None
+              af = None
             else:
               try:
                 peer = parse_address(spec[i])
@@ -237,44 +284,61 @@ class Config (object):
                 raise ConfigError(me._file, lno,
                                   "invalid IP address `%s': %s" %
                                   (spec[i], e))
+              af = peer.AF
               i += 1
 
-            ## Parse the local network.
-            if len(spec) != i + 1:
-              raise ConfigError(me._file, lno, 'no network defined')
-            try:
-              net = parse_net(spec[i])
-            except Exception, e:
-              raise ConfigError(me._file, lno,
-                                "invalid IP network `%s': %s" %
-                                (spec[i], e))
+            ## Parse the list of local networks.
+            nets = []
+            while i < len(spec):
+              try:
+                net = parse_net(spec[i])
+              except Exception, e:
+                raise ConfigError(me._file, lno,
+                                  "invalid IP network `%s': %s" %
+                                  (spec[i], e))
+              else:
+                nets.append(net)
+              i += 1
+            if not nets:
+              raise ConfigError(me._file, lno, 'no networks defined')
+
+            ## Make sure that the addresses are consistent.
+            for net in nets:
+              if af is None:
+                af = net.AF
+              elif net.AF != af:
+                raise ConfigError(me._file, lno,
+                                  "net %s doesn't match" % net)
 
             ## Add this entry to the list.
-            grplist.append((name, peer, net))
+            grplist.append((name, peer, nets))
 
-    ## Fill in the default test address if necessary.
-    if testaddr is None: testaddr = TESTADDR
+    ## Fill in the default test addresses if necessary.
+    for a in TESTADDRS: testaddrs.setdefault(a.AF, a)
 
     ## Done.
     if grpname is not None: groups[grpname] = grplist
-    me.testaddr = testaddr
+    me.testaddrs = testaddrs
     me.groups = groups
 
 ### This will be a configuration file.
 CF = None
 
 def cmd_showconfig():
-  T.svcinfo('test-addr=%s' % CF.testaddr)
+  T.svcinfo('test-addr=%s' %
+            ' '.join(str(a)
+                     for a in sorted(CF.testaddrs.itervalues(),
+                                     key = lambda a: a.AFNAME)))
 def cmd_showgroups():
   for g in sorted(CF.groups.iterkeys()):
     T.svcinfo(g)
 def cmd_showgroup(g):
   try: pats = CF.groups[g]
   except KeyError: raise T.TripeJobError('unknown-group', g)
-  for t, p, n in pats:
+  for t, p, nn in pats:
     T.svcinfo('peer', t,
               'target', p and str(p) or '(default)',
-              'net', str(n))
+              'net', ' '.join(map(str, nn)))
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
@@ -283,12 +347,12 @@ def localaddr(peer):
   """
   Return the local IP address used for talking to PEER.
   """
-  sk = S.socket(S.AF_INET, S.SOCK_DGRAM)
+  sk = S.socket(peer.AF, S.SOCK_DGRAM)
   try:
     try:
       sk.connect(peer.sockaddr(1))
       addr = sk.getsockname()
-      return InetAddress.from_sockaddr(addr)[0]
+      return type(peer).from_sockaddr(addr)[0]
     except S.error:
       return None
   finally:
@@ -343,16 +407,18 @@ def kickpeers():
     ## Find the current list of peers.
     peers = SM.list()
 
-    ## Work out the primary IP address.
+    ## Work out the primary IP addresses.
+    locals = {}
     if upness:
-      addr = localaddr(CF.testaddr)
-      if addr is None:
-        upness = False
-    else:
-      addr = None
+      for af, remote in CF.testaddrs.iteritems():
+        local = localaddr(remote)
+        if local is not None: locals[af] = local
+      if not locals: upness = False
     if not T._debug: pass
-    elif addr: print '#   local address = %s' % straddr(addr)
-    else: print '#   offline'
+    elif not locals: print '#   offline'
+    else:
+      for local in locals.itervalues():
+        print '#   local %s address = %s' % (local.AFNAME, local)
 
     ## Now decide what to do.
     changes = []
@@ -360,36 +426,33 @@ def kickpeers():
       if T._debug: print '#   check group %s' % g
 
       ## Find out which peer in the group ought to be active.
-      ip = None
-      map = {}
+      statemap = {}
       want = None
-      for t, p, n in pp:
-        if p is None or not upness:
-          ipq = addr
-        else:
-          ipq = localaddr(p)
+      matchp = False
+      for t, p, nn in pp:
+        af = nn[0].AF
+        if p is None or not upness: ip = locals.get(af)
+        else: ip = localaddr(p)
         if T._debug:
-          info = 'peer=%s; target=%s; net=%s; local=%s' % (
-            t, p or '(default)', n, straddr(ipq))
-        if upness and ip is None and \
-              ipq is not None and ipq.withinp(n):
+          info = 'peer = %s; target = %s; nets = %s; local = %s' % (
+            t, p or '(default)', ', '.join(map(str, nn)), straddr(ip))
+        if upness and not matchp and \
+           ip is not None and any(ip.withinp(n) for n in nn):
           if T._debug: print '#     %s: SELECTED' % info
-          map[t] = 'up'
+          statemap[t] = 'up'
           select.append('%s=%s' % (g, t))
-          if t == 'down' or t.startswith('down/'):
-            want = None
-          else:
-            want = t
-          ip = ipq
+          if t == 'down' or t.startswith('down/'): want = None
+          else: want = t
+          matchp = True
         else:
-          map[t] = 'down'
+          statemap[t] = 'down'
           if T._debug: print '#     %s: skipped' % info
 
       ## Shut down the wrong ones.
       found = False
-      if T._debug: print '#   peer-map = %r' % map
+      if T._debug: print '#   peer-map = %r' % statemap
       for p in peers:
-        what = map.get(p, 'leave')
+        what = statemap.get(p, 'leave')
         if what == 'up':
           found = True
           if T._debug: print '#   peer %s: already up' % p