fixes
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32 ROUTE = NamedConstant()
33 DROP = NamedConstant()
34 FLOW = NamedConstant()
35 HTTP = NamedConstant()
36 HTTP_CTRL = NamedConstant()
37 INIT = NamedConstant()
38 QUEUE = NamedConstant()
39 QUEUE_CTRL = NamedConstant()
40 HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47 if idof is not None:
48 msg = '[%d] %s' % (id(idof), msg)
49 if d is not None:
50 d = d[0:64]
51 d = _hex_codec(d)[0].decode('ascii')
52 msg += ' ' + d
53 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
54
55 defcfg = '''
56 [DEFAULT]
57 #[<client>] overrides
58 max_batch_down = 65536 # used by server, subject to [limits]
59 max_queue_time = 10 # used by server, subject to [limits]
60 max_request_time = 54 # used by server, subject to [limits]
61 target_requests_outstanding = 3 # must match; subject to [limits] on server
62 max_requests_outstanding = 4 # used by client
63 max_batch_up = 4000 # used by client
64 http_timeout = 30 # used by client
65 http_retry = 5 # used by client
66
67 #[server] or [<client>] overrides
68 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
69 # extra interpolations: %(local)s %(peer)s %(rnet)s
70 # obtained on server [virtual]server [virtual]relay [virtual]network
71 # from on client <client> [virtual]server [virtual]routes
72
73 [virtual]
74 mtu = 1500
75 routes = ''
76 # network = <prefix>/<len> # mandatory for server
77 # server = <ipaddr> # used by both, default is computed from `network'
78 # relay = <ipaddr> # used by server, default from `network' and `server'
79 # default server is first host in network
80 # default relay is first host which is not server
81
82 [server]
83 # addrs = 127.0.0.1 ::1 # mandatory for server
84 port = 80 # used by server
85 # url # used by client; default from first `addrs' and `port'
86
87 # [<client-ip4-or-ipv6-address>]
88 # password = <password> # used by both, must match
89
90 [limits]
91 max_batch_down = 262144 # used by server
92 max_queue_time = 121 # used by server
93 max_request_time = 121 # used by server
94 target_requests_outstanding = 10 # used by server
95 '''
96
97 # these need to be defined here so that they can be imported by import *
98 cfg = ConfigParser()
99 optparser = OptionParser()
100
101 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
102 def mime_translate(s):
103 # SLIP-encoded packets cannot contain ESC ESC.
104 # Swap `-' and ESC. The result cannot contain `--'
105 return s.translate(_mimetrans)
106
107 class ConfigResults:
108 def __init__(self, d = { }):
109 self.__dict__ = d
110 def __repr__(self):
111 return 'ConfigResults('+repr(self.__dict__)+')'
112
113 c = ConfigResults()
114
115 def log_discard(packet, iface, saddr, daddr, why):
116 log_debug(DBG.DROP,
117 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
118 d=packet)
119
120 #---------- packet parsing ----------
121
122 def packet_addrs(packet):
123 version = packet[0] >> 4
124 if version == 4:
125 addrlen = 4
126 saddroff = 3*4
127 factory = ipaddress.IPv4Address
128 elif version == 6:
129 addrlen = 16
130 saddroff = 2*4
131 factory = ipaddress.IPv6Address
132 else:
133 raise ValueError('unsupported IP version %d' % version)
134 saddr = factory(packet[ saddroff : saddroff + addrlen ])
135 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
136 return (saddr, daddr)
137
138 #---------- address handling ----------
139
140 def ipaddr(input):
141 try:
142 r = ipaddress.IPv4Address(input)
143 except AddressValueError:
144 r = ipaddress.IPv6Address(input)
145 return r
146
147 def ipnetwork(input):
148 try:
149 r = ipaddress.IPv4Network(input)
150 except NetworkValueError:
151 r = ipaddress.IPv6Network(input)
152 return r
153
154 #---------- ipif (SLIP) subprocess ----------
155
156 class SlipStreamDecoder():
157 def __init__(self, on_packet):
158 # we will call packet(<packet>)
159 self._buffer = b''
160 self._on_packet = on_packet
161
162 def inputdata(self, data):
163 #print('SLIP-GOT ', repr(data))
164 self._buffer += data
165 packets = slip.decode(self._buffer)
166 self._buffer = packets.pop()
167 for packet in packets:
168 self._maybe_packet(packet)
169
170 def _maybe_packet(self, packet):
171 if len(packet):
172 self._on_packet(packet)
173
174 def flush(self):
175 self._maybe_packet(self._buffer)
176 self._buffer = b''
177
178 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
179 def __init__(self, router):
180 self._router = router
181 self._decoder = SlipStreamDecoder(self.slip_on_packet)
182 def connectionMade(self): pass
183 def outReceived(self, data):
184 self._decoder.inputdata(data)
185 def slip_on_packet(self, packet):
186 (saddr, daddr) = packet_addrs(packet)
187 if saddr.is_link_local or daddr.is_link_local:
188 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
189 return
190 self._router(packet, saddr, daddr)
191 def processEnded(self, status):
192 status.raiseException()
193
194 def start_ipif(command, router):
195 global ipif
196 ipif = _IpifProcessProtocol(router)
197 reactor.spawnProcess(ipif,
198 '/bin/sh',['sh','-xc', command],
199 childFDs={0:'w', 1:'r', 2:2},
200 env=None)
201
202 def queue_inbound(packet):
203 ipif.transport.write(slip.delimiter)
204 ipif.transport.write(slip.encode(packet))
205 ipif.transport.write(slip.delimiter)
206
207 #---------- packet queue ----------
208
209 class PacketQueue():
210 def __init__(self, desc, max_queue_time):
211 self._desc = desc
212 assert(desc + '')
213 self._max_queue_time = max_queue_time
214 self._pq = collections.deque() # packets
215
216 def _log(self, dflag, msg, **kwargs):
217 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
218
219 def append(self, packet):
220 self._log(DBG.QUEUE, 'append', d=packet)
221 self._pq.append((time.monotonic(), packet))
222
223 def nonempty(self):
224 self._log(DBG.QUEUE, 'nonempty ?')
225 while True:
226 try: (queuetime, packet) = self._pq[0]
227 except IndexError:
228 self._log(DBG.QUEUE, 'nonempty ? empty.')
229 return False
230
231 age = time.monotonic() - queuetime
232 if age > self._max_queue_time:
233 # strip old packets off the front
234 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
235 self._pq.popleft()
236 continue
237
238 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
239 return True
240
241 def process(self, sizequery, moredata, max_batch):
242 # sizequery() should return size of batch so far
243 # moredata(s) should add s to batch
244 self._log(DBG.QUEUE, 'process...')
245 while True:
246 try: (dummy, packet) = self._pq[0]
247 except IndexError:
248 self._log(DBG.QUEUE, 'process... empty')
249 break
250
251 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
252
253 encoded = slip.encode(packet)
254 sofar = sizequery()
255
256 self._log(DBG.QUEUE_CTRL,
257 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
258 d=encoded)
259
260 if sofar > 0:
261 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
262 self._log(DBG.QUEUE_CTRL, 'process... overflow')
263 break
264 moredata(slip.delimiter)
265
266 moredata(encoded)
267 self._pq.popleft()
268
269 #---------- error handling ----------
270
271 _crashing = False
272
273 def crash(err):
274 global _crashing
275 _crashing = True
276 print('CRASH ', err, file=sys.stderr)
277 try: reactor.stop()
278 except twisted.internet.error.ReactorNotRunning: pass
279
280 def crash_on_defer(defer):
281 defer.addErrback(lambda err: crash(err))
282
283 def crash_on_critical(event):
284 if event.get('log_level') >= LogLevel.critical:
285 crash(twisted.logger.formatEvent(event))
286
287 #---------- config processing ----------
288
289 def process_cfg_common_always():
290 global mtu
291 c.mtu = cfg.get('virtual','mtu')
292
293 def process_cfg_ipif(section, varmap):
294 for d, s in varmap:
295 try: v = getattr(c, s)
296 except AttributeError: continue
297 setattr(c, d, v)
298
299 #print(repr(c))
300
301 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
302
303 def process_cfg_network():
304 c.network = ipnetwork(cfg.get('virtual','network'))
305 if c.network.num_addresses < 3 + 2:
306 raise ValueError('network needs at least 2^3 addresses')
307
308 def process_cfg_server():
309 try:
310 c.server = cfg.get('virtual','server')
311 except NoOptionError:
312 process_cfg_network()
313 c.server = next(c.network.hosts())
314
315 class ServerAddr():
316 def __init__(self, port, addrspec):
317 self.port = port
318 # also self.addr
319 try:
320 self.addr = ipaddress.IPv4Address(addrspec)
321 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
322 self._inurl = b'%s'
323 except AddressValueError:
324 self.addr = ipaddress.IPv6Address(addrspec)
325 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
326 self._inurl = b'[%s]'
327 def make_endpoint(self):
328 return self._endpointfactory(reactor, self.port, self.addr)
329 def url(self):
330 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
331 if self.port != 80: url += b':%d' % self.port
332 url += b'/'
333 return url
334
335 def process_cfg_saddrs():
336 try: port = cfg.getint('server','port')
337 except NoOptionError: port = 80
338
339 c.saddrs = [ ]
340 for addrspec in cfg.get('server','addrs').split():
341 sa = ServerAddr(port, addrspec)
342 c.saddrs.append(sa)
343
344 def process_cfg_clients(constructor):
345 c.clients = [ ]
346 for cs in cfg.sections():
347 if not (':' in cs or '.' in cs): continue
348 ci = ipaddr(cs)
349 pw = cfg.get(cs, 'password')
350 pw = pw.encode('utf-8')
351 constructor(ci,cs,pw)
352
353 #---------- startup ----------
354
355 def common_startup():
356 log_formatter = twisted.logger.formatEventAsClassicLogText
357 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
358 twisted.logger.globalLogBeginner.beginLoggingTo(
359 [ log_observer, crash_on_critical ]
360 )
361
362 optparser.add_option('-c', '--config', dest='configfile',
363 default='/etc/hippotat/config')
364 (opts, args) = optparser.parse_args()
365 if len(args): optparser.error('no non-option arguments please')
366
367 re = regexp.compile('#.*')
368 cfg.read_string(re.sub('', defcfg))
369 cfg.read(opts.configfile)
370
371 def common_run():
372 log_debug(DBG.INIT, 'entering reactor')
373 if not _crashing: reactor.run()
374 print('CRASHED (end)', file=sys.stderr)