wip
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32 ROUTE = NamedConstant()
33 DROP = NamedConstant()
34 FLOW = NamedConstant()
35 HTTP = NamedConstant()
36 HTTP_CTRL = NamedConstant()
37 INIT = NamedConstant()
38 QUEUE = NamedConstant()
39 QUEUE_CTRL = NamedConstant()
40 HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
48 if idof is not None:
49 msg = '[%d] %s' % (id(idof), msg)
50 if d is not None:
51 #d = d[0:64]
52 d = _hex_codec(d)[0].decode('ascii')
53 msg += ' ' + d
54 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
55
56 defcfg = '''
57 [DEFAULT]
58 #[<client>] overrides
59 max_batch_down = 65536 # used by server, subject to [limits]
60 max_queue_time = 10 # used by server, subject to [limits]
61 max_request_time = 54 # used by server, subject to [limits]
62 target_requests_outstanding = 3 # must match; subject to [limits] on server
63 max_requests_outstanding = 4 # used by client
64 max_batch_up = 4000 # used by client
65 http_timeout = 30 # used by client
66 http_retry = 5 # used by client
67
68 #[server] or [<client>] overrides
69 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
70 # extra interpolations: %(local)s %(peer)s %(rnet)s
71 # obtained on server [virtual]server [virtual]relay [virtual]network
72 # from on client <client> [virtual]server [virtual]routes
73
74 [virtual]
75 mtu = 1500
76 routes = ''
77 # network = <prefix>/<len> # mandatory for server
78 # server = <ipaddr> # used by both, default is computed from `network'
79 # relay = <ipaddr> # used by server, default from `network' and `server'
80 # default server is first host in network
81 # default relay is first host which is not server
82
83 [server]
84 # addrs = 127.0.0.1 ::1 # mandatory for server
85 port = 80 # used by server
86 # url # used by client; default from first `addrs' and `port'
87
88 # [<client-ip4-or-ipv6-address>]
89 # password = <password> # used by both, must match
90
91 [limits]
92 max_batch_down = 262144 # used by server
93 max_queue_time = 121 # used by server
94 max_request_time = 121 # used by server
95 target_requests_outstanding = 10 # used by server
96 '''
97
98 # these need to be defined here so that they can be imported by import *
99 cfg = ConfigParser()
100 optparser = OptionParser()
101
102 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
103 def mime_translate(s):
104 # SLIP-encoded packets cannot contain ESC ESC.
105 # Swap `-' and ESC. The result cannot contain `--'
106 return s.translate(_mimetrans)
107
108 class ConfigResults:
109 def __init__(self, d = { }):
110 self.__dict__ = d
111 def __repr__(self):
112 return 'ConfigResults('+repr(self.__dict__)+')'
113
114 c = ConfigResults()
115
116 def log_discard(packet, iface, saddr, daddr, why):
117 log_debug(DBG.DROP,
118 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
119 d=packet)
120
121 #---------- packet parsing ----------
122
123 def packet_addrs(packet):
124 version = packet[0] >> 4
125 if version == 4:
126 addrlen = 4
127 saddroff = 3*4
128 factory = ipaddress.IPv4Address
129 elif version == 6:
130 addrlen = 16
131 saddroff = 2*4
132 factory = ipaddress.IPv6Address
133 else:
134 raise ValueError('unsupported IP version %d' % version)
135 saddr = factory(packet[ saddroff : saddroff + addrlen ])
136 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
137 return (saddr, daddr)
138
139 #---------- address handling ----------
140
141 def ipaddr(input):
142 try:
143 r = ipaddress.IPv4Address(input)
144 except AddressValueError:
145 r = ipaddress.IPv6Address(input)
146 return r
147
148 def ipnetwork(input):
149 try:
150 r = ipaddress.IPv4Network(input)
151 except NetworkValueError:
152 r = ipaddress.IPv6Network(input)
153 return r
154
155 #---------- ipif (SLIP) subprocess ----------
156
157 class SlipStreamDecoder():
158 def __init__(self, on_packet):
159 # we will call packet(<packet>)
160 self._buffer = b''
161 self._on_packet = on_packet
162
163 def inputdata(self, data):
164 #print('SLIP-GOT ', repr(data))
165 self._buffer += data
166 packets = slip.decode(self._buffer)
167 self._buffer = packets.pop()
168 for packet in packets:
169 self._maybe_packet(packet)
170
171 def _maybe_packet(self, packet):
172 if len(packet):
173 self._on_packet(packet)
174
175 def flush(self):
176 self._maybe_packet(self._buffer)
177 self._buffer = b''
178
179 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
180 def __init__(self, router):
181 self._router = router
182 self._decoder = SlipStreamDecoder(self.slip_on_packet)
183 def connectionMade(self): pass
184 def outReceived(self, data):
185 self._decoder.inputdata(data)
186 def slip_on_packet(self, packet):
187 (saddr, daddr) = packet_addrs(packet)
188 if saddr.is_link_local or daddr.is_link_local:
189 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
190 return
191 self._router(packet, saddr, daddr)
192 def processEnded(self, status):
193 status.raiseException()
194
195 def start_ipif(command, router):
196 global ipif
197 ipif = _IpifProcessProtocol(router)
198 reactor.spawnProcess(ipif,
199 '/bin/sh',['sh','-xc', command],
200 childFDs={0:'w', 1:'r', 2:2},
201 env=None)
202
203 def queue_inbound(packet):
204 ipif.transport.write(slip.delimiter)
205 ipif.transport.write(slip.encode(packet))
206 ipif.transport.write(slip.delimiter)
207
208 #---------- packet queue ----------
209
210 class PacketQueue():
211 def __init__(self, desc, max_queue_time):
212 self._desc = desc
213 assert(desc + '')
214 self._max_queue_time = max_queue_time
215 self._pq = collections.deque() # packets
216
217 def _log(self, dflag, msg, **kwargs):
218 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
219
220 def append(self, packet):
221 self._log(DBG.QUEUE, 'append', d=packet)
222 self._pq.append((time.monotonic(), packet))
223
224 def nonempty(self):
225 self._log(DBG.QUEUE, 'nonempty ?')
226 while True:
227 try: (queuetime, packet) = self._pq[0]
228 except IndexError:
229 self._log(DBG.QUEUE, 'nonempty ? empty.')
230 return False
231
232 age = time.monotonic() - queuetime
233 if age > self._max_queue_time:
234 # strip old packets off the front
235 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
236 self._pq.popleft()
237 continue
238
239 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
240 return True
241
242 def process(self, sizequery, moredata, max_batch):
243 # sizequery() should return size of batch so far
244 # moredata(s) should add s to batch
245 self._log(DBG.QUEUE, 'process...')
246 while True:
247 try: (dummy, packet) = self._pq[0]
248 except IndexError:
249 self._log(DBG.QUEUE, 'process... empty')
250 break
251
252 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
253
254 encoded = slip.encode(packet)
255 sofar = sizequery()
256
257 self._log(DBG.QUEUE_CTRL,
258 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
259 d=encoded)
260
261 if sofar > 0:
262 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
263 self._log(DBG.QUEUE_CTRL, 'process... overflow')
264 break
265 moredata(slip.delimiter)
266
267 moredata(encoded)
268 self._pq.popleft()
269
270 #---------- error handling ----------
271
272 _crashing = False
273
274 def crash(err):
275 global _crashing
276 _crashing = True
277 print('CRASH ', err, file=sys.stderr)
278 try: reactor.stop()
279 except twisted.internet.error.ReactorNotRunning: pass
280
281 def crash_on_defer(defer):
282 defer.addErrback(lambda err: crash(err))
283
284 def crash_on_critical(event):
285 if event.get('log_level') >= LogLevel.critical:
286 crash(twisted.logger.formatEvent(event))
287
288 #---------- config processing ----------
289
290 def process_cfg_common_always():
291 global mtu
292 c.mtu = cfg.get('virtual','mtu')
293
294 def process_cfg_ipif(section, varmap):
295 for d, s in varmap:
296 try: v = getattr(c, s)
297 except AttributeError: continue
298 setattr(c, d, v)
299
300 #print(repr(c))
301
302 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
303
304 def process_cfg_network():
305 c.network = ipnetwork(cfg.get('virtual','network'))
306 if c.network.num_addresses < 3 + 2:
307 raise ValueError('network needs at least 2^3 addresses')
308
309 def process_cfg_server():
310 try:
311 c.server = cfg.get('virtual','server')
312 except NoOptionError:
313 process_cfg_network()
314 c.server = next(c.network.hosts())
315
316 class ServerAddr():
317 def __init__(self, port, addrspec):
318 self.port = port
319 # also self.addr
320 try:
321 self.addr = ipaddress.IPv4Address(addrspec)
322 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
323 self._inurl = b'%s'
324 except AddressValueError:
325 self.addr = ipaddress.IPv6Address(addrspec)
326 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
327 self._inurl = b'[%s]'
328 def make_endpoint(self):
329 return self._endpointfactory(reactor, self.port, self.addr)
330 def url(self):
331 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
332 if self.port != 80: url += b':%d' % self.port
333 url += b'/'
334 return url
335
336 def process_cfg_saddrs():
337 try: port = cfg.getint('server','port')
338 except NoOptionError: port = 80
339
340 c.saddrs = [ ]
341 for addrspec in cfg.get('server','addrs').split():
342 sa = ServerAddr(port, addrspec)
343 c.saddrs.append(sa)
344
345 def process_cfg_clients(constructor):
346 c.clients = [ ]
347 for cs in cfg.sections():
348 if not (':' in cs or '.' in cs): continue
349 ci = ipaddr(cs)
350 pw = cfg.get(cs, 'password')
351 pw = pw.encode('utf-8')
352 constructor(ci,cs,pw)
353
354 #---------- startup ----------
355
356 def common_startup():
357 log_formatter = twisted.logger.formatEventAsClassicLogText
358 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
359 twisted.logger.globalLogBeginner.beginLoggingTo(
360 [ log_observer, crash_on_critical ]
361 )
362
363 optparser.add_option('-c', '--config', dest='configfile',
364 default='/etc/hippotat/config')
365 (opts, args) = optparser.parse_args()
366 if len(args): optparser.error('no non-option arguments please')
367
368 re = regexp.compile('#.*')
369 cfg.read_string(re.sub('', defcfg))
370 cfg.read(opts.configfile)
371
372 def common_run():
373 log_debug(DBG.INIT, 'entering reactor')
374 if not _crashing: reactor.run()
375 print('CRASHED (end)', file=sys.stderr)