svc/conntrack.in: Contemplate multiple address families.
[tripe] / svc / conntrack.in
index 99491b1..cc753d4 100644 (file)
@@ -58,6 +58,8 @@ class struct (object):
 ### Address manipulation.
 
 class InetAddress (object):
+  AF = S.AF_INET
+  AFNAME = 'IPv4'
   def __init__(me, addrstr, maskstr = None):
     me.addr = me._addrstr_to_int(addrstr)
     if maskstr is None:
@@ -116,7 +118,7 @@ def straddr(a): return a is None and '#<none>' or str(a)
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
-TESTADDR = InetAddress('1.2.3.4')
+TESTADDRS = [InetAddress('1.2.3.4')]
 
 CONFSYNTAX = [
   ('COMMENT', RX.compile(r'^\s*($|[;#])')),
@@ -137,11 +139,11 @@ class Config (object):
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
-  list of (TAG, PEER, NET) triples.  The implication is that there should be
+  list of (TAG, PEER, NETS) triples.  The implication is that there should be
   precisely one peer from the set, and that it should be named TAG, where
-  (TAG, PEER, NET) is the first triple such that the host's primary IP
+  (TAG, PEER, NETS) is the first triple such that the host's primary IP
   address (if PEER is None -- or the IP address it would use for
-  communicating with PEER) is within the NET.
+  communicating with PEER) is within one of the networks defined by NETS.
   """
 
   def __init__(me, file):
@@ -167,7 +169,7 @@ class Config (object):
     if T._debug: print '# reread config'
 
     ## Initial state.
-    testaddr = None
+    testaddrs = {}
     groups = {}
     grpname = None
     grplist = []
@@ -215,9 +217,10 @@ class Config (object):
                   raise ConfigError(me._file, lno,
                                     "invalid IP address `%s': %s" %
                                     (astr, e))
-                if testaddr is not None:
-                  raise ConfigError(me._file, lno, 'duplicate test-address')
-                testaddr = a
+                if a.AF in testaddrs:
+                  raise ConfigError(me._file, lno,
+                                    'duplicate %s test-address' % a.AFNAME)
+                testaddrs[a.AF] = a
             else:
               raise ConfigError(me._file, lno,
                                 "unknown global option `%s'" % name)
@@ -230,6 +233,7 @@ class Config (object):
             ## Check for an explicit target address.
             if i >= len(spec) or spec[i].find('/') >= 0:
               peer = None
+              af = None
             else:
               try:
                 peer = parse_address(spec[i])
@@ -237,44 +241,61 @@ class Config (object):
                 raise ConfigError(me._file, lno,
                                   "invalid IP address `%s': %s" %
                                   (spec[i], e))
+              af = peer.AF
               i += 1
 
-            ## Parse the local network.
-            if len(spec) != i + 1:
-              raise ConfigError(me._file, lno, 'no network defined')
-            try:
-              net = parse_net(spec[i])
-            except Exception, e:
-              raise ConfigError(me._file, lno,
-                                "invalid IP network `%s': %s" %
-                                (spec[i], e))
+            ## Parse the list of local networks.
+            nets = []
+            while i < len(spec):
+              try:
+                net = parse_net(spec[i])
+              except Exception, e:
+                raise ConfigError(me._file, lno,
+                                  "invalid IP network `%s': %s" %
+                                  (spec[i], e))
+              else:
+                nets.append(net)
+              i += 1
+            if not nets:
+              raise ConfigError(me._file, lno, 'no networks defined')
+
+            ## Make sure that the addresses are consistent.
+            for net in nets:
+              if af is None:
+                af = net.AF
+              elif net.AF != af:
+                raise ConfigError(me._file, lno,
+                                  "net %s doesn't match" % net)
 
             ## Add this entry to the list.
-            grplist.append((name, peer, net))
+            grplist.append((name, peer, nets))
 
-    ## Fill in the default test address if necessary.
-    if testaddr is None: testaddr = TESTADDR
+    ## Fill in the default test addresses if necessary.
+    for a in TESTADDRS: testaddrs.setdefault(a.AF, a)
 
     ## Done.
     if grpname is not None: groups[grpname] = grplist
-    me.testaddr = testaddr
+    me.testaddrs = testaddrs
     me.groups = groups
 
 ### This will be a configuration file.
 CF = None
 
 def cmd_showconfig():
-  T.svcinfo('test-addr=%s' % CF.testaddr)
+  T.svcinfo('test-addr=%s' %
+            ' '.join(str(a)
+                     for a in sorted(CF.testaddrs.itervalues(),
+                                     key = lambda a: a.AFNAME)))
 def cmd_showgroups():
   for g in sorted(CF.groups.iterkeys()):
     T.svcinfo(g)
 def cmd_showgroup(g):
   try: pats = CF.groups[g]
   except KeyError: raise T.TripeJobError('unknown-group', g)
-  for t, p, n in pats:
+  for t, p, nn in pats:
     T.svcinfo('peer', t,
               'target', p and str(p) or '(default)',
-              'net', str(n))
+              'net', ' '.join(map(str, nn)))
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
@@ -283,12 +304,12 @@ def localaddr(peer):
   """
   Return the local IP address used for talking to PEER.
   """
-  sk = S.socket(S.AF_INET, S.SOCK_DGRAM)
+  sk = S.socket(peer.AF, S.SOCK_DGRAM)
   try:
     try:
       sk.connect(peer.sockaddr(1))
       addr = sk.getsockname()
-      return InetAddress.from_sockaddr(addr)[0]
+      return type(peer).from_sockaddr(addr)[0]
     except S.error:
       return None
   finally:
@@ -343,16 +364,18 @@ def kickpeers():
     ## Find the current list of peers.
     peers = SM.list()
 
-    ## Work out the primary IP address.
+    ## Work out the primary IP addresses.
+    locals = {}
     if upness:
-      addr = localaddr(CF.testaddr)
-      if addr is None:
-        upness = False
-    else:
-      addr = None
+      for af, remote in CF.testaddrs.iteritems():
+        local = localaddr(remote)
+        if local is not None: locals[af] = local
+      if not locals: upness = False
     if not T._debug: pass
-    elif addr: print '#   local address = %s' % straddr(addr)
-    else: print '#   offline'
+    elif not locals: print '#   offline'
+    else:
+      for local in locals.itervalues():
+        print '#   local %s address = %s' % (local.AFNAME, local)
 
     ## Now decide what to do.
     changes = []
@@ -360,27 +383,24 @@ def kickpeers():
       if T._debug: print '#   check group %s' % g
 
       ## Find out which peer in the group ought to be active.
-      ip = None
       statemap = {}
       want = None
-      for t, p, n in pp:
-        if p is None or not upness:
-          ipq = addr
-        else:
-          ipq = localaddr(p)
+      matchp = False
+      for t, p, nn in pp:
+        af = nn[0].AF
+        if p is None or not upness: ip = locals.get(af)
+        else: ip = localaddr(p)
         if T._debug:
-          info = 'peer=%s; target=%s; net=%s; local=%s' % (
-            t, p or '(default)', n, straddr(ipq))
-        if upness and ip is None and \
-              ipq is not None and ipq.withinp(n):
+          info = 'peer = %s; target = %s; nets = %s; local = %s' % (
+            t, p or '(default)', ', '.join(map(str, nn)), straddr(ip))
+        if upness and not matchp and \
+           ip is not None and any(ip.withinp(n) for n in nn):
           if T._debug: print '#     %s: SELECTED' % info
           statemap[t] = 'up'
           select.append('%s=%s' % (g, t))
-          if t == 'down' or t.startswith('down/'):
-            want = None
-          else:
-            want = t
-          ip = ipq
+          if t == 'down' or t.startswith('down/'): want = None
+          else: want = t
+          matchp = True
         else:
           statemap[t] = 'down'
           if T._debug: print '#     %s: skipped' % info