less absurd
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33 ROUTE = NamedConstant()
34 DROP = NamedConstant()
35 FLOW = NamedConstant()
36 HTTP = NamedConstant()
37 HTTP_CTRL = NamedConstant()
38 INIT = NamedConstant()
39 QUEUE = NamedConstant()
40 QUEUE_CTRL = NamedConstant()
41 HTTP_FULL = NamedConstant()
42 SLIP_FULL = NamedConstant()
43 CTRL_DUMP = NamedConstant()
44
45 _hex_codec = codecs.getencoder('hex_codec')
46
47 log = twisted.logger.Logger()
48
49 def log_debug(dflag, msg, idof=None, d=None):
50 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
51 if idof is not None:
52 msg = '[%#x] %s' % (id(idof), msg)
53 if d is not None:
54 d = d[0:64]
55 d = _hex_codec(d)[0].decode('ascii')
56 msg += ' ' + d
57 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
58
59 defcfg = '''
60 [DEFAULT]
61 #[<client>] overrides
62 max_batch_down = 65536 # used by server, subject to [limits]
63 max_queue_time = 10 # used by server, subject to [limits]
64 max_request_time = 54 # used by server, subject to [limits]
65 target_requests_outstanding = 3 # must match; subject to [limits] on server
66 max_requests_outstanding = 4 # used by client
67 max_batch_up = 4000 # used by client
68 http_timeout = 30 # used by client
69 http_retry = 5 # used by client
70
71 #[server] or [<client>] overrides
72 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
73 # extra interpolations: %(local)s %(peer)s %(rnet)s
74 # obtained on server [virtual]server [virtual]relay [virtual]network
75 # from on client <client> [virtual]server [virtual]routes
76
77 [virtual]
78 mtu = 1500
79 routes = ''
80 # network = <prefix>/<len> # mandatory for server
81 # server = <ipaddr> # used by both, default is computed from `network'
82 # relay = <ipaddr> # used by server, default from `network' and `server'
83 # default server is first host in network
84 # default relay is first host which is not server
85
86 [server]
87 # addrs = 127.0.0.1 ::1 # mandatory for server
88 port = 80 # used by server
89 # url # used by client; default from first `addrs' and `port'
90
91 # [<client-ip4-or-ipv6-address>]
92 # password = <password> # used by both, must match
93
94 [limits]
95 max_batch_down = 262144 # used by server
96 max_queue_time = 121 # used by server
97 max_request_time = 121 # used by server
98 target_requests_outstanding = 10 # used by server
99 '''
100
101 # these need to be defined here so that they can be imported by import *
102 cfg = ConfigParser()
103 optparser = OptionParser()
104
105 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
106 def mime_translate(s):
107 # SLIP-encoded packets cannot contain ESC ESC.
108 # Swap `-' and ESC. The result cannot contain `--'
109 return s.translate(_mimetrans)
110
111 class ConfigResults:
112 def __init__(self, d = { }):
113 self.__dict__ = d
114 def __repr__(self):
115 return 'ConfigResults('+repr(self.__dict__)+')'
116
117 c = ConfigResults()
118
119 def log_discard(packet, iface, saddr, daddr, why):
120 log_debug(DBG.DROP,
121 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
122 d=packet)
123
124 #---------- packet parsing ----------
125
126 def packet_addrs(packet):
127 version = packet[0] >> 4
128 if version == 4:
129 addrlen = 4
130 saddroff = 3*4
131 factory = ipaddress.IPv4Address
132 elif version == 6:
133 addrlen = 16
134 saddroff = 2*4
135 factory = ipaddress.IPv6Address
136 else:
137 raise ValueError('unsupported IP version %d' % version)
138 saddr = factory(packet[ saddroff : saddroff + addrlen ])
139 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
140 return (saddr, daddr)
141
142 #---------- address handling ----------
143
144 def ipaddr(input):
145 try:
146 r = ipaddress.IPv4Address(input)
147 except AddressValueError:
148 r = ipaddress.IPv6Address(input)
149 return r
150
151 def ipnetwork(input):
152 try:
153 r = ipaddress.IPv4Network(input)
154 except NetworkValueError:
155 r = ipaddress.IPv6Network(input)
156 return r
157
158 #---------- ipif (SLIP) subprocess ----------
159
160 class SlipStreamDecoder():
161 def __init__(self, desc, on_packet):
162 self._buffer = b''
163 self._on_packet = on_packet
164 self._desc = desc
165 self._log('__init__')
166
167 def _log(self, msg, **kwargs):
168 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
169
170 def inputdata(self, data):
171 self._log('inputdata', d=data)
172 data = self._buffer + data
173 self._buffer = b''
174 packets = slip.decode(data)
175 self._buffer = packets.pop()
176 for packet in packets:
177 self._maybe_packet(packet)
178 self._log('bufremain', d=self._buffer)
179
180 def _maybe_packet(self, packet):
181 self._log('maybepacket', d=packet)
182 if len(packet):
183 self._on_packet(packet)
184
185 def flush(self):
186 self._log('flush')
187 self._maybe_packet(self._buffer)
188 self._buffer = b''
189
190 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
191 def __init__(self, router):
192 self._router = router
193 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
194 def connectionMade(self): pass
195 def outReceived(self, data):
196 self._decoder.inputdata(data)
197 def slip_on_packet(self, packet):
198 (saddr, daddr) = packet_addrs(packet)
199 if saddr.is_link_local or daddr.is_link_local:
200 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
201 return
202 self._router(packet, saddr, daddr)
203 def processEnded(self, status):
204 status.raiseException()
205
206 def start_ipif(command, router):
207 global ipif
208 ipif = _IpifProcessProtocol(router)
209 reactor.spawnProcess(ipif,
210 '/bin/sh',['sh','-xc', command],
211 childFDs={0:'w', 1:'r', 2:2},
212 env=None)
213
214 def queue_inbound(packet):
215 log_debug(DBG.FLOW, "queue_inbound", d=packet)
216 ipif.transport.write(slip.delimiter)
217 ipif.transport.write(slip.encode(packet))
218 ipif.transport.write(slip.delimiter)
219
220 #---------- packet queue ----------
221
222 class PacketQueue():
223 def __init__(self, desc, max_queue_time):
224 self._desc = desc
225 assert(desc + '')
226 self._max_queue_time = max_queue_time
227 self._pq = collections.deque() # packets
228
229 def _log(self, dflag, msg, **kwargs):
230 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
231
232 def append(self, packet):
233 self._log(DBG.QUEUE, 'append', d=packet)
234 self._pq.append((time.monotonic(), packet))
235
236 def nonempty(self):
237 self._log(DBG.QUEUE, 'nonempty ?')
238 while True:
239 try: (queuetime, packet) = self._pq[0]
240 except IndexError:
241 self._log(DBG.QUEUE, 'nonempty ? empty.')
242 return False
243
244 age = time.monotonic() - queuetime
245 if age > self._max_queue_time:
246 # strip old packets off the front
247 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
248 self._pq.popleft()
249 continue
250
251 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
252 return True
253
254 def process(self, sizequery, moredata, max_batch):
255 # sizequery() should return size of batch so far
256 # moredata(s) should add s to batch
257 self._log(DBG.QUEUE, 'process...')
258 while True:
259 try: (dummy, packet) = self._pq[0]
260 except IndexError:
261 self._log(DBG.QUEUE, 'process... empty')
262 break
263
264 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
265
266 encoded = slip.encode(packet)
267 sofar = sizequery()
268
269 self._log(DBG.QUEUE_CTRL,
270 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
271 d=encoded)
272
273 if sofar > 0:
274 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
275 self._log(DBG.QUEUE_CTRL, 'process... overflow')
276 break
277 moredata(slip.delimiter)
278
279 moredata(encoded)
280 self._pq.popleft()
281
282 #---------- error handling ----------
283
284 _crashing = False
285
286 def crash(err):
287 global _crashing
288 _crashing = True
289 print('========== CRASH ==========', err,
290 '===========================', file=sys.stderr)
291 try: reactor.stop()
292 except twisted.internet.error.ReactorNotRunning: pass
293
294 def crash_on_defer(defer):
295 defer.addErrback(lambda err: crash(err))
296
297 def crash_on_critical(event):
298 if event.get('log_level') >= LogLevel.critical:
299 crash(twisted.logger.formatEvent(event))
300
301 #---------- config processing ----------
302
303 def process_cfg_common_always():
304 global mtu
305 c.mtu = cfg.get('virtual','mtu')
306
307 def process_cfg_ipif(section, varmap):
308 for d, s in varmap:
309 try: v = getattr(c, s)
310 except AttributeError: continue
311 setattr(c, d, v)
312
313 #print(repr(c))
314
315 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
316
317 def process_cfg_network():
318 c.network = ipnetwork(cfg.get('virtual','network'))
319 if c.network.num_addresses < 3 + 2:
320 raise ValueError('network needs at least 2^3 addresses')
321
322 def process_cfg_server():
323 try:
324 c.server = cfg.get('virtual','server')
325 except NoOptionError:
326 process_cfg_network()
327 c.server = next(c.network.hosts())
328
329 class ServerAddr():
330 def __init__(self, port, addrspec):
331 self.port = port
332 # also self.addr
333 try:
334 self.addr = ipaddress.IPv4Address(addrspec)
335 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
336 self._inurl = b'%s'
337 except AddressValueError:
338 self.addr = ipaddress.IPv6Address(addrspec)
339 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
340 self._inurl = b'[%s]'
341 def make_endpoint(self):
342 return self._endpointfactory(reactor, self.port, self.addr)
343 def url(self):
344 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
345 if self.port != 80: url += b':%d' % self.port
346 url += b'/'
347 return url
348
349 def process_cfg_saddrs():
350 try: port = cfg.getint('server','port')
351 except NoOptionError: port = 80
352
353 c.saddrs = [ ]
354 for addrspec in cfg.get('server','addrs').split():
355 sa = ServerAddr(port, addrspec)
356 c.saddrs.append(sa)
357
358 def process_cfg_clients(constructor):
359 c.clients = [ ]
360 for cs in cfg.sections():
361 if not (':' in cs or '.' in cs): continue
362 ci = ipaddr(cs)
363 pw = cfg.get(cs, 'password')
364 pw = pw.encode('utf-8')
365 constructor(ci,cs,pw)
366
367 #---------- startup ----------
368
369 def common_startup():
370 log_formatter = twisted.logger.formatEventAsClassicLogText
371 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
372 twisted.logger.globalLogBeginner.beginLoggingTo(
373 [ log_observer, crash_on_critical ]
374 )
375
376 optparser.add_option('-c', '--config', dest='configfile',
377 default='/etc/hippotat/config')
378 (opts, args) = optparser.parse_args()
379 if len(args): optparser.error('no non-option arguments please')
380
381 re = regexp.compile('#.*')
382 cfg.read_string(re.sub('', defcfg))
383 cfg.read(opts.configfile)
384
385 def common_run():
386 log_debug(DBG.INIT, 'entering reactor')
387 if not _crashing: reactor.run()
388 print('CRASHED (end)', file=sys.stderr)