e13b31f642429f443da389eb7095e350155d0baa
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35 INIT = NamedConstant()
36 ROUTE = NamedConstant()
37 DROP = NamedConstant()
38 FLOW = NamedConstant()
39 HTTP = NamedConstant()
40 TWISTED = NamedConstant()
41 QUEUE = NamedConstant()
42 HTTP_CTRL = NamedConstant()
43 QUEUE_CTRL = NamedConstant()
44 HTTP_FULL = NamedConstant()
45 CTRL_DUMP = NamedConstant()
46 SLIP_FULL = NamedConstant()
47
48 _hex_codec = codecs.getencoder('hex_codec')
49
50 log = twisted.logger.Logger()
51
52 def log_debug(dflag, msg, idof=None, d=None):
53 if dflag > DBG.HTTP: return
54 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
55 if idof is not None:
56 msg = '[%#x] %s' % (id(idof), msg)
57 if d is not None:
58 #d = d[0:64]
59 d = _hex_codec(d)[0].decode('ascii')
60 msg += ' ' + d
61 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
62
63 defcfg = '''
64 [DEFAULT]
65 #[<client>] overrides
66 max_batch_down = 65536 # used by server, subject to [limits]
67 max_queue_time = 10 # used by server, subject to [limits]
68 target_requests_outstanding = 3 # must match; subject to [limits] on server
69 http_timeout = 30 # used by both } must be
70 http_timeout_grace = 5 # used by both } compatible
71 max_requests_outstanding = 4 # used by client
72 max_batch_up = 4000 # used by client
73 http_retry = 5 # used by client
74
75 #[server] or [<client>] overrides
76 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
77 # extra interpolations: %(local)s %(peer)s %(rnet)s
78 # obtained on server [virtual]server [virtual]relay [virtual]network
79 # from on client <client> [virtual]server [virtual]routes
80
81 [virtual]
82 mtu = 1500
83 routes = ''
84 # network = <prefix>/<len> # mandatory for server
85 # server = <ipaddr> # used by both, default is computed from `network'
86 # relay = <ipaddr> # used by server, default from `network' and `server'
87 # default server is first host in network
88 # default relay is first host which is not server
89
90 [server]
91 # addrs = 127.0.0.1 ::1 # mandatory for server
92 port = 80 # used by server
93 # url # used by client; default from first `addrs' and `port'
94
95 # [<client-ip4-or-ipv6-address>]
96 # password = <password> # used by both, must match
97
98 [limits]
99 max_batch_down = 262144 # used by server
100 max_queue_time = 121 # used by server
101 http_timeout = 121 # used by server
102 target_requests_outstanding = 10 # used by server
103 '''
104
105 # these need to be defined here so that they can be imported by import *
106 cfg = ConfigParser()
107 optparser = OptionParser()
108
109 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
110 def mime_translate(s):
111 # SLIP-encoded packets cannot contain ESC ESC.
112 # Swap `-' and ESC. The result cannot contain `--'
113 return s.translate(_mimetrans)
114
115 class ConfigResults:
116 def __init__(self, d = { }):
117 self.__dict__ = d
118 def __repr__(self):
119 return 'ConfigResults('+repr(self.__dict__)+')'
120
121 c = ConfigResults()
122
123 def log_discard(packet, iface, saddr, daddr, why):
124 log_debug(DBG.DROP,
125 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
126 d=packet)
127
128 #---------- packet parsing ----------
129
130 def packet_addrs(packet):
131 version = packet[0] >> 4
132 if version == 4:
133 addrlen = 4
134 saddroff = 3*4
135 factory = ipaddress.IPv4Address
136 elif version == 6:
137 addrlen = 16
138 saddroff = 2*4
139 factory = ipaddress.IPv6Address
140 else:
141 raise ValueError('unsupported IP version %d' % version)
142 saddr = factory(packet[ saddroff : saddroff + addrlen ])
143 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
144 return (saddr, daddr)
145
146 #---------- address handling ----------
147
148 def ipaddr(input):
149 try:
150 r = ipaddress.IPv4Address(input)
151 except AddressValueError:
152 r = ipaddress.IPv6Address(input)
153 return r
154
155 def ipnetwork(input):
156 try:
157 r = ipaddress.IPv4Network(input)
158 except NetworkValueError:
159 r = ipaddress.IPv6Network(input)
160 return r
161
162 #---------- ipif (SLIP) subprocess ----------
163
164 class SlipStreamDecoder():
165 def __init__(self, desc, on_packet):
166 self._buffer = b''
167 self._on_packet = on_packet
168 self._desc = desc
169 self._log('__init__')
170
171 def _log(self, msg, **kwargs):
172 log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
173
174 def inputdata(self, data):
175 self._log('inputdata', d=data)
176 packets = slip.decode(data)
177 packets[0] = self._buffer + packets[0]
178 self._buffer = packets.pop()
179 for packet in packets:
180 self._maybe_packet(packet)
181 self._log('bufremain', d=self._buffer)
182
183 def _maybe_packet(self, packet):
184 self._log('maybepacket', d=packet)
185 if len(packet):
186 self._on_packet(packet)
187
188 def flush(self):
189 self._log('flush')
190 self._maybe_packet(self._buffer)
191 self._buffer = b''
192
193 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
194 def __init__(self, router):
195 self._router = router
196 self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
197 def connectionMade(self): pass
198 def outReceived(self, data):
199 self._decoder.inputdata(data)
200 def slip_on_packet(self, packet):
201 (saddr, daddr) = packet_addrs(packet)
202 if saddr.is_link_local or daddr.is_link_local:
203 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
204 return
205 self._router(packet, saddr, daddr)
206 def processEnded(self, status):
207 status.raiseException()
208
209 def start_ipif(command, router):
210 global ipif
211 ipif = _IpifProcessProtocol(router)
212 reactor.spawnProcess(ipif,
213 '/bin/sh',['sh','-xc', command],
214 childFDs={0:'w', 1:'r', 2:2},
215 env=None)
216
217 def queue_inbound(packet):
218 log_debug(DBG.FLOW, "queue_inbound", d=packet)
219 ipif.transport.write(slip.delimiter)
220 ipif.transport.write(slip.encode(packet))
221 ipif.transport.write(slip.delimiter)
222
223 #---------- packet queue ----------
224
225 class PacketQueue():
226 def __init__(self, desc, max_queue_time):
227 self._desc = desc
228 assert(desc + '')
229 self._max_queue_time = max_queue_time
230 self._pq = collections.deque() # packets
231
232 def _log(self, dflag, msg, **kwargs):
233 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
234
235 def append(self, packet):
236 self._log(DBG.QUEUE, 'append', d=packet)
237 self._pq.append((time.monotonic(), packet))
238
239 def nonempty(self):
240 self._log(DBG.QUEUE, 'nonempty ?')
241 while True:
242 try: (queuetime, packet) = self._pq[0]
243 except IndexError:
244 self._log(DBG.QUEUE, 'nonempty ? empty.')
245 return False
246
247 age = time.monotonic() - queuetime
248 if age > self._max_queue_time:
249 # strip old packets off the front
250 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
251 self._pq.popleft()
252 continue
253
254 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
255 return True
256
257 def process(self, sizequery, moredata, max_batch):
258 # sizequery() should return size of batch so far
259 # moredata(s) should add s to batch
260 self._log(DBG.QUEUE, 'process...')
261 while True:
262 try: (dummy, packet) = self._pq[0]
263 except IndexError:
264 self._log(DBG.QUEUE, 'process... empty')
265 break
266
267 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
268
269 encoded = slip.encode(packet)
270 sofar = sizequery()
271
272 self._log(DBG.QUEUE_CTRL,
273 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
274 d=encoded)
275
276 if sofar > 0:
277 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
278 self._log(DBG.QUEUE_CTRL, 'process... overflow')
279 break
280 moredata(slip.delimiter)
281
282 moredata(encoded)
283 self._pq.popleft()
284
285 #---------- error handling ----------
286
287 _crashing = False
288
289 def crash(err):
290 global _crashing
291 _crashing = True
292 print('========== CRASH ==========', err,
293 '===========================', file=sys.stderr)
294 try: reactor.stop()
295 except twisted.internet.error.ReactorNotRunning: pass
296
297 def crash_on_defer(defer):
298 defer.addErrback(lambda err: crash(err))
299
300 def crash_on_critical(event):
301 if event.get('log_level') >= LogLevel.critical:
302 crash(twisted.logger.formatEvent(event))
303
304 #---------- config processing ----------
305
306 def process_cfg_common_always():
307 global mtu
308 c.mtu = cfg.get('virtual','mtu')
309
310 def process_cfg_ipif(section, varmap):
311 for d, s in varmap:
312 try: v = getattr(c, s)
313 except AttributeError: continue
314 setattr(c, d, v)
315
316 #print(repr(c))
317
318 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
319
320 def process_cfg_network():
321 c.network = ipnetwork(cfg.get('virtual','network'))
322 if c.network.num_addresses < 3 + 2:
323 raise ValueError('network needs at least 2^3 addresses')
324
325 def process_cfg_server():
326 try:
327 c.server = cfg.get('virtual','server')
328 except NoOptionError:
329 process_cfg_network()
330 c.server = next(c.network.hosts())
331
332 class ServerAddr():
333 def __init__(self, port, addrspec):
334 self.port = port
335 # also self.addr
336 try:
337 self.addr = ipaddress.IPv4Address(addrspec)
338 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
339 self._inurl = b'%s'
340 except AddressValueError:
341 self.addr = ipaddress.IPv6Address(addrspec)
342 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
343 self._inurl = b'[%s]'
344 def make_endpoint(self):
345 return self._endpointfactory(reactor, self.port, self.addr)
346 def url(self):
347 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
348 if self.port != 80: url += b':%d' % self.port
349 url += b'/'
350 return url
351
352 def process_cfg_saddrs():
353 try: port = cfg.getint('server','port')
354 except NoOptionError: port = 80
355
356 c.saddrs = [ ]
357 for addrspec in cfg.get('server','addrs').split():
358 sa = ServerAddr(port, addrspec)
359 c.saddrs.append(sa)
360
361 def process_cfg_clients(constructor):
362 c.clients = [ ]
363 for cs in cfg.sections():
364 if not (':' in cs or '.' in cs): continue
365 ci = ipaddr(cs)
366 pw = cfg.get(cs, 'password')
367 pw = pw.encode('utf-8')
368 constructor(ci,cs,pw)
369
370 #---------- startup ----------
371
372 def common_startup():
373 log_formatter = twisted.logger.formatEventAsClassicLogText
374 stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
375 stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
376 pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
377 log_observer = twisted.logger.FilteringLogObserver(
378 stderr_obs, [pred], stdout_obs
379 )
380 twisted.logger.globalLogBeginner.beginLoggingTo(
381 [ log_observer, crash_on_critical ]
382 )
383
384 optparser.add_option('-c', '--config', dest='configfile',
385 default='/etc/hippotat/config')
386 (opts, args) = optparser.parse_args()
387 if len(args): optparser.error('no non-option arguments please')
388
389 re = regexp.compile('#.*')
390 cfg.read_string(re.sub('', defcfg))
391 cfg.read(opts.configfile)
392
393 def common_run():
394 log_debug(DBG.INIT, 'entering reactor')
395 if not _crashing: reactor.run()
396 print('CRASHED (end)', file=sys.stderr)