5d5129a86fe74eba3c6961f44e53699a2a7b48c2
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35 ROUTE = NamedConstant()
36 DROP = NamedConstant()
37 FLOW = NamedConstant()
38 HTTP = NamedConstant()
39 HTTP_CTRL = NamedConstant()
40 INIT = NamedConstant()
41 QUEUE = NamedConstant()
42 QUEUE_CTRL = NamedConstant()
43 HTTP_FULL = NamedConstant()
44 SLIP_FULL = NamedConstant()
45 CTRL_DUMP = NamedConstant()
46
47 _hex_codec = codecs.getencoder('hex_codec')
48
49 log = twisted.logger.Logger()
50
51 def log_debug(dflag, msg, idof=None, d=None):
52 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
53 if idof is not None:
54 msg = '[%#x] %s' % (id(idof), msg)
55 if d is not None:
56 d = d[0:64]
57 d = _hex_codec(d)[0].decode('ascii')
58 msg += ' ' + d
59 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
60
61 defcfg = '''
62 [DEFAULT]
63 #[<client>] overrides
64 max_batch_down = 65536 # used by server, subject to [limits]
65 max_queue_time = 10 # used by server, subject to [limits]
66 max_request_time = 54 # used by server, subject to [limits]
67 target_requests_outstanding = 3 # must match; subject to [limits] on server
68 max_requests_outstanding = 4 # used by client
69 max_batch_up = 4000 # used by client
70 http_timeout = 30 # used by client
71 http_retry = 5 # used by client
72
73 #[server] or [<client>] overrides
74 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
75 # extra interpolations: %(local)s %(peer)s %(rnet)s
76 # obtained on server [virtual]server [virtual]relay [virtual]network
77 # from on client <client> [virtual]server [virtual]routes
78
79 [virtual]
80 mtu = 1500
81 routes = ''
82 # network = <prefix>/<len> # mandatory for server
83 # server = <ipaddr> # used by both, default is computed from `network'
84 # relay = <ipaddr> # used by server, default from `network' and `server'
85 # default server is first host in network
86 # default relay is first host which is not server
87
88 [server]
89 # addrs = 127.0.0.1 ::1 # mandatory for server
90 port = 80 # used by server
91 # url # used by client; default from first `addrs' and `port'
92
93 # [<client-ip4-or-ipv6-address>]
94 # password = <password> # used by both, must match
95
96 [limits]
97 max_batch_down = 262144 # used by server
98 max_queue_time = 121 # used by server
99 max_request_time = 121 # used by server
100 target_requests_outstanding = 10 # used by server
101 '''
102
103 # these need to be defined here so that they can be imported by import *
104 cfg = ConfigParser()
105 optparser = OptionParser()
106
107 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
108 def mime_translate(s):
109 # SLIP-encoded packets cannot contain ESC ESC.
110 # Swap `-' and ESC. The result cannot contain `--'
111 return s.translate(_mimetrans)
112
113 class ConfigResults:
114 def __init__(self, d = { }):
115 self.__dict__ = d
116 def __repr__(self):
117 return 'ConfigResults('+repr(self.__dict__)+')'
118
119 c = ConfigResults()
120
121 def log_discard(packet, iface, saddr, daddr, why):
122 log_debug(DBG.DROP,
123 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
124 d=packet)
125
126 #---------- packet parsing ----------
127
128 def packet_addrs(packet):
129 version = packet[0] >> 4
130 if version == 4:
131 addrlen = 4
132 saddroff = 3*4
133 factory = ipaddress.IPv4Address
134 elif version == 6:
135 addrlen = 16
136 saddroff = 2*4
137 factory = ipaddress.IPv6Address
138 else:
139 raise ValueError('unsupported IP version %d' % version)
140 saddr = factory(packet[ saddroff : saddroff + addrlen ])
141 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
142 return (saddr, daddr)
143
144 #---------- address handling ----------
145
146 def ipaddr(input):
147 try:
148 r = ipaddress.IPv4Address(input)
149 except AddressValueError:
150 r = ipaddress.IPv6Address(input)
151 return r
152
153 def ipnetwork(input):
154 try:
155 r = ipaddress.IPv4Network(input)
156 except NetworkValueError:
157 r = ipaddress.IPv6Network(input)
158 return r
159
160 #---------- ipif (SLIP) subprocess ----------
161
162 class SlipStreamDecoder():
163 def __init__(self, desc, on_packet):
164 self._buffer = b''
165 self._on_packet = on_packet
166 self._desc = desc
167 self._log('__init__')
168
169 def _log(self, msg, **kwargs):
170 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
171
172 def inputdata(self, data):
173 self._log('inputdata', d=data)
174 data = self._buffer + data
175 self._buffer = b''
176 packets = slip.decode(data)
177 self._buffer = packets.pop()
178 for packet in packets:
179 self._maybe_packet(packet)
180 self._log('bufremain', d=self._buffer)
181
182 def _maybe_packet(self, packet):
183 self._log('maybepacket', d=packet)
184 if len(packet):
185 self._on_packet(packet)
186
187 def flush(self):
188 self._log('flush')
189 self._maybe_packet(self._buffer)
190 self._buffer = b''
191
192 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
193 def __init__(self, router):
194 self._router = router
195 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
196 def connectionMade(self): pass
197 def outReceived(self, data):
198 self._decoder.inputdata(data)
199 def slip_on_packet(self, packet):
200 (saddr, daddr) = packet_addrs(packet)
201 if saddr.is_link_local or daddr.is_link_local:
202 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
203 return
204 self._router(packet, saddr, daddr)
205 def processEnded(self, status):
206 status.raiseException()
207
208 def start_ipif(command, router):
209 global ipif
210 ipif = _IpifProcessProtocol(router)
211 reactor.spawnProcess(ipif,
212 '/bin/sh',['sh','-xc', command],
213 childFDs={0:'w', 1:'r', 2:2},
214 env=None)
215
216 def queue_inbound(packet):
217 log_debug(DBG.FLOW, "queue_inbound", d=packet)
218 ipif.transport.write(slip.delimiter)
219 ipif.transport.write(slip.encode(packet))
220 ipif.transport.write(slip.delimiter)
221
222 #---------- packet queue ----------
223
224 class PacketQueue():
225 def __init__(self, desc, max_queue_time):
226 self._desc = desc
227 assert(desc + '')
228 self._max_queue_time = max_queue_time
229 self._pq = collections.deque() # packets
230
231 def _log(self, dflag, msg, **kwargs):
232 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
233
234 def append(self, packet):
235 self._log(DBG.QUEUE, 'append', d=packet)
236 self._pq.append((time.monotonic(), packet))
237
238 def nonempty(self):
239 self._log(DBG.QUEUE, 'nonempty ?')
240 while True:
241 try: (queuetime, packet) = self._pq[0]
242 except IndexError:
243 self._log(DBG.QUEUE, 'nonempty ? empty.')
244 return False
245
246 age = time.monotonic() - queuetime
247 if age > self._max_queue_time:
248 # strip old packets off the front
249 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
250 self._pq.popleft()
251 continue
252
253 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
254 return True
255
256 def process(self, sizequery, moredata, max_batch):
257 # sizequery() should return size of batch so far
258 # moredata(s) should add s to batch
259 self._log(DBG.QUEUE, 'process...')
260 while True:
261 try: (dummy, packet) = self._pq[0]
262 except IndexError:
263 self._log(DBG.QUEUE, 'process... empty')
264 break
265
266 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
267
268 encoded = slip.encode(packet)
269 sofar = sizequery()
270
271 self._log(DBG.QUEUE_CTRL,
272 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
273 d=encoded)
274
275 if sofar > 0:
276 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
277 self._log(DBG.QUEUE_CTRL, 'process... overflow')
278 break
279 moredata(slip.delimiter)
280
281 moredata(encoded)
282 self._pq.popleft()
283
284 #---------- error handling ----------
285
286 _crashing = False
287
288 def crash(err):
289 global _crashing
290 _crashing = True
291 print('========== CRASH ==========', err,
292 '===========================', file=sys.stderr)
293 try: reactor.stop()
294 except twisted.internet.error.ReactorNotRunning: pass
295
296 def crash_on_defer(defer):
297 defer.addErrback(lambda err: crash(err))
298
299 def crash_on_critical(event):
300 if event.get('log_level') >= LogLevel.critical:
301 crash(twisted.logger.formatEvent(event))
302
303 #---------- config processing ----------
304
305 def process_cfg_common_always():
306 global mtu
307 c.mtu = cfg.get('virtual','mtu')
308
309 def process_cfg_ipif(section, varmap):
310 for d, s in varmap:
311 try: v = getattr(c, s)
312 except AttributeError: continue
313 setattr(c, d, v)
314
315 #print(repr(c))
316
317 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
318
319 def process_cfg_network():
320 c.network = ipnetwork(cfg.get('virtual','network'))
321 if c.network.num_addresses < 3 + 2:
322 raise ValueError('network needs at least 2^3 addresses')
323
324 def process_cfg_server():
325 try:
326 c.server = cfg.get('virtual','server')
327 except NoOptionError:
328 process_cfg_network()
329 c.server = next(c.network.hosts())
330
331 class ServerAddr():
332 def __init__(self, port, addrspec):
333 self.port = port
334 # also self.addr
335 try:
336 self.addr = ipaddress.IPv4Address(addrspec)
337 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
338 self._inurl = b'%s'
339 except AddressValueError:
340 self.addr = ipaddress.IPv6Address(addrspec)
341 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
342 self._inurl = b'[%s]'
343 def make_endpoint(self):
344 return self._endpointfactory(reactor, self.port, self.addr)
345 def url(self):
346 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
347 if self.port != 80: url += b':%d' % self.port
348 url += b'/'
349 return url
350
351 def process_cfg_saddrs():
352 try: port = cfg.getint('server','port')
353 except NoOptionError: port = 80
354
355 c.saddrs = [ ]
356 for addrspec in cfg.get('server','addrs').split():
357 sa = ServerAddr(port, addrspec)
358 c.saddrs.append(sa)
359
360 def process_cfg_clients(constructor):
361 c.clients = [ ]
362 for cs in cfg.sections():
363 if not (':' in cs or '.' in cs): continue
364 ci = ipaddr(cs)
365 pw = cfg.get(cs, 'password')
366 pw = pw.encode('utf-8')
367 constructor(ci,cs,pw)
368
369 #---------- startup ----------
370
371 def common_startup():
372 log_formatter = twisted.logger.formatEventAsClassicLogText
373 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
374 twisted.logger.globalLogBeginner.beginLoggingTo(
375 [ log_observer, crash_on_critical ]
376 )
377
378 optparser.add_option('-c', '--config', dest='configfile',
379 default='/etc/hippotat/config')
380 (opts, args) = optparser.parse_args()
381 if len(args): optparser.error('no non-option arguments please')
382
383 re = regexp.compile('#.*')
384 cfg.read_string(re.sub('', defcfg))
385 cfg.read(opts.configfile)
386
387 def common_run():
388 log_debug(DBG.INIT, 'entering reactor')
389 if not _crashing: reactor.run()
390 print('CRASHED (end)', file=sys.stderr)