d8b119a2b5c70d345605231dc6e2e2920dbfba8d
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 from zope.interface import implementer
9
10 import twisted
11 from twisted.internet import reactor
12 import twisted.internet.endpoints
13 import twisted.logger
14 from twisted.logger import LogLevel
15 import twisted.python.constants
16 from twisted.python.constants import NamedConstant
17
18 import ipaddress
19 from ipaddress import AddressValueError
20
21 from optparse import OptionParser
22 from configparser import ConfigParser
23 from configparser import NoOptionError
24
25 from functools import partial
26
27 import collections
28 import time
29 import codecs
30 import traceback
31
32 import re as regexp
33
34 import hippotat.slip as slip
35
36 class DBG(twisted.python.constants.Names):
37 INIT = NamedConstant()
38 ROUTE = NamedConstant()
39 DROP = NamedConstant()
40 FLOW = NamedConstant()
41 HTTP = NamedConstant()
42 TWISTED = NamedConstant()
43 QUEUE = NamedConstant()
44 HTTP_CTRL = NamedConstant()
45 QUEUE_CTRL = NamedConstant()
46 HTTP_FULL = NamedConstant()
47 CTRL_DUMP = NamedConstant()
48 SLIP_FULL = NamedConstant()
49 DATA_COMPLETE = NamedConstant()
50
51 _hex_codec = codecs.getencoder('hex_codec')
52
53 #---------- logging ----------
54
55 org_stderr = sys.stderr
56
57 log = twisted.logger.Logger()
58
59 debug_set = set()
60 debug_def_detail = DBG.HTTP
61
62 def log_debug(dflag, msg, idof=None, d=None):
63 if dflag not in debug_set: return
64 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
65 if idof is not None:
66 msg = '[%#x] %s' % (id(idof), msg)
67 if d is not None:
68 trunc = ''
69 if not DBG.DATA_COMPLETE in debug_set:
70 if len(d) > 64:
71 d = d[0:64]
72 trunc = '...'
73 d = _hex_codec(d)[0].decode('ascii')
74 msg += ' ' + d + trunc
75 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
76
77 @implementer(twisted.logger.ILogFilterPredicate)
78 class LogNotBoringTwisted:
79 def __call__(self, event):
80 yes = twisted.logger.PredicateResult.yes
81 no = twisted.logger.PredicateResult.no
82 try:
83 if event.get('log_level') != LogLevel.info:
84 return yes
85 dflag = event.get('dflag')
86 if dflag in debug_set: return yes
87 if dflag is None and DBG.TWISTED in debug_set: return yes
88 return no
89 except Exception:
90 print(traceback.format_exc(), file=org_stderr)
91 return yes
92
93 #---------- default config ----------
94
95 defcfg = '''
96 [DEFAULT]
97 #[<client>] overrides
98 max_batch_down = 65536 # used by server, subject to [limits]
99 max_queue_time = 10 # used by server, subject to [limits]
100 target_requests_outstanding = 3 # must match; subject to [limits] on server
101 http_timeout = 30 # used by both } must be
102 http_timeout_grace = 5 # used by both } compatible
103 max_requests_outstanding = 4 # used by client
104 max_batch_up = 4000 # used by client
105 http_retry = 5 # used by client
106
107 #[server] or [<client>] overrides
108 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
109 # extra interpolations: %(local)s %(peer)s %(rnet)s
110 # obtained on server [virtual]server [virtual]relay [virtual]network
111 # from on client <client> [virtual]server [virtual]routes
112
113 [virtual]
114 mtu = 1500
115 routes = ''
116 # network = <prefix>/<len> # mandatory for server
117 # server = <ipaddr> # used by both, default is computed from `network'
118 # relay = <ipaddr> # used by server, default from `network' and `server'
119 # default server is first host in network
120 # default relay is first host which is not server
121
122 [server]
123 # addrs = 127.0.0.1 ::1 # mandatory for server
124 port = 80 # used by server
125 # url # used by client; default from first `addrs' and `port'
126
127 # [<client-ip4-or-ipv6-address>]
128 # password = <password> # used by both, must match
129
130 [limits]
131 max_batch_down = 262144 # used by server
132 max_queue_time = 121 # used by server
133 http_timeout = 121 # used by server
134 target_requests_outstanding = 10 # used by server
135 '''
136
137 # these need to be defined here so that they can be imported by import *
138 cfg = ConfigParser()
139 optparser = OptionParser()
140
141 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
142 def mime_translate(s):
143 # SLIP-encoded packets cannot contain ESC ESC.
144 # Swap `-' and ESC. The result cannot contain `--'
145 return s.translate(_mimetrans)
146
147 class ConfigResults:
148 def __init__(self, d = { }):
149 self.__dict__ = d
150 def __repr__(self):
151 return 'ConfigResults('+repr(self.__dict__)+')'
152
153 c = ConfigResults()
154
155 def log_discard(packet, iface, saddr, daddr, why):
156 log_debug(DBG.DROP,
157 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
158 d=packet)
159
160 #---------- packet parsing ----------
161
162 def packet_addrs(packet):
163 version = packet[0] >> 4
164 if version == 4:
165 addrlen = 4
166 saddroff = 3*4
167 factory = ipaddress.IPv4Address
168 elif version == 6:
169 addrlen = 16
170 saddroff = 2*4
171 factory = ipaddress.IPv6Address
172 else:
173 raise ValueError('unsupported IP version %d' % version)
174 saddr = factory(packet[ saddroff : saddroff + addrlen ])
175 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
176 return (saddr, daddr)
177
178 #---------- address handling ----------
179
180 def ipaddr(input):
181 try:
182 r = ipaddress.IPv4Address(input)
183 except AddressValueError:
184 r = ipaddress.IPv6Address(input)
185 return r
186
187 def ipnetwork(input):
188 try:
189 r = ipaddress.IPv4Network(input)
190 except NetworkValueError:
191 r = ipaddress.IPv6Network(input)
192 return r
193
194 #---------- ipif (SLIP) subprocess ----------
195
196 class SlipStreamDecoder():
197 def __init__(self, desc, on_packet):
198 self._buffer = b''
199 self._on_packet = on_packet
200 self._desc = desc
201 self._log('__init__')
202
203 def _log(self, msg, **kwargs):
204 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
205
206 def inputdata(self, data):
207 self._log('inputdata', d=data)
208 packets = slip.decode(data)
209 packets[0] = self._buffer + packets[0]
210 self._buffer = packets.pop()
211 for packet in packets:
212 self._maybe_packet(packet)
213 self._log('bufremain', d=self._buffer)
214
215 def _maybe_packet(self, packet):
216 self._log('maybepacket', d=packet)
217 if len(packet):
218 self._on_packet(packet)
219
220 def flush(self):
221 self._log('flush')
222 self._maybe_packet(self._buffer)
223 self._buffer = b''
224
225 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
226 def __init__(self, router):
227 self._router = router
228 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
229 def connectionMade(self): pass
230 def outReceived(self, data):
231 self._decoder.inputdata(data)
232 def slip_on_packet(self, packet):
233 (saddr, daddr) = packet_addrs(packet)
234 if saddr.is_link_local or daddr.is_link_local:
235 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
236 return
237 self._router(packet, saddr, daddr)
238 def processEnded(self, status):
239 status.raiseException()
240
241 def start_ipif(command, router):
242 global ipif
243 ipif = _IpifProcessProtocol(router)
244 reactor.spawnProcess(ipif,
245 '/bin/sh',['sh','-xc', command],
246 childFDs={0:'w', 1:'r', 2:2},
247 env=None)
248
249 def queue_inbound(packet):
250 log_debug(DBG.FLOW, "queue_inbound", d=packet)
251 ipif.transport.write(slip.delimiter)
252 ipif.transport.write(slip.encode(packet))
253 ipif.transport.write(slip.delimiter)
254
255 #---------- packet queue ----------
256
257 class PacketQueue():
258 def __init__(self, desc, max_queue_time):
259 self._desc = desc
260 assert(desc + '')
261 self._max_queue_time = max_queue_time
262 self._pq = collections.deque() # packets
263
264 def _log(self, dflag, msg, **kwargs):
265 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
266
267 def append(self, packet):
268 self._log(DBG.QUEUE, 'append', d=packet)
269 self._pq.append((time.monotonic(), packet))
270
271 def nonempty(self):
272 self._log(DBG.QUEUE, 'nonempty ?')
273 while True:
274 try: (queuetime, packet) = self._pq[0]
275 except IndexError:
276 self._log(DBG.QUEUE, 'nonempty ? empty.')
277 return False
278
279 age = time.monotonic() - queuetime
280 if age > self._max_queue_time:
281 # strip old packets off the front
282 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
283 self._pq.popleft()
284 continue
285
286 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
287 return True
288
289 def process(self, sizequery, moredata, max_batch):
290 # sizequery() should return size of batch so far
291 # moredata(s) should add s to batch
292 self._log(DBG.QUEUE, 'process...')
293 while True:
294 try: (dummy, packet) = self._pq[0]
295 except IndexError:
296 self._log(DBG.QUEUE, 'process... empty')
297 break
298
299 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
300
301 encoded = slip.encode(packet)
302 sofar = sizequery()
303
304 self._log(DBG.QUEUE_CTRL,
305 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
306 d=encoded)
307
308 if sofar > 0:
309 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
310 self._log(DBG.QUEUE_CTRL, 'process... overflow')
311 break
312 moredata(slip.delimiter)
313
314 moredata(encoded)
315 self._pq.popleft()
316
317 #---------- error handling ----------
318
319 _crashing = False
320
321 def crash(err):
322 global _crashing
323 _crashing = True
324 print('========== CRASH ==========', err,
325 '===========================', file=sys.stderr)
326 try: reactor.stop()
327 except twisted.internet.error.ReactorNotRunning: pass
328
329 def crash_on_defer(defer):
330 defer.addErrback(lambda err: crash(err))
331
332 def crash_on_critical(event):
333 if event.get('log_level') >= LogLevel.critical:
334 crash(twisted.logger.formatEvent(event))
335
336 #---------- config processing ----------
337
338 def process_cfg_common_always():
339 global mtu
340 c.mtu = cfg.get('virtual','mtu')
341
342 def process_cfg_ipif(section, varmap):
343 for d, s in varmap:
344 try: v = getattr(c, s)
345 except AttributeError: continue
346 setattr(c, d, v)
347
348 #print(repr(c))
349
350 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
351
352 def process_cfg_network():
353 c.network = ipnetwork(cfg.get('virtual','network'))
354 if c.network.num_addresses < 3 + 2:
355 raise ValueError('network needs at least 2^3 addresses')
356
357 def process_cfg_server():
358 try:
359 c.server = cfg.get('virtual','server')
360 except NoOptionError:
361 process_cfg_network()
362 c.server = next(c.network.hosts())
363
364 class ServerAddr():
365 def __init__(self, port, addrspec):
366 self.port = port
367 # also self.addr
368 try:
369 self.addr = ipaddress.IPv4Address(addrspec)
370 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
371 self._inurl = b'%s'
372 except AddressValueError:
373 self.addr = ipaddress.IPv6Address(addrspec)
374 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
375 self._inurl = b'[%s]'
376 def make_endpoint(self):
377 return self._endpointfactory(reactor, self.port, self.addr)
378 def url(self):
379 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
380 if self.port != 80: url += b':%d' % self.port
381 url += b'/'
382 return url
383
384 def process_cfg_saddrs():
385 try: port = cfg.getint('server','port')
386 except NoOptionError: port = 80
387
388 c.saddrs = [ ]
389 for addrspec in cfg.get('server','addrs').split():
390 sa = ServerAddr(port, addrspec)
391 c.saddrs.append(sa)
392
393 def process_cfg_clients(constructor):
394 c.clients = [ ]
395 for cs in cfg.sections():
396 if not (':' in cs or '.' in cs): continue
397 ci = ipaddr(cs)
398 pw = cfg.get(cs, 'password')
399 pw = pw.encode('utf-8')
400 constructor(ci,cs,pw)
401
402 #---------- startup ----------
403
404 def common_startup():
405 optparser.add_option('-c', '--config', dest='configfile',
406 default='/etc/hippotat/config')
407
408 def dfs_less_detailed(dl):
409 return [df for df in DBG.iterconstants() if df <= dl]
410
411 def ds_default(od,os,dl,op):
412 global debug_set
413 debug_set = set(dfs_less_detailed(debug_def_detail))
414
415 def ds_select(od,os, spec, op):
416 for it in spec.split(','):
417
418 if it.startswith('-'):
419 mutator = debug_set.discard
420 it = it[1:]
421 else:
422 mutator = debug_set.add
423
424 if it == '+':
425 dfs = DBG.iterconstants()
426
427 else:
428 if it.endswith('+'):
429 mapper = dfs_less_detailed
430 it = it[0:len(it)-1]
431 else:
432 mapper = lambda x: [x]
433
434 try:
435 dfspec = DBG.lookupByName(it)
436 except ValueError:
437 optparser.error('unknown debug flag %s in --debug-select' % it)
438
439 dfs = mapper(dfspec)
440
441 for df in dfs:
442 mutator(df)
443
444 optparser.add_option('-D', '--debug',
445 nargs=0,
446 action='callback',
447 help='enable default debug (to stdout)',
448 callback= ds_default)
449
450 optparser.add_option('--debug-select',
451 nargs=1,
452 type='string',
453 metavar='[-]DFLAG[+]|[-]+,...',
454 help=
455 '''enable (`-': disable) each specified DFLAG;
456 `+': do same for all "more interesting" DFLAGSs;
457 just `+': all DFLAGs.
458 DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
459 action='callback',
460 callback= ds_select)
461
462 (opts, args) = optparser.parse_args()
463 if len(args): optparser.error('no non-option arguments please')
464
465 #print(repr(debug_set), file=sys.stderr)
466
467 re = regexp.compile('#.*')
468 cfg.read_string(re.sub('', defcfg))
469 cfg.read(opts.configfile)
470
471 log_formatter = twisted.logger.formatEventAsClassicLogText
472 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
473 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
474 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
475 stdsomething_obs = twisted.logger.FilteringLogObserver(
476 stderr_obs, [pred], stdout_obs
477 )
478 log_observer = twisted.logger.FilteringLogObserver(
479 stdsomething_obs, [LogNotBoringTwisted()]
480 )
481 #log_observer = stdsomething_obs
482 twisted.logger.globalLogBeginner.beginLoggingTo(
483 [ log_observer, crash_on_critical ]
484 )
485
486 def common_run():
487 log_debug(DBG.INIT, 'entering reactor')
488 if not _crashing: reactor.run()
489 print('CRASHED (end)', file=sys.stderr)