8563834b5aca1ff16b48561d7bdba61f0034b9c1
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39 INIT = NamedConstant()
40 CONFIG = NamedConstant()
41 ROUTE = NamedConstant()
42 DROP = NamedConstant()
43 FLOW = NamedConstant()
44 HTTP = NamedConstant()
45 TWISTED = NamedConstant()
46 QUEUE = NamedConstant()
47 HTTP_CTRL = NamedConstant()
48 QUEUE_CTRL = NamedConstant()
49 HTTP_FULL = NamedConstant()
50 CTRL_DUMP = NamedConstant()
51 SLIP_FULL = NamedConstant()
52 DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66 if dflag not in debug_set: return
67 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68 if idof is not None:
69 msg = '[%#x] %s' % (id(idof), msg)
70 if d is not None:
71 trunc = ''
72 if not DBG.DATA_COMPLETE in debug_set:
73 if len(d) > 64:
74 d = d[0:64]
75 trunc = '...'
76 d = _hex_codec(d)[0].decode('ascii')
77 msg += ' ' + d + trunc
78 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82 def __call__(self, event):
83 yes = twisted.logger.PredicateResult.yes
84 no = twisted.logger.PredicateResult.no
85 try:
86 if event.get('log_level') != LogLevel.info:
87 return yes
88 dflag = event.get('dflag')
89 if dflag in debug_set: return yes
90 if dflag is None and DBG.TWISTED in debug_set: return yes
91 return no
92 except Exception:
93 print(traceback.format_exc(), file=org_stderr)
94 return yes
95
96 #---------- default config ----------
97
98 defcfg = '''
99 [DEFAULT]
100 max_batch_down = 65536
101 max_queue_time = 10
102 target_requests_outstanding = 3
103 http_timeout = 30
104 http_timeout_grace = 5
105 max_requests_outstanding = 6
106 max_batch_up = 4000
107 http_retry = 5
108
109 #[server] or [<client>] overrides
110 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
111
112 # relating to virtual network
113 mtu = 1500
114
115 [server]
116 # addrs = 127.0.0.1 ::1
117 port = 80
118 # url
119
120 # relating to virtual network
121 routes = ''
122 vnetwork = 172.24.230.192
123 # network = <prefix>/<len>
124 # server = <ipaddr>
125 # relay = <ipaddr>
126
127
128 # [<client-ip4-or-ipv6-address>]
129 # password = <password> # used by both, must match
130
131 [limits]
132 max_batch_down = 262144
133 max_queue_time = 121
134 http_timeout = 121
135 target_requests_outstanding = 10
136 '''
137
138 # these need to be defined here so that they can be imported by import *
139 cfg = ConfigParser(strict=False)
140 optparser = OptionParser()
141
142 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
143 def mime_translate(s):
144 # SLIP-encoded packets cannot contain ESC ESC.
145 # Swap `-' and ESC. The result cannot contain `--'
146 return s.translate(_mimetrans)
147
148 class ConfigResults:
149 def __init__(self, d = { }):
150 self.__dict__ = d
151 def __repr__(self):
152 return 'ConfigResults('+repr(self.__dict__)+')'
153
154 c = ConfigResults()
155
156 def log_discard(packet, iface, saddr, daddr, why):
157 log_debug(DBG.DROP,
158 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
159 d=packet)
160
161 #---------- packet parsing ----------
162
163 def packet_addrs(packet):
164 version = packet[0] >> 4
165 if version == 4:
166 addrlen = 4
167 saddroff = 3*4
168 factory = ipaddress.IPv4Address
169 elif version == 6:
170 addrlen = 16
171 saddroff = 2*4
172 factory = ipaddress.IPv6Address
173 else:
174 raise ValueError('unsupported IP version %d' % version)
175 saddr = factory(packet[ saddroff : saddroff + addrlen ])
176 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
177 return (saddr, daddr)
178
179 #---------- address handling ----------
180
181 def ipaddr(input):
182 try:
183 r = ipaddress.IPv4Address(input)
184 except AddressValueError:
185 r = ipaddress.IPv6Address(input)
186 return r
187
188 def ipnetwork(input):
189 try:
190 r = ipaddress.IPv4Network(input)
191 except NetworkValueError:
192 r = ipaddress.IPv6Network(input)
193 return r
194
195 #---------- ipif (SLIP) subprocess ----------
196
197 class SlipStreamDecoder():
198 def __init__(self, desc, on_packet):
199 self._buffer = b''
200 self._on_packet = on_packet
201 self._desc = desc
202 self._log('__init__')
203
204 def _log(self, msg, **kwargs):
205 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
206
207 def inputdata(self, data):
208 self._log('inputdata', d=data)
209 packets = slip.decode(data)
210 packets[0] = self._buffer + packets[0]
211 self._buffer = packets.pop()
212 for packet in packets:
213 self._maybe_packet(packet)
214 self._log('bufremain', d=self._buffer)
215
216 def _maybe_packet(self, packet):
217 self._log('maybepacket', d=packet)
218 if len(packet):
219 self._on_packet(packet)
220
221 def flush(self):
222 self._log('flush')
223 self._maybe_packet(self._buffer)
224 self._buffer = b''
225
226 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
227 def __init__(self, router):
228 self._router = router
229 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
230 def connectionMade(self): pass
231 def outReceived(self, data):
232 self._decoder.inputdata(data)
233 def slip_on_packet(self, packet):
234 (saddr, daddr) = packet_addrs(packet)
235 if saddr.is_link_local or daddr.is_link_local:
236 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
237 return
238 self._router(packet, saddr, daddr)
239 def processEnded(self, status):
240 status.raiseException()
241
242 def start_ipif(command, router):
243 global ipif
244 ipif = _IpifProcessProtocol(router)
245 reactor.spawnProcess(ipif,
246 '/bin/sh',['sh','-xc', command],
247 childFDs={0:'w', 1:'r', 2:2},
248 env=None)
249
250 def queue_inbound(packet):
251 log_debug(DBG.FLOW, "queue_inbound", d=packet)
252 ipif.transport.write(slip.delimiter)
253 ipif.transport.write(slip.encode(packet))
254 ipif.transport.write(slip.delimiter)
255
256 #---------- packet queue ----------
257
258 class PacketQueue():
259 def __init__(self, desc, max_queue_time):
260 self._desc = desc
261 assert(desc + '')
262 self._max_queue_time = max_queue_time
263 self._pq = collections.deque() # packets
264
265 def _log(self, dflag, msg, **kwargs):
266 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
267
268 def append(self, packet):
269 self._log(DBG.QUEUE, 'append', d=packet)
270 self._pq.append((time.monotonic(), packet))
271
272 def nonempty(self):
273 self._log(DBG.QUEUE, 'nonempty ?')
274 while True:
275 try: (queuetime, packet) = self._pq[0]
276 except IndexError:
277 self._log(DBG.QUEUE, 'nonempty ? empty.')
278 return False
279
280 age = time.monotonic() - queuetime
281 if age > self._max_queue_time:
282 # strip old packets off the front
283 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
284 self._pq.popleft()
285 continue
286
287 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
288 return True
289
290 def process(self, sizequery, moredata, max_batch):
291 # sizequery() should return size of batch so far
292 # moredata(s) should add s to batch
293 self._log(DBG.QUEUE, 'process...')
294 while True:
295 try: (dummy, packet) = self._pq[0]
296 except IndexError:
297 self._log(DBG.QUEUE, 'process... empty')
298 break
299
300 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
301
302 encoded = slip.encode(packet)
303 sofar = sizequery()
304
305 self._log(DBG.QUEUE_CTRL,
306 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
307 d=encoded)
308
309 if sofar > 0:
310 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
311 self._log(DBG.QUEUE_CTRL, 'process... overflow')
312 break
313 moredata(slip.delimiter)
314
315 moredata(encoded)
316 self._pq.popleft()
317
318 #---------- error handling ----------
319
320 _crashing = False
321
322 def crash(err):
323 global _crashing
324 _crashing = True
325 print('========== CRASH ==========', err,
326 '===========================', file=sys.stderr)
327 try: reactor.stop()
328 except twisted.internet.error.ReactorNotRunning: pass
329
330 def crash_on_defer(defer):
331 defer.addErrback(lambda err: crash(err))
332
333 def crash_on_critical(event):
334 if event.get('log_level') >= LogLevel.critical:
335 crash(twisted.logger.formatEvent(event))
336
337 #---------- config processing ----------
338
339 def process_cfg_common_always():
340 global mtu
341 c.mtu = cfg.get('virtual','mtu')
342
343 def process_cfg_ipif(section, varmap):
344 for d, s in varmap:
345 try: v = getattr(c, s)
346 except AttributeError: continue
347 setattr(c, d, v)
348
349 #print(repr(c))
350
351 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
352
353 def process_cfg_network():
354 c.network = ipnetwork(cfg.get('virtual','network'))
355 if c.network.num_addresses < 3 + 2:
356 raise ValueError('network needs at least 2^3 addresses')
357
358 def process_cfg_server():
359 try:
360 c.server = cfg.get('virtual','server')
361 except NoOptionError:
362 process_cfg_network()
363 c.server = next(c.network.hosts())
364
365 class ServerAddr():
366 def __init__(self, port, addrspec):
367 self.port = port
368 # also self.addr
369 try:
370 self.addr = ipaddress.IPv4Address(addrspec)
371 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
372 self._inurl = b'%s'
373 except AddressValueError:
374 self.addr = ipaddress.IPv6Address(addrspec)
375 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
376 self._inurl = b'[%s]'
377 def make_endpoint(self):
378 return self._endpointfactory(reactor, self.port, self.addr)
379 def url(self):
380 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
381 if self.port != 80: url += b':%d' % self.port
382 url += b'/'
383 return url
384
385 def process_cfg_saddrs():
386 try: port = cfg.getint('server','port')
387 except NoOptionError: port = 80
388
389 c.saddrs = [ ]
390 for addrspec in cfg.get('server','addrs').split():
391 sa = ServerAddr(port, addrspec)
392 c.saddrs.append(sa)
393
394 def process_cfg_clients(constructor):
395 c.clients = [ ]
396 for cs in cfg.sections():
397 if not (':' in cs or '.' in cs): continue
398 ci = ipaddr(cs)
399 pw = cfg.get(cs, 'password')
400 pw = pw.encode('utf-8')
401 constructor(ci,cs,pw)
402
403 #---------- startup ----------
404
405 def common_startup(process_cfg):
406 # ConfigParser hates #-comments after values
407 trailingcomments_re = regexp.compile('#.*')
408 cfg.read_string(trailingcomments_re.sub('', defcfg))
409 need_defcfg = True
410
411 def readconfig(pathname, mandatory=True):
412 def log(m, p=pathname):
413 if not DBG.CONFIG in debug_set: return
414 print('DBG.CONFIG: %s: %s' % (m, pathname))
415
416 try:
417 files = os.listdir(pathname)
418
419 except FileNotFoundError:
420 if mandatory: raise
421 log('skipped')
422 return
423
424 except NotADirectoryError:
425 cfg.read(pathname)
426 log('read file')
427 return
428
429 # is a directory
430 log('directory')
431 re = regexp.compile('[^-A-Za-z0-9_]')
432 for f in os.listdir(cdir):
433 if re.search(f): continue
434 subpath = pathname + '/' + f
435 try:
436 os.stat(subpath)
437 except FileNotFoundError:
438 log('entry skipped', subpath)
439 continue
440 cfg.read(subpath)
441 log('entry read', subpath)
442
443 def oc_config(od,os, value, op):
444 nonlocal need_defcfg
445 need_defcfg = False
446 readconfig(value)
447
448 def dfs_less_detailed(dl):
449 return [df for df in DBG.iterconstants() if df <= dl]
450
451 def ds_default(od,os,dl,op):
452 global debug_set
453 debug_set = set(dfs_less_detailed(debug_def_detail))
454
455 def ds_select(od,os, spec, op):
456 for it in spec.split(','):
457
458 if it.startswith('-'):
459 mutator = debug_set.discard
460 it = it[1:]
461 else:
462 mutator = debug_set.add
463
464 if it == '+':
465 dfs = DBG.iterconstants()
466
467 else:
468 if it.endswith('+'):
469 mapper = dfs_less_detailed
470 it = it[0:len(it)-1]
471 else:
472 mapper = lambda x: [x]
473
474 try:
475 dfspec = DBG.lookupByName(it)
476 except ValueError:
477 optparser.error('unknown debug flag %s in --debug-select' % it)
478
479 dfs = mapper(dfspec)
480
481 for df in dfs:
482 mutator(df)
483
484 optparser.add_option('-D', '--debug',
485 nargs=0,
486 action='callback',
487 help='enable default debug (to stdout)',
488 callback= ds_default)
489
490 optparser.add_option('--debug-select',
491 nargs=1,
492 type='string',
493 metavar='[-]DFLAG[+]|[-]+,...',
494 help=
495 '''enable (`-': disable) each specified DFLAG;
496 `+': do same for all "more interesting" DFLAGSs;
497 just `+': all DFLAGs.
498 DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
499 action='callback',
500 callback= ds_select)
501
502 optparser.add_option('-c', '--config',
503 nargs=1,
504 type='string',
505 metavar='CONFIGFILE',
506 dest='configfile',
507 action='callback',
508 callback= oc_config)
509
510 (opts, args) = optparser.parse_args()
511 if len(args): optparser.error('no non-option arguments please')
512
513 if need_defcfg:
514 readconfig('/etc/hippotat/config', False)
515 readconfig('/etc/hippotat/config.d', False)
516
517 try: process_cfg()
518 except (configparser.Error, ValueError):
519 traceback.print_exc(file=sys.stderr)
520 print('\nInvalid configuration, giving up.', file=sys.stderr)
521 sys.exit(12)
522
523 #print(repr(debug_set), file=sys.stderr)
524
525 log_formatter = twisted.logger.formatEventAsClassicLogText
526 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
527 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
528 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
529 stdsomething_obs = twisted.logger.FilteringLogObserver(
530 stderr_obs, [pred], stdout_obs
531 )
532 log_observer = twisted.logger.FilteringLogObserver(
533 stdsomething_obs, [LogNotBoringTwisted()]
534 )
535 #log_observer = stdsomething_obs
536 twisted.logger.globalLogBeginner.beginLoggingTo(
537 [ log_observer, crash_on_critical ]
538 )
539
540 def common_run():
541 log_debug(DBG.INIT, 'entering reactor')
542 if not _crashing: reactor.run()
543 print('CRASHED (end)', file=sys.stderr)