b6a84d9546903868edf9d6830f4b8df2de92b08c
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35 INIT = NamedConstant()
36 ROUTE = NamedConstant()
37 DROP = NamedConstant()
38 FLOW = NamedConstant()
39 HTTP = NamedConstant()
40 TWISTED = NamedConstant()
41 QUEUE = NamedConstant()
42 HTTP_CTRL = NamedConstant()
43 QUEUE_CTRL = NamedConstant()
44 HTTP_FULL = NamedConstant()
45 CTRL_DUMP = NamedConstant()
46 SLIP_FULL = NamedConstant()
47
48 _hex_codec = codecs.getencoder('hex_codec')
49
50 log = twisted.logger.Logger()
51
52 debug_set = set([x for x in DBG.iterconstants() if x <= DBG.HTTP])
53
54 def log_debug(dflag, msg, idof=None, d=None):
55 if dflag not in debug_set: return
56 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
57 if idof is not None:
58 msg = '[%#x] %s' % (id(idof), msg)
59 if d is not None:
60 #d = d[0:64]
61 d = _hex_codec(d)[0].decode('ascii')
62 msg += ' ' + d
63 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
64
65 defcfg = '''
66 [DEFAULT]
67 #[<client>] overrides
68 max_batch_down = 65536 # used by server, subject to [limits]
69 max_queue_time = 10 # used by server, subject to [limits]
70 target_requests_outstanding = 3 # must match; subject to [limits] on server
71 http_timeout = 30 # used by both } must be
72 http_timeout_grace = 5 # used by both } compatible
73 max_requests_outstanding = 4 # used by client
74 max_batch_up = 4000 # used by client
75 http_retry = 5 # used by client
76
77 #[server] or [<client>] overrides
78 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
79 # extra interpolations: %(local)s %(peer)s %(rnet)s
80 # obtained on server [virtual]server [virtual]relay [virtual]network
81 # from on client <client> [virtual]server [virtual]routes
82
83 [virtual]
84 mtu = 1500
85 routes = ''
86 # network = <prefix>/<len> # mandatory for server
87 # server = <ipaddr> # used by both, default is computed from `network'
88 # relay = <ipaddr> # used by server, default from `network' and `server'
89 # default server is first host in network
90 # default relay is first host which is not server
91
92 [server]
93 # addrs = 127.0.0.1 ::1 # mandatory for server
94 port = 80 # used by server
95 # url # used by client; default from first `addrs' and `port'
96
97 # [<client-ip4-or-ipv6-address>]
98 # password = <password> # used by both, must match
99
100 [limits]
101 max_batch_down = 262144 # used by server
102 max_queue_time = 121 # used by server
103 http_timeout = 121 # used by server
104 target_requests_outstanding = 10 # used by server
105 '''
106
107 # these need to be defined here so that they can be imported by import *
108 cfg = ConfigParser()
109 optparser = OptionParser()
110
111 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
112 def mime_translate(s):
113 # SLIP-encoded packets cannot contain ESC ESC.
114 # Swap `-' and ESC. The result cannot contain `--'
115 return s.translate(_mimetrans)
116
117 class ConfigResults:
118 def __init__(self, d = { }):
119 self.__dict__ = d
120 def __repr__(self):
121 return 'ConfigResults('+repr(self.__dict__)+')'
122
123 c = ConfigResults()
124
125 def log_discard(packet, iface, saddr, daddr, why):
126 log_debug(DBG.DROP,
127 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
128 d=packet)
129
130 #---------- packet parsing ----------
131
132 def packet_addrs(packet):
133 version = packet[0] >> 4
134 if version == 4:
135 addrlen = 4
136 saddroff = 3*4
137 factory = ipaddress.IPv4Address
138 elif version == 6:
139 addrlen = 16
140 saddroff = 2*4
141 factory = ipaddress.IPv6Address
142 else:
143 raise ValueError('unsupported IP version %d' % version)
144 saddr = factory(packet[ saddroff : saddroff + addrlen ])
145 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
146 return (saddr, daddr)
147
148 #---------- address handling ----------
149
150 def ipaddr(input):
151 try:
152 r = ipaddress.IPv4Address(input)
153 except AddressValueError:
154 r = ipaddress.IPv6Address(input)
155 return r
156
157 def ipnetwork(input):
158 try:
159 r = ipaddress.IPv4Network(input)
160 except NetworkValueError:
161 r = ipaddress.IPv6Network(input)
162 return r
163
164 #---------- ipif (SLIP) subprocess ----------
165
166 class SlipStreamDecoder():
167 def __init__(self, desc, on_packet):
168 self._buffer = b''
169 self._on_packet = on_packet
170 self._desc = desc
171 self._log('__init__')
172
173 def _log(self, msg, **kwargs):
174 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
175
176 def inputdata(self, data):
177 self._log('inputdata', d=data)
178 packets = slip.decode(data)
179 packets[0] = self._buffer + packets[0]
180 self._buffer = packets.pop()
181 for packet in packets:
182 self._maybe_packet(packet)
183 self._log('bufremain', d=self._buffer)
184
185 def _maybe_packet(self, packet):
186 self._log('maybepacket', d=packet)
187 if len(packet):
188 self._on_packet(packet)
189
190 def flush(self):
191 self._log('flush')
192 self._maybe_packet(self._buffer)
193 self._buffer = b''
194
195 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
196 def __init__(self, router):
197 self._router = router
198 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
199 def connectionMade(self): pass
200 def outReceived(self, data):
201 self._decoder.inputdata(data)
202 def slip_on_packet(self, packet):
203 (saddr, daddr) = packet_addrs(packet)
204 if saddr.is_link_local or daddr.is_link_local:
205 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
206 return
207 self._router(packet, saddr, daddr)
208 def processEnded(self, status):
209 status.raiseException()
210
211 def start_ipif(command, router):
212 global ipif
213 ipif = _IpifProcessProtocol(router)
214 reactor.spawnProcess(ipif,
215 '/bin/sh',['sh','-xc', command],
216 childFDs={0:'w', 1:'r', 2:2},
217 env=None)
218
219 def queue_inbound(packet):
220 log_debug(DBG.FLOW, "queue_inbound", d=packet)
221 ipif.transport.write(slip.delimiter)
222 ipif.transport.write(slip.encode(packet))
223 ipif.transport.write(slip.delimiter)
224
225 #---------- packet queue ----------
226
227 class PacketQueue():
228 def __init__(self, desc, max_queue_time):
229 self._desc = desc
230 assert(desc + '')
231 self._max_queue_time = max_queue_time
232 self._pq = collections.deque() # packets
233
234 def _log(self, dflag, msg, **kwargs):
235 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
236
237 def append(self, packet):
238 self._log(DBG.QUEUE, 'append', d=packet)
239 self._pq.append((time.monotonic(), packet))
240
241 def nonempty(self):
242 self._log(DBG.QUEUE, 'nonempty ?')
243 while True:
244 try: (queuetime, packet) = self._pq[0]
245 except IndexError:
246 self._log(DBG.QUEUE, 'nonempty ? empty.')
247 return False
248
249 age = time.monotonic() - queuetime
250 if age > self._max_queue_time:
251 # strip old packets off the front
252 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
253 self._pq.popleft()
254 continue
255
256 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
257 return True
258
259 def process(self, sizequery, moredata, max_batch):
260 # sizequery() should return size of batch so far
261 # moredata(s) should add s to batch
262 self._log(DBG.QUEUE, 'process...')
263 while True:
264 try: (dummy, packet) = self._pq[0]
265 except IndexError:
266 self._log(DBG.QUEUE, 'process... empty')
267 break
268
269 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
270
271 encoded = slip.encode(packet)
272 sofar = sizequery()
273
274 self._log(DBG.QUEUE_CTRL,
275 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
276 d=encoded)
277
278 if sofar > 0:
279 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
280 self._log(DBG.QUEUE_CTRL, 'process... overflow')
281 break
282 moredata(slip.delimiter)
283
284 moredata(encoded)
285 self._pq.popleft()
286
287 #---------- error handling ----------
288
289 _crashing = False
290
291 def crash(err):
292 global _crashing
293 _crashing = True
294 print('========== CRASH ==========', err,
295 '===========================', file=sys.stderr)
296 try: reactor.stop()
297 except twisted.internet.error.ReactorNotRunning: pass
298
299 def crash_on_defer(defer):
300 defer.addErrback(lambda err: crash(err))
301
302 def crash_on_critical(event):
303 if event.get('log_level') >= LogLevel.critical:
304 crash(twisted.logger.formatEvent(event))
305
306 #---------- config processing ----------
307
308 def process_cfg_common_always():
309 global mtu
310 c.mtu = cfg.get('virtual','mtu')
311
312 def process_cfg_ipif(section, varmap):
313 for d, s in varmap:
314 try: v = getattr(c, s)
315 except AttributeError: continue
316 setattr(c, d, v)
317
318 #print(repr(c))
319
320 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
321
322 def process_cfg_network():
323 c.network = ipnetwork(cfg.get('virtual','network'))
324 if c.network.num_addresses < 3 + 2:
325 raise ValueError('network needs at least 2^3 addresses')
326
327 def process_cfg_server():
328 try:
329 c.server = cfg.get('virtual','server')
330 except NoOptionError:
331 process_cfg_network()
332 c.server = next(c.network.hosts())
333
334 class ServerAddr():
335 def __init__(self, port, addrspec):
336 self.port = port
337 # also self.addr
338 try:
339 self.addr = ipaddress.IPv4Address(addrspec)
340 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
341 self._inurl = b'%s'
342 except AddressValueError:
343 self.addr = ipaddress.IPv6Address(addrspec)
344 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
345 self._inurl = b'[%s]'
346 def make_endpoint(self):
347 return self._endpointfactory(reactor, self.port, self.addr)
348 def url(self):
349 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
350 if self.port != 80: url += b':%d' % self.port
351 url += b'/'
352 return url
353
354 def process_cfg_saddrs():
355 try: port = cfg.getint('server','port')
356 except NoOptionError: port = 80
357
358 c.saddrs = [ ]
359 for addrspec in cfg.get('server','addrs').split():
360 sa = ServerAddr(port, addrspec)
361 c.saddrs.append(sa)
362
363 def process_cfg_clients(constructor):
364 c.clients = [ ]
365 for cs in cfg.sections():
366 if not (':' in cs or '.' in cs): continue
367 ci = ipaddr(cs)
368 pw = cfg.get(cs, 'password')
369 pw = pw.encode('utf-8')
370 constructor(ci,cs,pw)
371
372 #---------- startup ----------
373
374 def common_startup():
375 log_formatter = twisted.logger.formatEventAsClassicLogText
376 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
377 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
378 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
379 log_observer = twisted.logger.FilteringLogObserver(
380 stderr_obs, [pred], stdout_obs
381 )
382 twisted.logger.globalLogBeginner.beginLoggingTo(
383 [ log_observer, crash_on_critical ]
384 )
385
386 optparser.add_option('-c', '--config', dest='configfile',
387 default='/etc/hippotat/config')
388 (opts, args) = optparser.parse_args()
389 if len(args): optparser.error('no non-option arguments please')
390
391 re = regexp.compile('#.*')
392 cfg.read_string(re.sub('', defcfg))
393 cfg.read(opts.configfile)
394
395 def common_run():
396 log_debug(DBG.INIT, 'entering reactor')
397 if not _crashing: reactor.run()
398 print('CRASHED (end)', file=sys.stderr)