8a21966e279c175af286cd6855e7255339996d33
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39 INIT = NamedConstant()
40 CONFIG = NamedConstant()
41 ROUTE = NamedConstant()
42 DROP = NamedConstant()
43 FLOW = NamedConstant()
44 HTTP = NamedConstant()
45 TWISTED = NamedConstant()
46 QUEUE = NamedConstant()
47 HTTP_CTRL = NamedConstant()
48 QUEUE_CTRL = NamedConstant()
49 HTTP_FULL = NamedConstant()
50 CTRL_DUMP = NamedConstant()
51 SLIP_FULL = NamedConstant()
52 DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66 if dflag not in debug_set: return
67 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68 if idof is not None:
69 msg = '[%#x] %s' % (id(idof), msg)
70 if d is not None:
71 trunc = ''
72 if not DBG.DATA_COMPLETE in debug_set:
73 if len(d) > 64:
74 d = d[0:64]
75 trunc = '...'
76 d = _hex_codec(d)[0].decode('ascii')
77 msg += ' ' + d + trunc
78 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82 def __call__(self, event):
83 yes = twisted.logger.PredicateResult.yes
84 no = twisted.logger.PredicateResult.no
85 try:
86 if event.get('log_level') != LogLevel.info:
87 return yes
88 dflag = event.get('dflag')
89 if dflag is False : return yes
90 if dflag in debug_set: return yes
91 if dflag is None and DBG.TWISTED in debug_set: return yes
92 return no
93 except Exception:
94 print(traceback.format_exc(), file=org_stderr)
95 return yes
96
97 #---------- default config ----------
98
99 defcfg = '''
100 [DEFAULT]
101 max_batch_down = 65536
102 max_queue_time = 10
103 target_requests_outstanding = 3
104 http_timeout = 30
105 http_timeout_grace = 5
106 max_requests_outstanding = 6
107 max_batch_up = 4000
108 http_retry = 5
109 port = 80
110 vroutes = ''
111
112 #[server] or [<client>] overrides
113 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
114
115 # relating to virtual network
116 mtu = 1500
117
118 [SERVER]
119 server = SERVER
120 # addrs = 127.0.0.1 ::1
121 # url
122
123 # relating to virtual network
124 vvnetwork = 172.24.230.192
125 # vnetwork = <prefix>/<len>
126 # vadd r = <ipaddr>
127 # vrelay = <ipaddr>
128
129
130 # [<client-ip4-or-ipv6-address>]
131 # password = <password> # used by both, must match
132
133 [LIMIT]
134 max_batch_down = 262144
135 max_queue_time = 121
136 http_timeout = 121
137 target_requests_outstanding = 10
138 '''
139
140 # these need to be defined here so that they can be imported by import *
141 cfg = ConfigParser(strict=False)
142 optparser = OptionParser()
143
144 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
145 def mime_translate(s):
146 # SLIP-encoded packets cannot contain ESC ESC.
147 # Swap `-' and ESC. The result cannot contain `--'
148 return s.translate(_mimetrans)
149
150 class ConfigResults:
151 def __init__(self):
152 pass
153 def __repr__(self):
154 return 'ConfigResults('+repr(self.__dict__)+')'
155
156 def log_discard(packet, iface, saddr, daddr, why):
157 log_debug(DBG.DROP,
158 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
159 d=packet)
160
161 #---------- packet parsing ----------
162
163 def packet_addrs(packet):
164 version = packet[0] >> 4
165 if version == 4:
166 addrlen = 4
167 saddroff = 3*4
168 factory = ipaddress.IPv4Address
169 elif version == 6:
170 addrlen = 16
171 saddroff = 2*4
172 factory = ipaddress.IPv6Address
173 else:
174 raise ValueError('unsupported IP version %d' % version)
175 saddr = factory(packet[ saddroff : saddroff + addrlen ])
176 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
177 return (saddr, daddr)
178
179 #---------- address handling ----------
180
181 def ipaddr(input):
182 try:
183 r = ipaddress.IPv4Address(input)
184 except AddressValueError:
185 r = ipaddress.IPv6Address(input)
186 return r
187
188 def ipnetwork(input):
189 try:
190 r = ipaddress.IPv4Network(input)
191 except NetworkValueError:
192 r = ipaddress.IPv6Network(input)
193 return r
194
195 #---------- ipif (SLIP) subprocess ----------
196
197 class SlipStreamDecoder():
198 def __init__(self, desc, on_packet):
199 self._buffer = b''
200 self._on_packet = on_packet
201 self._desc = desc
202 self._log('__init__')
203
204 def _log(self, msg, **kwargs):
205 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
206
207 def inputdata(self, data):
208 self._log('inputdata', d=data)
209 packets = slip.decode(data)
210 packets[0] = self._buffer + packets[0]
211 self._buffer = packets.pop()
212 for packet in packets:
213 self._maybe_packet(packet)
214 self._log('bufremain', d=self._buffer)
215
216 def _maybe_packet(self, packet):
217 self._log('maybepacket', d=packet)
218 if len(packet):
219 self._on_packet(packet)
220
221 def flush(self):
222 self._log('flush')
223 self._maybe_packet(self._buffer)
224 self._buffer = b''
225
226 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
227 def __init__(self, router):
228 self._router = router
229 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
230 def connectionMade(self): pass
231 def outReceived(self, data):
232 self._decoder.inputdata(data)
233 def slip_on_packet(self, packet):
234 (saddr, daddr) = packet_addrs(packet)
235 if saddr.is_link_local or daddr.is_link_local:
236 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
237 return
238 self._router(packet, saddr, daddr)
239 def processEnded(self, status):
240 status.raiseException()
241
242 def start_ipif(command, router):
243 global ipif
244 ipif = _IpifProcessProtocol(router)
245 reactor.spawnProcess(ipif,
246 '/bin/sh',['sh','-xc', command],
247 childFDs={0:'w', 1:'r', 2:2},
248 env=None)
249
250 def queue_inbound(packet):
251 log_debug(DBG.FLOW, "queue_inbound", d=packet)
252 ipif.transport.write(slip.delimiter)
253 ipif.transport.write(slip.encode(packet))
254 ipif.transport.write(slip.delimiter)
255
256 #---------- packet queue ----------
257
258 class PacketQueue():
259 def __init__(self, desc, max_queue_time):
260 self._desc = desc
261 assert(desc + '')
262 self._max_queue_time = max_queue_time
263 self._pq = collections.deque() # packets
264
265 def _log(self, dflag, msg, **kwargs):
266 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
267
268 def append(self, packet):
269 self._log(DBG.QUEUE, 'append', d=packet)
270 self._pq.append((time.monotonic(), packet))
271
272 def nonempty(self):
273 self._log(DBG.QUEUE, 'nonempty ?')
274 while True:
275 try: (queuetime, packet) = self._pq[0]
276 except IndexError:
277 self._log(DBG.QUEUE, 'nonempty ? empty.')
278 return False
279
280 age = time.monotonic() - queuetime
281 if age > self._max_queue_time:
282 # strip old packets off the front
283 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
284 self._pq.popleft()
285 continue
286
287 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
288 return True
289
290 def process(self, sizequery, moredata, max_batch):
291 # sizequery() should return size of batch so far
292 # moredata(s) should add s to batch
293 self._log(DBG.QUEUE, 'process...')
294 while True:
295 try: (dummy, packet) = self._pq[0]
296 except IndexError:
297 self._log(DBG.QUEUE, 'process... empty')
298 break
299
300 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
301
302 encoded = slip.encode(packet)
303 sofar = sizequery()
304
305 self._log(DBG.QUEUE_CTRL,
306 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
307 d=encoded)
308
309 if sofar > 0:
310 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
311 self._log(DBG.QUEUE_CTRL, 'process... overflow')
312 break
313 moredata(slip.delimiter)
314
315 moredata(encoded)
316 self._pq.popleft()
317
318 #---------- error handling ----------
319
320 _crashing = False
321
322 def crash(err):
323 global _crashing
324 _crashing = True
325 print('========== CRASH ==========', err,
326 '===========================', file=sys.stderr)
327 try: reactor.stop()
328 except twisted.internet.error.ReactorNotRunning: pass
329
330 def crash_on_defer(defer):
331 defer.addErrback(lambda err: crash(err))
332
333 def crash_on_critical(event):
334 if event.get('log_level') >= LogLevel.critical:
335 crash(twisted.logger.formatEvent(event))
336
337 #---------- config processing ----------
338
339 def _cfg_process_putatives():
340 servers = { }
341 clients = { }
342 # maps from abstract object to canonical name for cs's
343
344 def putative(cmap, abstract, canoncs):
345 try:
346 current_canoncs = cmap[abstract]
347 except KeyError:
348 pass
349 else:
350 assert(current_canoncs == canoncs)
351 cmap[abstract] = canoncs
352
353 server_pat = r'[-.0-9A-Za-z]+'
354 client_pat = r'[.:0-9a-f]+'
355 server_re = regexp.compile(server_pat)
356 serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
357
358 for cs in cfg.sections():
359 if cs == 'LIMIT':
360 # plan A "[LIMIT]"
361 continue
362
363 try:
364 # plan B "[<client>]" part 1
365 ci = ipaddr(cs)
366 except AddressValueError:
367
368 if server_re.fullmatch(cs):
369 # plan C "[<servername>]"
370 putative(servers, cs, cs)
371 continue
372
373 if serverclient_re.fullmatch(cs):
374 # plan D "[<servername> <client>]" part 1
375 (pss,pcs) = cs.split(' ')
376
377 if pcs == 'LIMIT':
378 # plan E "[<servername> LIMIT]"
379 continue
380
381 try:
382 # plan D "[<servername> <client>]" part 2
383 ci = ipaddr(pc)
384 except AddressValueError:
385 # plan F "[<some thing we do not understand>]"
386 # well, we ignore this
387 print('warning: ignoring config section %s' % cs, file=sys.stderr)
388 continue
389
390 else: # no AddressValueError
391 # plan D "[<servername> <client]" part 3
392 putative(clients, ci, pcs)
393 putative(servers, pss, pss)
394 continue
395
396 else: # no AddressValueError
397 # plan B "[<client>" part 2
398 putative(clients, ci, cs)
399 continue
400
401 return (servers, clients)
402
403 def cfg_process_common(ss):
404 c.mtu = cfg.getint(ss, 'mtu')
405
406 def cfg_process_saddrs(c, ss):
407 class ServerAddr():
408 def __init__(self, port, addrspec):
409 self.port = port
410 # also self.addr
411 try:
412 self.addr = ipaddress.IPv4Address(addrspec)
413 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
414 self._inurl = b'%s'
415 except AddressValueError:
416 self.addr = ipaddress.IPv6Address(addrspec)
417 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
418 self._inurl = b'[%s]'
419 def make_endpoint(self):
420 return self._endpointfactory(reactor, self.port, self.addr)
421 def url(self):
422 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
423 if self.port != 80: url += b':%d' % self.port
424 url += b'/'
425 return url
426
427 c.port = cfg.getint(ss,'port')
428 c.saddrs = [ ]
429 for addrspec in cfg.get(ss, 'addrs').split():
430 sa = ServerAddr(c.port, addrspec)
431 c.saddrs.append(sa)
432
433 def cfg_process_vnetwork(c, ss):
434 c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
435 if c.vnetwork.num_addresses < 3 + 2:
436 raise ValueError('vnetwork needs at least 2^3 addresses')
437
438 def cfg_process_vaddr(c, ss):
439 try:
440 c.vaddr = cfg.get(ss,'vaddr')
441 except NoOptionError:
442 cfg_process_vnetwork(c, ss)
443 c.vaddr = next(c.vnetwork.hosts())
444
445 def cfg_search_section(key,sections):
446 for section in sections:
447 if cfg.has_option(section, key):
448 return section
449 raise NoOptionError(key, repr(sections))
450
451 def cfg_search(getter,key,sections):
452 section = cfg_search_section(key,sections)
453 return getter(section, key)
454
455 def cfg_process_client_limited(cc,ss,sections,key):
456 val = cfg_search(cfg.getint, key, sections)
457 lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
458 cc.__dict__[key] = min(val,lim)
459
460 def cfg_process_client_common(cc,ss,cs,ci):
461 # returns sections to search in, iff password is defined, otherwise None
462 cc.ci = ci
463
464 sections = ['%s %s' % (ss,cs),
465 cs,
466 ss,
467 'DEFAULT']
468
469 try: pwsection = cfg_search_section('password', sections)
470 except NoOptionError: return None
471
472 pw = cfg.get(pwsection, 'password')
473 cc.password = pw.encode('utf-8')
474
475 cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
476 cfg_process_client_limited(cc,ss,sections,'http_timeout')
477
478 return sections
479
480 def cfg_process_ipif(c, sections, varmap):
481 for d, s in varmap:
482 try: v = getattr(c, s)
483 except AttributeError: continue
484 setattr(c, d, v)
485
486 #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
487
488 section = cfg_search_section('ipif', sections)
489 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
490
491 #---------- startup ----------
492
493 def common_startup(process_cfg):
494 # calls process_cfg(putative_clients, putative_servers)
495
496 # ConfigParser hates #-comments after values
497 trailingcomments_re = regexp.compile(r'#.*')
498 cfg.read_string(trailingcomments_re.sub('', defcfg))
499 need_defcfg = True
500
501 def readconfig(pathname, mandatory=True):
502 def log(m, p=pathname):
503 if not DBG.CONFIG in debug_set: return
504 print('DBG.CONFIG: %s: %s' % (m, pathname))
505
506 try:
507 files = os.listdir(pathname)
508
509 except FileNotFoundError:
510 if mandatory: raise
511 log('skipped')
512 return
513
514 except NotADirectoryError:
515 cfg.read(pathname)
516 log('read file')
517 return
518
519 # is a directory
520 log('directory')
521 re = regexp.compile('[^-A-Za-z0-9_]')
522 for f in os.listdir(cdir):
523 if re.search(f): continue
524 subpath = pathname + '/' + f
525 try:
526 os.stat(subpath)
527 except FileNotFoundError:
528 log('entry skipped', subpath)
529 continue
530 cfg.read(subpath)
531 log('entry read', subpath)
532
533 def oc_config(od,os, value, op):
534 nonlocal need_defcfg
535 need_defcfg = False
536 readconfig(value)
537
538 def dfs_less_detailed(dl):
539 return [df for df in DBG.iterconstants() if df <= dl]
540
541 def ds_default(od,os,dl,op):
542 global debug_set
543 debug_set = set(dfs_less_detailed(debug_def_detail))
544
545 def ds_select(od,os, spec, op):
546 for it in spec.split(','):
547
548 if it.startswith('-'):
549 mutator = debug_set.discard
550 it = it[1:]
551 else:
552 mutator = debug_set.add
553
554 if it == '+':
555 dfs = DBG.iterconstants()
556
557 else:
558 if it.endswith('+'):
559 mapper = dfs_less_detailed
560 it = it[0:len(it)-1]
561 else:
562 mapper = lambda x: [x]
563
564 try:
565 dfspec = DBG.lookupByName(it)
566 except ValueError:
567 optparser.error('unknown debug flag %s in --debug-select' % it)
568
569 dfs = mapper(dfspec)
570
571 for df in dfs:
572 mutator(df)
573
574 optparser.add_option('-D', '--debug',
575 nargs=0,
576 action='callback',
577 help='enable default debug (to stdout)',
578 callback= ds_default)
579
580 optparser.add_option('--debug-select',
581 nargs=1,
582 type='string',
583 metavar='[-]DFLAG[+]|[-]+,...',
584 help=
585 '''enable (`-': disable) each specified DFLAG;
586 `+': do same for all "more interesting" DFLAGSs;
587 just `+': all DFLAGs.
588 DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
589 action='callback',
590 callback= ds_select)
591
592 optparser.add_option('-c', '--config',
593 nargs=1,
594 type='string',
595 metavar='CONFIGFILE',
596 dest='configfile',
597 action='callback',
598 callback= oc_config)
599
600 (opts, args) = optparser.parse_args()
601 if len(args): optparser.error('no non-option arguments please')
602
603 if need_defcfg:
604 readconfig('/etc/hippotat/config', False)
605 readconfig('/etc/hippotat/config.d', False)
606
607 try:
608 (pss, pcs) = _cfg_process_putatives()
609 process_cfg(pss, pcs)
610 except (configparser.Error, ValueError):
611 traceback.print_exc(file=sys.stderr)
612 print('\nInvalid configuration, giving up.', file=sys.stderr)
613 sys.exit(12)
614
615 #print(repr(debug_set), file=sys.stderr)
616
617 log_formatter = twisted.logger.formatEventAsClassicLogText
618 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
619 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
620 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
621 stdsomething_obs = twisted.logger.FilteringLogObserver(
622 stderr_obs, [pred], stdout_obs
623 )
624 log_observer = twisted.logger.FilteringLogObserver(
625 stdsomething_obs, [LogNotBoringTwisted()]
626 )
627 #log_observer = stdsomething_obs
628 twisted.logger.globalLogBeginner.beginLoggingTo(
629 [ log_observer, crash_on_critical ]
630 )
631
632 def common_run():
633 log_debug(DBG.INIT, 'entering reactor')
634 if not _crashing: reactor.run()
635 print('CRASHED (end)', file=sys.stderr)