6af49b6da754e46f88a3e65b663aec32df46c5bd
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9
10 import ipaddress
11 from ipaddress import AddressValueError
12
13 import hippotat.slip as slip
14
15 from optparse import OptionParser
16 from configparser import ConfigParser
17 from configparser import NoOptionError
18
19 import collections
20
21 cfg = ConfigParser()
22 optparser = OptionParser()
23
24 #---------- packet parsing ----------
25
26 def packet_addrs(packet):
27 version = packet[0] >> 4
28 if version == 4:
29 addrlen = 4
30 saddroff = 3*4
31 factory = ipaddress.IPv4Address
32 elif version == 6:
33 addrlen = 16
34 saddroff = 2*4
35 factory = ipaddress.IPv6Address
36 else:
37 raise ValueError('unsupported IP version %d' % version)
38 saddr = factory(packet[ saddroff : saddroff + addrlen ])
39 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
40 return (saddr, daddr)
41
42 #---------- address handling ----------
43
44 def ipaddr(input):
45 try:
46 r = ipaddress.IPv4Address(input)
47 except AddressValueError:
48 r = ipaddress.IPv6Address(input)
49 return r
50
51 def ipnetwork(input):
52 try:
53 r = ipaddress.IPv4Network(input)
54 except NetworkValueError:
55 r = ipaddress.IPv6Network(input)
56 return r
57
58 #---------- ipif (SLIP) subprocess ----------
59
60 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
61 def __init__(self, router):
62 self._buffer = b''
63 self._router = router
64 def connectionMade(self): pass
65 def outReceived(self, data):
66 #print('RECV ', repr(data))
67 self._buffer += data
68 packets = slip.decode(self._buffer)
69 self._buffer = packets.pop()
70 for packet in packets:
71 if not len(packet): continue
72 (saddr, daddr) = packet_addrs(packet)
73 self._router(packet, saddr, daddr)
74 def processEnded(self, status):
75 status.raiseException()
76
77 def start_ipif(command, router):
78 global ipif
79 ipif = _IpifProcessProtocol(router)
80 reactor.spawnProcess(ipif,
81 '/bin/sh',['sh','-xc', command],
82 childFDs={0:'w', 1:'r', 2:2})
83
84 def queue_inbound(packet):
85 ipif.transport.write(slip.delimiter)
86 ipif.transport.write(slip.encode(packet))
87 ipif.transport.write(slip.delimiter)
88
89 #---------- packet queue ----------
90
91 class PacketQueue():
92 def __init__(self, max_queue_time):
93 self._max_queue_time = max_queue_time
94 self._pq = collections.deque() # packets
95
96 def append(self, packet):
97 self._pq.append((time.monotonic(), packet))
98
99 def nonempty(self):
100 while True:
101 try: (queuetime, packet) = self._pq[0]
102 except IndexError: return False
103
104 age = time.monotonic() - queuetime
105 if age > self.max_queue_time:
106 # strip old packets off the front
107 self._pq.popleft()
108 continue
109
110 return True
111
112 def popleft(self):
113 # caller must have checked nonempty
114 try: (dummy, packet) = self._pq[0]
115 except IndexError: return None
116 return packet
117
118 #---------- error handling ----------
119
120 def crash(err):
121 print('CRASH ', err, file=sys.stderr)
122 try: reactor.stop()
123 except twisted.internet.error.ReactorNotRunning: pass
124
125 def crash_on_defer(defer):
126 defer.addErrback(lambda err: crash(err))
127
128 def crash_on_critical(event):
129 if event.get('log_level') >= LogLevel.critical:
130 crash(twisted.logger.formatEvent(event))
131
132 #---------- startup ----------
133
134 def common_startup(defcfg):
135 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
136
137 optparser.add_option('-c', '--config', dest='configfile',
138 default='/etc/hippotat/config')
139 (opts, args) = optparser.parse_args()
140 if len(args): optparser.error('no non-option arguments please')
141
142 cfg.read_string(defcfg)
143 cfg.read(opts.configfile)
144
145 def common_run():
146 reactor.run()
147 print('CRASHED (end)', file=sys.stderr)