filter out twisted stuff
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 from zope.interface import implementer
9
10 import twisted
11 from twisted.internet import reactor
12 import twisted.internet.endpoints
13 import twisted.logger
14 from twisted.logger import LogLevel
15 import twisted.python.constants
16 from twisted.python.constants import NamedConstant
17
18 import ipaddress
19 from ipaddress import AddressValueError
20
21 from optparse import OptionParser
22 from configparser import ConfigParser
23 from configparser import NoOptionError
24
25 from functools import partial
26
27 import collections
28 import time
29 import codecs
30 import traceback
31
32 import re as regexp
33
34 import hippotat.slip as slip
35
36 class DBG(twisted.python.constants.Names):
37 INIT = NamedConstant()
38 ROUTE = NamedConstant()
39 DROP = NamedConstant()
40 FLOW = NamedConstant()
41 HTTP = NamedConstant()
42 TWISTED = NamedConstant()
43 QUEUE = NamedConstant()
44 HTTP_CTRL = NamedConstant()
45 QUEUE_CTRL = NamedConstant()
46 HTTP_FULL = NamedConstant()
47 CTRL_DUMP = NamedConstant()
48 SLIP_FULL = NamedConstant()
49
50 _hex_codec = codecs.getencoder('hex_codec')
51
52 #---------- logging ----------
53
54 org_stderr = sys.stderr
55
56 log = twisted.logger.Logger()
57
58 debug_set = set([x for x in DBG.iterconstants() if x <= DBG.HTTP])
59
60 def log_debug(dflag, msg, idof=None, d=None):
61 if dflag not in debug_set: return
62 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
63 if idof is not None:
64 msg = '[%#x] %s' % (id(idof), msg)
65 if d is not None:
66 #d = d[0:64]
67 d = _hex_codec(d)[0].decode('ascii')
68 msg += ' ' + d
69 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
70
71 @implementer(twisted.logger.ILogFilterPredicate)
72 class LogNotBoringTwisted:
73 def __call__(self, event):
74 yes = twisted.logger.PredicateResult.yes
75 no = twisted.logger.PredicateResult.no
76 try:
77 if event.get('log_level') != LogLevel.info:
78 return yes
79 try:
80 dflag = event.get('dflag')
81 except KeyError:
82 dflag = DBG.TWISTED
83 return yes if (dflag in debug_set) else no
84 except Exception:
85 print(traceback.format_exc(), file=org_stderr)
86 return yes
87
88 #---------- default config ----------
89
90 defcfg = '''
91 [DEFAULT]
92 #[<client>] overrides
93 max_batch_down = 65536 # used by server, subject to [limits]
94 max_queue_time = 10 # used by server, subject to [limits]
95 target_requests_outstanding = 3 # must match; subject to [limits] on server
96 http_timeout = 30 # used by both } must be
97 http_timeout_grace = 5 # used by both } compatible
98 max_requests_outstanding = 4 # used by client
99 max_batch_up = 4000 # used by client
100 http_retry = 5 # used by client
101
102 #[server] or [<client>] overrides
103 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
104 # extra interpolations: %(local)s %(peer)s %(rnet)s
105 # obtained on server [virtual]server [virtual]relay [virtual]network
106 # from on client <client> [virtual]server [virtual]routes
107
108 [virtual]
109 mtu = 1500
110 routes = ''
111 # network = <prefix>/<len> # mandatory for server
112 # server = <ipaddr> # used by both, default is computed from `network'
113 # relay = <ipaddr> # used by server, default from `network' and `server'
114 # default server is first host in network
115 # default relay is first host which is not server
116
117 [server]
118 # addrs = 127.0.0.1 ::1 # mandatory for server
119 port = 80 # used by server
120 # url # used by client; default from first `addrs' and `port'
121
122 # [<client-ip4-or-ipv6-address>]
123 # password = <password> # used by both, must match
124
125 [limits]
126 max_batch_down = 262144 # used by server
127 max_queue_time = 121 # used by server
128 http_timeout = 121 # used by server
129 target_requests_outstanding = 10 # used by server
130 '''
131
132 # these need to be defined here so that they can be imported by import *
133 cfg = ConfigParser()
134 optparser = OptionParser()
135
136 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
137 def mime_translate(s):
138 # SLIP-encoded packets cannot contain ESC ESC.
139 # Swap `-' and ESC. The result cannot contain `--'
140 return s.translate(_mimetrans)
141
142 class ConfigResults:
143 def __init__(self, d = { }):
144 self.__dict__ = d
145 def __repr__(self):
146 return 'ConfigResults('+repr(self.__dict__)+')'
147
148 c = ConfigResults()
149
150 def log_discard(packet, iface, saddr, daddr, why):
151 log_debug(DBG.DROP,
152 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
153 d=packet)
154
155 #---------- packet parsing ----------
156
157 def packet_addrs(packet):
158 version = packet[0] >> 4
159 if version == 4:
160 addrlen = 4
161 saddroff = 3*4
162 factory = ipaddress.IPv4Address
163 elif version == 6:
164 addrlen = 16
165 saddroff = 2*4
166 factory = ipaddress.IPv6Address
167 else:
168 raise ValueError('unsupported IP version %d' % version)
169 saddr = factory(packet[ saddroff : saddroff + addrlen ])
170 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
171 return (saddr, daddr)
172
173 #---------- address handling ----------
174
175 def ipaddr(input):
176 try:
177 r = ipaddress.IPv4Address(input)
178 except AddressValueError:
179 r = ipaddress.IPv6Address(input)
180 return r
181
182 def ipnetwork(input):
183 try:
184 r = ipaddress.IPv4Network(input)
185 except NetworkValueError:
186 r = ipaddress.IPv6Network(input)
187 return r
188
189 #---------- ipif (SLIP) subprocess ----------
190
191 class SlipStreamDecoder():
192 def __init__(self, desc, on_packet):
193 self._buffer = b''
194 self._on_packet = on_packet
195 self._desc = desc
196 self._log('__init__')
197
198 def _log(self, msg, **kwargs):
199 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
200
201 def inputdata(self, data):
202 self._log('inputdata', d=data)
203 packets = slip.decode(data)
204 packets[0] = self._buffer + packets[0]
205 self._buffer = packets.pop()
206 for packet in packets:
207 self._maybe_packet(packet)
208 self._log('bufremain', d=self._buffer)
209
210 def _maybe_packet(self, packet):
211 self._log('maybepacket', d=packet)
212 if len(packet):
213 self._on_packet(packet)
214
215 def flush(self):
216 self._log('flush')
217 self._maybe_packet(self._buffer)
218 self._buffer = b''
219
220 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
221 def __init__(self, router):
222 self._router = router
223 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
224 def connectionMade(self): pass
225 def outReceived(self, data):
226 self._decoder.inputdata(data)
227 def slip_on_packet(self, packet):
228 (saddr, daddr) = packet_addrs(packet)
229 if saddr.is_link_local or daddr.is_link_local:
230 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
231 return
232 self._router(packet, saddr, daddr)
233 def processEnded(self, status):
234 status.raiseException()
235
236 def start_ipif(command, router):
237 global ipif
238 ipif = _IpifProcessProtocol(router)
239 reactor.spawnProcess(ipif,
240 '/bin/sh',['sh','-xc', command],
241 childFDs={0:'w', 1:'r', 2:2},
242 env=None)
243
244 def queue_inbound(packet):
245 log_debug(DBG.FLOW, "queue_inbound", d=packet)
246 ipif.transport.write(slip.delimiter)
247 ipif.transport.write(slip.encode(packet))
248 ipif.transport.write(slip.delimiter)
249
250 #---------- packet queue ----------
251
252 class PacketQueue():
253 def __init__(self, desc, max_queue_time):
254 self._desc = desc
255 assert(desc + '')
256 self._max_queue_time = max_queue_time
257 self._pq = collections.deque() # packets
258
259 def _log(self, dflag, msg, **kwargs):
260 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
261
262 def append(self, packet):
263 self._log(DBG.QUEUE, 'append', d=packet)
264 self._pq.append((time.monotonic(), packet))
265
266 def nonempty(self):
267 self._log(DBG.QUEUE, 'nonempty ?')
268 while True:
269 try: (queuetime, packet) = self._pq[0]
270 except IndexError:
271 self._log(DBG.QUEUE, 'nonempty ? empty.')
272 return False
273
274 age = time.monotonic() - queuetime
275 if age > self._max_queue_time:
276 # strip old packets off the front
277 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
278 self._pq.popleft()
279 continue
280
281 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
282 return True
283
284 def process(self, sizequery, moredata, max_batch):
285 # sizequery() should return size of batch so far
286 # moredata(s) should add s to batch
287 self._log(DBG.QUEUE, 'process...')
288 while True:
289 try: (dummy, packet) = self._pq[0]
290 except IndexError:
291 self._log(DBG.QUEUE, 'process... empty')
292 break
293
294 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
295
296 encoded = slip.encode(packet)
297 sofar = sizequery()
298
299 self._log(DBG.QUEUE_CTRL,
300 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
301 d=encoded)
302
303 if sofar > 0:
304 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
305 self._log(DBG.QUEUE_CTRL, 'process... overflow')
306 break
307 moredata(slip.delimiter)
308
309 moredata(encoded)
310 self._pq.popleft()
311
312 #---------- error handling ----------
313
314 _crashing = False
315
316 def crash(err):
317 global _crashing
318 _crashing = True
319 print('========== CRASH ==========', err,
320 '===========================', file=sys.stderr)
321 try: reactor.stop()
322 except twisted.internet.error.ReactorNotRunning: pass
323
324 def crash_on_defer(defer):
325 defer.addErrback(lambda err: crash(err))
326
327 def crash_on_critical(event):
328 if event.get('log_level') >= LogLevel.critical:
329 crash(twisted.logger.formatEvent(event))
330
331 #---------- config processing ----------
332
333 def process_cfg_common_always():
334 global mtu
335 c.mtu = cfg.get('virtual','mtu')
336
337 def process_cfg_ipif(section, varmap):
338 for d, s in varmap:
339 try: v = getattr(c, s)
340 except AttributeError: continue
341 setattr(c, d, v)
342
343 #print(repr(c))
344
345 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
346
347 def process_cfg_network():
348 c.network = ipnetwork(cfg.get('virtual','network'))
349 if c.network.num_addresses < 3 + 2:
350 raise ValueError('network needs at least 2^3 addresses')
351
352 def process_cfg_server():
353 try:
354 c.server = cfg.get('virtual','server')
355 except NoOptionError:
356 process_cfg_network()
357 c.server = next(c.network.hosts())
358
359 class ServerAddr():
360 def __init__(self, port, addrspec):
361 self.port = port
362 # also self.addr
363 try:
364 self.addr = ipaddress.IPv4Address(addrspec)
365 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
366 self._inurl = b'%s'
367 except AddressValueError:
368 self.addr = ipaddress.IPv6Address(addrspec)
369 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
370 self._inurl = b'[%s]'
371 def make_endpoint(self):
372 return self._endpointfactory(reactor, self.port, self.addr)
373 def url(self):
374 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
375 if self.port != 80: url += b':%d' % self.port
376 url += b'/'
377 return url
378
379 def process_cfg_saddrs():
380 try: port = cfg.getint('server','port')
381 except NoOptionError: port = 80
382
383 c.saddrs = [ ]
384 for addrspec in cfg.get('server','addrs').split():
385 sa = ServerAddr(port, addrspec)
386 c.saddrs.append(sa)
387
388 def process_cfg_clients(constructor):
389 c.clients = [ ]
390 for cs in cfg.sections():
391 if not (':' in cs or '.' in cs): continue
392 ci = ipaddr(cs)
393 pw = cfg.get(cs, 'password')
394 pw = pw.encode('utf-8')
395 constructor(ci,cs,pw)
396
397 #---------- startup ----------
398
399 def common_startup():
400 log_formatter = twisted.logger.formatEventAsClassicLogText
401 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
402 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
403 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
404 stdsomething_obs = twisted.logger.FilteringLogObserver(
405 stderr_obs, [pred], stdout_obs
406 )
407 log_observer = twisted.logger.FilteringLogObserver(
408 stdsomething_obs, [LogNotBoringTwisted()]
409 )
410 #log_observer = stdsomething_obs
411 twisted.logger.globalLogBeginner.beginLoggingTo(
412 [ log_observer, crash_on_critical ]
413 )
414
415 optparser.add_option('-c', '--config', dest='configfile',
416 default='/etc/hippotat/config')
417 (opts, args) = optparser.parse_args()
418 if len(args): optparser.error('no non-option arguments please')
419
420 re = regexp.compile('#.*')
421 cfg.read_string(re.sub('', defcfg))
422 cfg.read(opts.configfile)
423
424 def common_run():
425 log_debug(DBG.INIT, 'entering reactor')
426 if not _crashing: reactor.run()
427 print('CRASHED (end)', file=sys.stderr)