fa43065759d82332d9f61c5b4d8305c5346bf32c
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35 ROUTE = NamedConstant()
36 DROP = NamedConstant()
37 FLOW = NamedConstant()
38 HTTP = NamedConstant()
39 HTTP_CTRL = NamedConstant()
40 INIT = NamedConstant()
41 QUEUE = NamedConstant()
42 QUEUE_CTRL = NamedConstant()
43 HTTP_FULL = NamedConstant()
44 SLIP_FULL = NamedConstant()
45 CTRL_DUMP = NamedConstant()
46
47 _hex_codec = codecs.getencoder('hex_codec')
48
49 log = twisted.logger.Logger()
50
51 def log_debug(dflag, msg, idof=None, d=None):
52 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
53 if idof is not None:
54 msg = '[%#x] %s' % (id(idof), msg)
55 if d is not None:
56 #d = d[0:64]
57 d = _hex_codec(d)[0].decode('ascii')
58 msg += ' ' + d
59 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
60
61 defcfg = '''
62 [DEFAULT]
63 #[<client>] overrides
64 max_batch_down = 65536 # used by server, subject to [limits]
65 max_queue_time = 10 # used by server, subject to [limits]
66 target_requests_outstanding = 3 # must match; subject to [limits] on server
67 http_timeout = 30 # used by both } must be
68 http_timeout_grace = 5 # used by both } compatible
69 max_requests_outstanding = 4 # used by client
70 max_batch_up = 4000 # used by client
71 http_retry = 5 # used by client
72
73 #[server] or [<client>] overrides
74 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
75 # extra interpolations: %(local)s %(peer)s %(rnet)s
76 # obtained on server [virtual]server [virtual]relay [virtual]network
77 # from on client <client> [virtual]server [virtual]routes
78
79 [virtual]
80 mtu = 1500
81 routes = ''
82 # network = <prefix>/<len> # mandatory for server
83 # server = <ipaddr> # used by both, default is computed from `network'
84 # relay = <ipaddr> # used by server, default from `network' and `server'
85 # default server is first host in network
86 # default relay is first host which is not server
87
88 [server]
89 # addrs = 127.0.0.1 ::1 # mandatory for server
90 port = 80 # used by server
91 # url # used by client; default from first `addrs' and `port'
92
93 # [<client-ip4-or-ipv6-address>]
94 # password = <password> # used by both, must match
95
96 [limits]
97 max_batch_down = 262144 # used by server
98 max_queue_time = 121 # used by server
99 http_timeout = 121 # used by server
100 target_requests_outstanding = 10 # used by server
101 '''
102
103 # these need to be defined here so that they can be imported by import *
104 cfg = ConfigParser()
105 optparser = OptionParser()
106
107 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
108 def mime_translate(s):
109 # SLIP-encoded packets cannot contain ESC ESC.
110 # Swap `-' and ESC. The result cannot contain `--'
111 return s.translate(_mimetrans)
112
113 class ConfigResults:
114 def __init__(self, d = { }):
115 self.__dict__ = d
116 def __repr__(self):
117 return 'ConfigResults('+repr(self.__dict__)+')'
118
119 c = ConfigResults()
120
121 def log_discard(packet, iface, saddr, daddr, why):
122 log_debug(DBG.DROP,
123 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
124 d=packet)
125
126 #---------- packet parsing ----------
127
128 def packet_addrs(packet):
129 version = packet[0] >> 4
130 if version == 4:
131 addrlen = 4
132 saddroff = 3*4
133 factory = ipaddress.IPv4Address
134 elif version == 6:
135 addrlen = 16
136 saddroff = 2*4
137 factory = ipaddress.IPv6Address
138 else:
139 raise ValueError('unsupported IP version %d' % version)
140 saddr = factory(packet[ saddroff : saddroff + addrlen ])
141 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
142 return (saddr, daddr)
143
144 #---------- address handling ----------
145
146 def ipaddr(input):
147 try:
148 r = ipaddress.IPv4Address(input)
149 except AddressValueError:
150 r = ipaddress.IPv6Address(input)
151 return r
152
153 def ipnetwork(input):
154 try:
155 r = ipaddress.IPv4Network(input)
156 except NetworkValueError:
157 r = ipaddress.IPv6Network(input)
158 return r
159
160 #---------- ipif (SLIP) subprocess ----------
161
162 class SlipStreamDecoder():
163 def __init__(self, desc, on_packet):
164 self._buffer = b''
165 self._on_packet = on_packet
166 self._desc = desc
167 self._log('__init__')
168
169 def _log(self, msg, **kwargs):
170 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
171
172 def inputdata(self, data):
173 self._log('inputdata', d=data)
174 packets = slip.decode(data)
175 packets[0] = self._buffer + packets[0]
176 self._buffer = packets.pop()
177 for packet in packets:
178 self._maybe_packet(packet)
179 self._log('bufremain', d=self._buffer)
180
181 def _maybe_packet(self, packet):
182 self._log('maybepacket', d=packet)
183 if len(packet):
184 self._on_packet(packet)
185
186 def flush(self):
187 self._log('flush')
188 self._maybe_packet(self._buffer)
189 self._buffer = b''
190
191 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
192 def __init__(self, router):
193 self._router = router
194 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
195 def connectionMade(self): pass
196 def outReceived(self, data):
197 self._decoder.inputdata(data)
198 def slip_on_packet(self, packet):
199 (saddr, daddr) = packet_addrs(packet)
200 if saddr.is_link_local or daddr.is_link_local:
201 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
202 return
203 self._router(packet, saddr, daddr)
204 def processEnded(self, status):
205 status.raiseException()
206
207 def start_ipif(command, router):
208 global ipif
209 ipif = _IpifProcessProtocol(router)
210 reactor.spawnProcess(ipif,
211 '/bin/sh',['sh','-xc', command],
212 childFDs={0:'w', 1:'r', 2:2},
213 env=None)
214
215 def queue_inbound(packet):
216 log_debug(DBG.FLOW, "queue_inbound", d=packet)
217 ipif.transport.write(slip.delimiter)
218 ipif.transport.write(slip.encode(packet))
219 ipif.transport.write(slip.delimiter)
220
221 #---------- packet queue ----------
222
223 class PacketQueue():
224 def __init__(self, desc, max_queue_time):
225 self._desc = desc
226 assert(desc + '')
227 self._max_queue_time = max_queue_time
228 self._pq = collections.deque() # packets
229
230 def _log(self, dflag, msg, **kwargs):
231 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
232
233 def append(self, packet):
234 self._log(DBG.QUEUE, 'append', d=packet)
235 self._pq.append((time.monotonic(), packet))
236
237 def nonempty(self):
238 self._log(DBG.QUEUE, 'nonempty ?')
239 while True:
240 try: (queuetime, packet) = self._pq[0]
241 except IndexError:
242 self._log(DBG.QUEUE, 'nonempty ? empty.')
243 return False
244
245 age = time.monotonic() - queuetime
246 if age > self._max_queue_time:
247 # strip old packets off the front
248 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
249 self._pq.popleft()
250 continue
251
252 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
253 return True
254
255 def process(self, sizequery, moredata, max_batch):
256 # sizequery() should return size of batch so far
257 # moredata(s) should add s to batch
258 self._log(DBG.QUEUE, 'process...')
259 while True:
260 try: (dummy, packet) = self._pq[0]
261 except IndexError:
262 self._log(DBG.QUEUE, 'process... empty')
263 break
264
265 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
266
267 encoded = slip.encode(packet)
268 sofar = sizequery()
269
270 self._log(DBG.QUEUE_CTRL,
271 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
272 d=encoded)
273
274 if sofar > 0:
275 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
276 self._log(DBG.QUEUE_CTRL, 'process... overflow')
277 break
278 moredata(slip.delimiter)
279
280 moredata(encoded)
281 self._pq.popleft()
282
283 #---------- error handling ----------
284
285 _crashing = False
286
287 def crash(err):
288 global _crashing
289 _crashing = True
290 print('========== CRASH ==========', err,
291 '===========================', file=sys.stderr)
292 try: reactor.stop()
293 except twisted.internet.error.ReactorNotRunning: pass
294
295 def crash_on_defer(defer):
296 defer.addErrback(lambda err: crash(err))
297
298 def crash_on_critical(event):
299 if event.get('log_level') >= LogLevel.critical:
300 crash(twisted.logger.formatEvent(event))
301
302 #---------- config processing ----------
303
304 def process_cfg_common_always():
305 global mtu
306 c.mtu = cfg.get('virtual','mtu')
307
308 def process_cfg_ipif(section, varmap):
309 for d, s in varmap:
310 try: v = getattr(c, s)
311 except AttributeError: continue
312 setattr(c, d, v)
313
314 #print(repr(c))
315
316 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
317
318 def process_cfg_network():
319 c.network = ipnetwork(cfg.get('virtual','network'))
320 if c.network.num_addresses < 3 + 2:
321 raise ValueError('network needs at least 2^3 addresses')
322
323 def process_cfg_server():
324 try:
325 c.server = cfg.get('virtual','server')
326 except NoOptionError:
327 process_cfg_network()
328 c.server = next(c.network.hosts())
329
330 class ServerAddr():
331 def __init__(self, port, addrspec):
332 self.port = port
333 # also self.addr
334 try:
335 self.addr = ipaddress.IPv4Address(addrspec)
336 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
337 self._inurl = b'%s'
338 except AddressValueError:
339 self.addr = ipaddress.IPv6Address(addrspec)
340 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
341 self._inurl = b'[%s]'
342 def make_endpoint(self):
343 return self._endpointfactory(reactor, self.port, self.addr)
344 def url(self):
345 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
346 if self.port != 80: url += b':%d' % self.port
347 url += b'/'
348 return url
349
350 def process_cfg_saddrs():
351 try: port = cfg.getint('server','port')
352 except NoOptionError: port = 80
353
354 c.saddrs = [ ]
355 for addrspec in cfg.get('server','addrs').split():
356 sa = ServerAddr(port, addrspec)
357 c.saddrs.append(sa)
358
359 def process_cfg_clients(constructor):
360 c.clients = [ ]
361 for cs in cfg.sections():
362 if not (':' in cs or '.' in cs): continue
363 ci = ipaddr(cs)
364 pw = cfg.get(cs, 'password')
365 pw = pw.encode('utf-8')
366 constructor(ci,cs,pw)
367
368 #---------- startup ----------
369
370 def common_startup():
371 log_formatter = twisted.logger.formatEventAsClassicLogText
372 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
373 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
374 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
375 log_observer = twisted.logger.FilteringLogObserver(
376 stderr_obs, [pred], stdout_obs
377 )
378 twisted.logger.globalLogBeginner.beginLoggingTo(
379 [ log_observer, crash_on_critical ]
380 )
381
382 optparser.add_option('-c', '--config', dest='configfile',
383 default='/etc/hippotat/config')
384 (opts, args) = optparser.parse_args()
385 if len(args): optparser.error('no non-option arguments please')
386
387 re = regexp.compile('#.*')
388 cfg.read_string(re.sub('', defcfg))
389 cfg.read(opts.configfile)
390
391 def common_run():
392 log_debug(DBG.INIT, 'entering reactor')
393 if not _crashing: reactor.run()
394 print('CRASHED (end)', file=sys.stderr)