064909100fbe409fe30ee67b772c902140060443
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32 ROUTE = NamedConstant()
33 DROP = NamedConstant()
34 FLOW = NamedConstant()
35 HTTP = NamedConstant()
36 HTTP_CTRL = NamedConstant()
37 INIT = NamedConstant()
38 QUEUE = NamedConstant()
39 QUEUE_CTRL = NamedConstant()
40
41 _hex_codec = codecs.getencoder('hex_codec')
42
43 log = twisted.logger.Logger()
44
45 def log_debug(dflag, msg, idof=None, d=None):
46 if idof is not None:
47 msg = '[%d] %s' % (id(idof), msg)
48 if d is not None:
49 d = d[0:64]
50 d = _hex_codec(d)[0].decode('ascii')
51 msg += ' ' + d
52 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
53
54 defcfg = '''
55 [DEFAULT]
56 #[<client>] overrides
57 max_batch_down = 65536 # used by server, subject to [limits]
58 max_queue_time = 10 # used by server, subject to [limits]
59 max_request_time = 54 # used by server, subject to [limits]
60 target_requests_outstanding = 3 # must match; subject to [limits] on server
61 max_requests_outstanding = 4 # used by client
62 max_batch_up = 4000 # used by client
63 http_timeout = 30 # used by client
64 http_retry = 5 # used by client
65
66 #[server] or [<client>] overrides
67 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
68 # extra interpolations: %(local)s %(peer)s %(rnet)s
69 # obtained on server [virtual]server [virtual]relay [virtual]network
70 # from on client <client> [virtual]server [virtual]routes
71
72 [virtual]
73 mtu = 1500
74 routes = ''
75 # network = <prefix>/<len> # mandatory for server
76 # server = <ipaddr> # used by both, default is computed from `network'
77 # relay = <ipaddr> # used by server, default from `network' and `server'
78 # default server is first host in network
79 # default relay is first host which is not server
80
81 [server]
82 # addrs = 127.0.0.1 ::1 # mandatory for server
83 port = 80 # used by server
84 # url # used by client; default from first `addrs' and `port'
85
86 # [<client-ip4-or-ipv6-address>]
87 # password = <password> # used by both, must match
88
89 [limits]
90 max_batch_down = 262144 # used by server
91 max_queue_time = 121 # used by server
92 max_request_time = 121 # used by server
93 target_requests_outstanding = 10 # used by server
94 '''
95
96 # these need to be defined here so that they can be imported by import *
97 cfg = ConfigParser()
98 optparser = OptionParser()
99
100 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
101 def mime_translate(s):
102 # SLIP-encoded packets cannot contain ESC ESC.
103 # Swap `-' and ESC. The result cannot contain `--'
104 return s.translate(_mimetrans)
105
106 class ConfigResults:
107 def __init__(self, d = { }):
108 self.__dict__ = d
109 def __repr__(self):
110 return 'ConfigResults('+repr(self.__dict__)+')'
111
112 c = ConfigResults()
113
114 def log_discard(packet, iface, saddr, daddr, why):
115 log_debug(DBG.DROP,
116 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
117 d=packet)
118
119 #---------- packet parsing ----------
120
121 def packet_addrs(packet):
122 version = packet[0] >> 4
123 if version == 4:
124 addrlen = 4
125 saddroff = 3*4
126 factory = ipaddress.IPv4Address
127 elif version == 6:
128 addrlen = 16
129 saddroff = 2*4
130 factory = ipaddress.IPv6Address
131 else:
132 raise ValueError('unsupported IP version %d' % version)
133 saddr = factory(packet[ saddroff : saddroff + addrlen ])
134 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
135 return (saddr, daddr)
136
137 #---------- address handling ----------
138
139 def ipaddr(input):
140 try:
141 r = ipaddress.IPv4Address(input)
142 except AddressValueError:
143 r = ipaddress.IPv6Address(input)
144 return r
145
146 def ipnetwork(input):
147 try:
148 r = ipaddress.IPv4Network(input)
149 except NetworkValueError:
150 r = ipaddress.IPv6Network(input)
151 return r
152
153 #---------- ipif (SLIP) subprocess ----------
154
155 class SlipStreamDecoder():
156 def __init__(self, on_packet):
157 # we will call packet(<packet>)
158 self._buffer = b''
159 self._on_packet = on_packet
160
161 def inputdata(self, data):
162 #print('SLIP-GOT ', repr(data))
163 self._buffer += data
164 packets = slip.decode(self._buffer)
165 self._buffer = packets.pop()
166 for packet in packets:
167 self._maybe_packet(packet)
168
169 def _maybe_packet(self, packet):
170 if len(packet):
171 self._on_packet(packet)
172
173 def flush(self):
174 self._maybe_packet(self._buffer)
175 self._buffer = b''
176
177 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
178 def __init__(self, router):
179 self._router = router
180 self._decoder = SlipStreamDecoder(self.slip_on_packet)
181 def connectionMade(self): pass
182 def outReceived(self, data):
183 self._decoder.inputdata(data)
184 def slip_on_packet(self, packet):
185 (saddr, daddr) = packet_addrs(packet)
186 if saddr.is_link_local or daddr.is_link_local:
187 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
188 return
189 self._router(packet, saddr, daddr)
190 def processEnded(self, status):
191 status.raiseException()
192
193 def start_ipif(command, router):
194 global ipif
195 ipif = _IpifProcessProtocol(router)
196 reactor.spawnProcess(ipif,
197 '/bin/sh',['sh','-xc', command],
198 childFDs={0:'w', 1:'r', 2:2},
199 env=None)
200
201 def queue_inbound(packet):
202 ipif.transport.write(slip.delimiter)
203 ipif.transport.write(slip.encode(packet))
204 ipif.transport.write(slip.delimiter)
205
206 #---------- packet queue ----------
207
208 class PacketQueue():
209 def __init__(self, desc, max_queue_time):
210 self._desc = desc
211 self._max_queue_time = max_queue_time
212 self._pq = collections.deque() # packets
213
214 def _log(self, dflag, msg, **kwargs):
215 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
216
217 def append(self, packet):
218 self._log(DBG.QUEUE, 'append', d=packet)
219 self._pq.append((time.monotonic(), packet))
220
221 def nonempty(self):
222 self._log(DBG.QUEUE, 'nonempty ?')
223 while True:
224 try: (queuetime, packet) = self._pq[0]
225 except IndexError:
226 self._log(DBG.QUEUE, 'nonempty ? empty.')
227 return False
228
229 age = time.monotonic() - queuetime
230 if age > self._max_queue_time:
231 # strip old packets off the front
232 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
233 self._pq.popleft()
234 continue
235
236 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
237 return True
238
239 def process(self, sizequery, moredata, max_batch):
240 # sizequery() should return size of batch so far
241 # moredata(s) should add s to batch
242 self._log(DBG.QUEUE, 'process...')
243 while True:
244 try: (dummy, packet) = self._pq[0]
245 except IndexError:
246 self._log(DBG.QUEUE, 'process... empty')
247 break
248
249 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
250
251 encoded = slip.encode(packet)
252 sofar = sizequery()
253
254 self._log(DBG.QUEUE_CTRL,
255 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
256 d=encoded)
257
258 if sofar > 0:
259 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
260 self._log(DBG.QUEUE_CTRL, 'process... overflow')
261 break
262 moredata(slip.delimiter)
263
264 moredata(encoded)
265 self._pq.popleft()
266
267 #---------- error handling ----------
268
269 _crashing = False
270
271 def crash(err):
272 global _crashing
273 _crashing = True
274 print('CRASH ', err, file=sys.stderr)
275 try: reactor.stop()
276 except twisted.internet.error.ReactorNotRunning: pass
277
278 def crash_on_defer(defer):
279 defer.addErrback(lambda err: crash(err))
280
281 def crash_on_critical(event):
282 if event.get('log_level') >= LogLevel.critical:
283 crash(twisted.logger.formatEvent(event))
284
285 #---------- config processing ----------
286
287 def process_cfg_common_always():
288 global mtu
289 c.mtu = cfg.get('virtual','mtu')
290
291 def process_cfg_ipif(section, varmap):
292 for d, s in varmap:
293 try: v = getattr(c, s)
294 except AttributeError: continue
295 setattr(c, d, v)
296
297 #print(repr(c))
298
299 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
300
301 def process_cfg_network():
302 c.network = ipnetwork(cfg.get('virtual','network'))
303 if c.network.num_addresses < 3 + 2:
304 raise ValueError('network needs at least 2^3 addresses')
305
306 def process_cfg_server():
307 try:
308 c.server = cfg.get('virtual','server')
309 except NoOptionError:
310 process_cfg_network()
311 c.server = next(c.network.hosts())
312
313 class ServerAddr():
314 def __init__(self, port, addrspec):
315 self.port = port
316 # also self.addr
317 try:
318 self.addr = ipaddress.IPv4Address(addrspec)
319 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
320 self._inurl = b'%s'
321 except AddressValueError:
322 self.addr = ipaddress.IPv6Address(addrspec)
323 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
324 self._inurl = b'[%s]'
325 def make_endpoint(self):
326 return self._endpointfactory(reactor, self.port, self.addr)
327 def url(self):
328 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
329 if self.port != 80: url += b':%d' % self.port
330 url += b'/'
331 return url
332
333 def process_cfg_saddrs():
334 try: port = cfg.getint('server','port')
335 except NoOptionError: port = 80
336
337 c.saddrs = [ ]
338 for addrspec in cfg.get('server','addrs').split():
339 sa = ServerAddr(port, addrspec)
340 c.saddrs.append(sa)
341
342 def process_cfg_clients(constructor):
343 c.clients = [ ]
344 for cs in cfg.sections():
345 if not (':' in cs or '.' in cs): continue
346 ci = ipaddr(cs)
347 pw = cfg.get(cs, 'password')
348 pw = pw.encode('utf-8')
349 constructor(ci,cs,pw)
350
351 #---------- startup ----------
352
353 def common_startup():
354 log_formatter = twisted.logger.formatEventAsClassicLogText
355 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
356 twisted.logger.globalLogBeginner.beginLoggingTo(
357 [ log_observer, crash_on_critical ]
358 )
359
360 optparser.add_option('-c', '--config', dest='configfile',
361 default='/etc/hippotat/config')
362 (opts, args) = optparser.parse_args()
363 if len(args): optparser.error('no non-option arguments please')
364
365 re = regexp.compile('#.*')
366 cfg.read_string(re.sub('', defcfg))
367 cfg.read(opts.configfile)
368
369 def common_run():
370 log_debug(DBG.INIT, 'entering reactor')
371 if not _crashing: reactor.run()
372 print('CRASHED (end)', file=sys.stderr)