82f0634bf234d583456c0d5e48860fb2c6bc469c
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33 ROUTE = NamedConstant()
34 DROP = NamedConstant()
35 FLOW = NamedConstant()
36 HTTP = NamedConstant()
37 HTTP_CTRL = NamedConstant()
38 INIT = NamedConstant()
39 QUEUE = NamedConstant()
40 QUEUE_CTRL = NamedConstant()
41 HTTP_FULL = NamedConstant()
42
43 _hex_codec = codecs.getencoder('hex_codec')
44
45 log = twisted.logger.Logger()
46
47 def log_debug(dflag, msg, idof=None, d=None):
48 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
49 if idof is not None:
50 msg = '[%d] %s' % (id(idof), msg)
51 if d is not None:
52 #d = d[0:64]
53 d = _hex_codec(d)[0].decode('ascii')
54 msg += ' ' + d
55 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
56
57 defcfg = '''
58 [DEFAULT]
59 #[<client>] overrides
60 max_batch_down = 65536 # used by server, subject to [limits]
61 max_queue_time = 10 # used by server, subject to [limits]
62 max_request_time = 54 # used by server, subject to [limits]
63 target_requests_outstanding = 3 # must match; subject to [limits] on server
64 max_requests_outstanding = 4 # used by client
65 max_batch_up = 4000 # used by client
66 http_timeout = 30 # used by client
67 http_retry = 5 # used by client
68
69 #[server] or [<client>] overrides
70 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
71 # extra interpolations: %(local)s %(peer)s %(rnet)s
72 # obtained on server [virtual]server [virtual]relay [virtual]network
73 # from on client <client> [virtual]server [virtual]routes
74
75 [virtual]
76 mtu = 1500
77 routes = ''
78 # network = <prefix>/<len> # mandatory for server
79 # server = <ipaddr> # used by both, default is computed from `network'
80 # relay = <ipaddr> # used by server, default from `network' and `server'
81 # default server is first host in network
82 # default relay is first host which is not server
83
84 [server]
85 # addrs = 127.0.0.1 ::1 # mandatory for server
86 port = 80 # used by server
87 # url # used by client; default from first `addrs' and `port'
88
89 # [<client-ip4-or-ipv6-address>]
90 # password = <password> # used by both, must match
91
92 [limits]
93 max_batch_down = 262144 # used by server
94 max_queue_time = 121 # used by server
95 max_request_time = 121 # used by server
96 target_requests_outstanding = 10 # used by server
97 '''
98
99 # these need to be defined here so that they can be imported by import *
100 cfg = ConfigParser()
101 optparser = OptionParser()
102
103 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
104 def mime_translate(s):
105 # SLIP-encoded packets cannot contain ESC ESC.
106 # Swap `-' and ESC. The result cannot contain `--'
107 return s.translate(_mimetrans)
108
109 class ConfigResults:
110 def __init__(self, d = { }):
111 self.__dict__ = d
112 def __repr__(self):
113 return 'ConfigResults('+repr(self.__dict__)+')'
114
115 c = ConfigResults()
116
117 def log_discard(packet, iface, saddr, daddr, why):
118 log_debug(DBG.DROP,
119 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
120 d=packet)
121
122 #---------- packet parsing ----------
123
124 def packet_addrs(packet):
125 version = packet[0] >> 4
126 if version == 4:
127 addrlen = 4
128 saddroff = 3*4
129 factory = ipaddress.IPv4Address
130 elif version == 6:
131 addrlen = 16
132 saddroff = 2*4
133 factory = ipaddress.IPv6Address
134 else:
135 raise ValueError('unsupported IP version %d' % version)
136 saddr = factory(packet[ saddroff : saddroff + addrlen ])
137 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
138 return (saddr, daddr)
139
140 #---------- address handling ----------
141
142 def ipaddr(input):
143 try:
144 r = ipaddress.IPv4Address(input)
145 except AddressValueError:
146 r = ipaddress.IPv6Address(input)
147 return r
148
149 def ipnetwork(input):
150 try:
151 r = ipaddress.IPv4Network(input)
152 except NetworkValueError:
153 r = ipaddress.IPv6Network(input)
154 return r
155
156 #---------- ipif (SLIP) subprocess ----------
157
158 class SlipStreamDecoder():
159 def __init__(self, on_packet):
160 # we will call packet(<packet>)
161 self._buffer = b''
162 self._on_packet = on_packet
163
164 def inputdata(self, data):
165 #print('SLIP-GOT ', repr(data))
166 self._buffer += data
167 packets = slip.decode(self._buffer)
168 self._buffer = packets.pop()
169 for packet in packets:
170 self._maybe_packet(packet)
171
172 def _maybe_packet(self, packet):
173 if len(packet):
174 self._on_packet(packet)
175
176 def flush(self):
177 self._maybe_packet(self._buffer)
178 self._buffer = b''
179
180 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
181 def __init__(self, router):
182 self._router = router
183 self._decoder = SlipStreamDecoder(self.slip_on_packet)
184 def connectionMade(self): pass
185 def outReceived(self, data):
186 self._decoder.inputdata(data)
187 def slip_on_packet(self, packet):
188 (saddr, daddr) = packet_addrs(packet)
189 if saddr.is_link_local or daddr.is_link_local:
190 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
191 return
192 self._router(packet, saddr, daddr)
193 def processEnded(self, status):
194 status.raiseException()
195
196 def start_ipif(command, router):
197 global ipif
198 ipif = _IpifProcessProtocol(router)
199 reactor.spawnProcess(ipif,
200 '/bin/sh',['sh','-xc', command],
201 childFDs={0:'w', 1:'r', 2:2},
202 env=None)
203
204 def queue_inbound(packet):
205 log_debug(DBG.FLOW, "queue_inbound", d=packet)
206 ipif.transport.write(slip.delimiter)
207 ipif.transport.write(slip.encode(packet))
208 ipif.transport.write(slip.delimiter)
209
210 #---------- packet queue ----------
211
212 class PacketQueue():
213 def __init__(self, desc, max_queue_time):
214 self._desc = desc
215 assert(desc + '')
216 self._max_queue_time = max_queue_time
217 self._pq = collections.deque() # packets
218
219 def _log(self, dflag, msg, **kwargs):
220 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
221
222 def append(self, packet):
223 self._log(DBG.QUEUE, 'append', d=packet)
224 self._pq.append((time.monotonic(), packet))
225
226 def nonempty(self):
227 self._log(DBG.QUEUE, 'nonempty ?')
228 while True:
229 try: (queuetime, packet) = self._pq[0]
230 except IndexError:
231 self._log(DBG.QUEUE, 'nonempty ? empty.')
232 return False
233
234 age = time.monotonic() - queuetime
235 if age > self._max_queue_time:
236 # strip old packets off the front
237 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
238 self._pq.popleft()
239 continue
240
241 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
242 return True
243
244 def process(self, sizequery, moredata, max_batch):
245 # sizequery() should return size of batch so far
246 # moredata(s) should add s to batch
247 self._log(DBG.QUEUE, 'process...')
248 while True:
249 try: (dummy, packet) = self._pq[0]
250 except IndexError:
251 self._log(DBG.QUEUE, 'process... empty')
252 break
253
254 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
255
256 encoded = slip.encode(packet)
257 sofar = sizequery()
258
259 self._log(DBG.QUEUE_CTRL,
260 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
261 d=encoded)
262
263 if sofar > 0:
264 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
265 self._log(DBG.QUEUE_CTRL, 'process... overflow')
266 break
267 moredata(slip.delimiter)
268
269 moredata(encoded)
270 self._pq.popleft()
271
272 #---------- error handling ----------
273
274 _crashing = False
275
276 def crash(err):
277 global _crashing
278 _crashing = True
279 print('CRASH ', err, file=sys.stderr)
280 try: reactor.stop()
281 except twisted.internet.error.ReactorNotRunning: pass
282
283 def crash_on_defer(defer):
284 defer.addErrback(lambda err: crash(err))
285
286 def crash_on_critical(event):
287 if event.get('log_level') >= LogLevel.critical:
288 crash(twisted.logger.formatEvent(event))
289
290 #---------- config processing ----------
291
292 def process_cfg_common_always():
293 global mtu
294 c.mtu = cfg.get('virtual','mtu')
295
296 def process_cfg_ipif(section, varmap):
297 for d, s in varmap:
298 try: v = getattr(c, s)
299 except AttributeError: continue
300 setattr(c, d, v)
301
302 #print(repr(c))
303
304 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
305
306 def process_cfg_network():
307 c.network = ipnetwork(cfg.get('virtual','network'))
308 if c.network.num_addresses < 3 + 2:
309 raise ValueError('network needs at least 2^3 addresses')
310
311 def process_cfg_server():
312 try:
313 c.server = cfg.get('virtual','server')
314 except NoOptionError:
315 process_cfg_network()
316 c.server = next(c.network.hosts())
317
318 class ServerAddr():
319 def __init__(self, port, addrspec):
320 self.port = port
321 # also self.addr
322 try:
323 self.addr = ipaddress.IPv4Address(addrspec)
324 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
325 self._inurl = b'%s'
326 except AddressValueError:
327 self.addr = ipaddress.IPv6Address(addrspec)
328 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
329 self._inurl = b'[%s]'
330 def make_endpoint(self):
331 return self._endpointfactory(reactor, self.port, self.addr)
332 def url(self):
333 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
334 if self.port != 80: url += b':%d' % self.port
335 url += b'/'
336 return url
337
338 def process_cfg_saddrs():
339 try: port = cfg.getint('server','port')
340 except NoOptionError: port = 80
341
342 c.saddrs = [ ]
343 for addrspec in cfg.get('server','addrs').split():
344 sa = ServerAddr(port, addrspec)
345 c.saddrs.append(sa)
346
347 def process_cfg_clients(constructor):
348 c.clients = [ ]
349 for cs in cfg.sections():
350 if not (':' in cs or '.' in cs): continue
351 ci = ipaddr(cs)
352 pw = cfg.get(cs, 'password')
353 pw = pw.encode('utf-8')
354 constructor(ci,cs,pw)
355
356 #---------- startup ----------
357
358 def common_startup():
359 log_formatter = twisted.logger.formatEventAsClassicLogText
360 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
361 twisted.logger.globalLogBeginner.beginLoggingTo(
362 [ log_observer, crash_on_critical ]
363 )
364
365 optparser.add_option('-c', '--config', dest='configfile',
366 default='/etc/hippotat/config')
367 (opts, args) = optparser.parse_args()
368 if len(args): optparser.error('no non-option arguments please')
369
370 re = regexp.compile('#.*')
371 cfg.read_string(re.sub('', defcfg))
372 cfg.read(opts.configfile)
373
374 def common_run():
375 log_debug(DBG.INIT, 'entering reactor')
376 if not _crashing: reactor.run()
377 print('CRASHED (end)', file=sys.stderr)