wip, towards target
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32 ROUTE = NamedConstant()
33 DROP = NamedConstant()
34 FLOW = NamedConstant()
35 HTTP = NamedConstant()
36 HTTP_CTRL = NamedConstant()
37 INIT = NamedConstant()
38 QUEUE = NamedConstant()
39 QUEUE_CTRL = NamedConstant()
40 HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47 #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
48 if idof is not None:
49 msg = '[%d] %s' % (id(idof), msg)
50 if d is not None:
51 #d = d[0:64]
52 d = _hex_codec(d)[0].decode('ascii')
53 msg += ' ' + d
54 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
55
56 defcfg = '''
57 [DEFAULT]
58 #[<client>] overrides
59 max_batch_down = 65536 # used by server, subject to [limits]
60 max_queue_time = 10 # used by server, subject to [limits]
61 max_request_time = 54 # used by server, subject to [limits]
62 target_requests_outstanding = 3 # must match; subject to [limits] on server
63 max_requests_outstanding = 4 # used by client
64 max_batch_up = 4000 # used by client
65 http_timeout = 30 # used by client
66 http_retry = 5 # used by client
67
68 #[server] or [<client>] overrides
69 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
70 # extra interpolations: %(local)s %(peer)s %(rnet)s
71 # obtained on server [virtual]server [virtual]relay [virtual]network
72 # from on client <client> [virtual]server [virtual]routes
73
74 [virtual]
75 mtu = 1500
76 routes = ''
77 # network = <prefix>/<len> # mandatory for server
78 # server = <ipaddr> # used by both, default is computed from `network'
79 # relay = <ipaddr> # used by server, default from `network' and `server'
80 # default server is first host in network
81 # default relay is first host which is not server
82
83 [server]
84 # addrs = 127.0.0.1 ::1 # mandatory for server
85 port = 80 # used by server
86 # url # used by client; default from first `addrs' and `port'
87
88 # [<client-ip4-or-ipv6-address>]
89 # password = <password> # used by both, must match
90
91 [limits]
92 max_batch_down = 262144 # used by server
93 max_queue_time = 121 # used by server
94 max_request_time = 121 # used by server
95 target_requests_outstanding = 10 # used by server
96 '''
97
98 # these need to be defined here so that they can be imported by import *
99 cfg = ConfigParser()
100 optparser = OptionParser()
101
102 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
103 def mime_translate(s):
104 # SLIP-encoded packets cannot contain ESC ESC.
105 # Swap `-' and ESC. The result cannot contain `--'
106 return s.translate(_mimetrans)
107
108 class ConfigResults:
109 def __init__(self, d = { }):
110 self.__dict__ = d
111 def __repr__(self):
112 return 'ConfigResults('+repr(self.__dict__)+')'
113
114 c = ConfigResults()
115
116 def log_discard(packet, iface, saddr, daddr, why):
117 log_debug(DBG.DROP,
118 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
119 d=packet)
120
121 #---------- packet parsing ----------
122
123 def packet_addrs(packet):
124 version = packet[0] >> 4
125 if version == 4:
126 addrlen = 4
127 saddroff = 3*4
128 factory = ipaddress.IPv4Address
129 elif version == 6:
130 addrlen = 16
131 saddroff = 2*4
132 factory = ipaddress.IPv6Address
133 else:
134 raise ValueError('unsupported IP version %d' % version)
135 saddr = factory(packet[ saddroff : saddroff + addrlen ])
136 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
137 return (saddr, daddr)
138
139 #---------- address handling ----------
140
141 def ipaddr(input):
142 try:
143 r = ipaddress.IPv4Address(input)
144 except AddressValueError:
145 r = ipaddress.IPv6Address(input)
146 return r
147
148 def ipnetwork(input):
149 try:
150 r = ipaddress.IPv4Network(input)
151 except NetworkValueError:
152 r = ipaddress.IPv6Network(input)
153 return r
154
155 #---------- ipif (SLIP) subprocess ----------
156
157 class SlipStreamDecoder():
158 def __init__(self, on_packet):
159 # we will call packet(<packet>)
160 self._buffer = b''
161 self._on_packet = on_packet
162
163 def inputdata(self, data):
164 #print('SLIP-GOT ', repr(data))
165 self._buffer += data
166 packets = slip.decode(self._buffer)
167 self._buffer = packets.pop()
168 for packet in packets:
169 self._maybe_packet(packet)
170
171 def _maybe_packet(self, packet):
172 if len(packet):
173 self._on_packet(packet)
174
175 def flush(self):
176 self._maybe_packet(self._buffer)
177 self._buffer = b''
178
179 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
180 def __init__(self, router):
181 self._router = router
182 self._decoder = SlipStreamDecoder(self.slip_on_packet)
183 def connectionMade(self): pass
184 def outReceived(self, data):
185 self._decoder.inputdata(data)
186 def slip_on_packet(self, packet):
187 (saddr, daddr) = packet_addrs(packet)
188 if saddr.is_link_local or daddr.is_link_local:
189 log_discard(packet, 'ipif', saddr, daddr, 'link-local')
190 return
191 self._router(packet, saddr, daddr)
192 def processEnded(self, status):
193 status.raiseException()
194
195 def start_ipif(command, router):
196 global ipif
197 ipif = _IpifProcessProtocol(router)
198 reactor.spawnProcess(ipif,
199 '/bin/sh',['sh','-xc', command],
200 childFDs={0:'w', 1:'r', 2:2},
201 env=None)
202
203 def queue_inbound(packet):
204 log_debug(DBG.FLOW, "queue_inbound", d=packet)
205 ipif.transport.write(slip.delimiter)
206 ipif.transport.write(slip.encode(packet))
207 ipif.transport.write(slip.delimiter)
208
209 #---------- packet queue ----------
210
211 class PacketQueue():
212 def __init__(self, desc, max_queue_time):
213 self._desc = desc
214 assert(desc + '')
215 self._max_queue_time = max_queue_time
216 self._pq = collections.deque() # packets
217
218 def _log(self, dflag, msg, **kwargs):
219 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
220
221 def append(self, packet):
222 self._log(DBG.QUEUE, 'append', d=packet)
223 self._pq.append((time.monotonic(), packet))
224
225 def nonempty(self):
226 self._log(DBG.QUEUE, 'nonempty ?')
227 while True:
228 try: (queuetime, packet) = self._pq[0]
229 except IndexError:
230 self._log(DBG.QUEUE, 'nonempty ? empty.')
231 return False
232
233 age = time.monotonic() - queuetime
234 if age > self._max_queue_time:
235 # strip old packets off the front
236 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
237 self._pq.popleft()
238 continue
239
240 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
241 return True
242
243 def process(self, sizequery, moredata, max_batch):
244 # sizequery() should return size of batch so far
245 # moredata(s) should add s to batch
246 self._log(DBG.QUEUE, 'process...')
247 while True:
248 try: (dummy, packet) = self._pq[0]
249 except IndexError:
250 self._log(DBG.QUEUE, 'process... empty')
251 break
252
253 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
254
255 encoded = slip.encode(packet)
256 sofar = sizequery()
257
258 self._log(DBG.QUEUE_CTRL,
259 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
260 d=encoded)
261
262 if sofar > 0:
263 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
264 self._log(DBG.QUEUE_CTRL, 'process... overflow')
265 break
266 moredata(slip.delimiter)
267
268 moredata(encoded)
269 self._pq.popleft()
270
271 #---------- error handling ----------
272
273 _crashing = False
274
275 def crash(err):
276 global _crashing
277 _crashing = True
278 print('CRASH ', err, file=sys.stderr)
279 try: reactor.stop()
280 except twisted.internet.error.ReactorNotRunning: pass
281
282 def crash_on_defer(defer):
283 defer.addErrback(lambda err: crash(err))
284
285 def crash_on_critical(event):
286 if event.get('log_level') >= LogLevel.critical:
287 crash(twisted.logger.formatEvent(event))
288
289 #---------- config processing ----------
290
291 def process_cfg_common_always():
292 global mtu
293 c.mtu = cfg.get('virtual','mtu')
294
295 def process_cfg_ipif(section, varmap):
296 for d, s in varmap:
297 try: v = getattr(c, s)
298 except AttributeError: continue
299 setattr(c, d, v)
300
301 #print(repr(c))
302
303 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
304
305 def process_cfg_network():
306 c.network = ipnetwork(cfg.get('virtual','network'))
307 if c.network.num_addresses < 3 + 2:
308 raise ValueError('network needs at least 2^3 addresses')
309
310 def process_cfg_server():
311 try:
312 c.server = cfg.get('virtual','server')
313 except NoOptionError:
314 process_cfg_network()
315 c.server = next(c.network.hosts())
316
317 class ServerAddr():
318 def __init__(self, port, addrspec):
319 self.port = port
320 # also self.addr
321 try:
322 self.addr = ipaddress.IPv4Address(addrspec)
323 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
324 self._inurl = b'%s'
325 except AddressValueError:
326 self.addr = ipaddress.IPv6Address(addrspec)
327 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
328 self._inurl = b'[%s]'
329 def make_endpoint(self):
330 return self._endpointfactory(reactor, self.port, self.addr)
331 def url(self):
332 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
333 if self.port != 80: url += b':%d' % self.port
334 url += b'/'
335 return url
336
337 def process_cfg_saddrs():
338 try: port = cfg.getint('server','port')
339 except NoOptionError: port = 80
340
341 c.saddrs = [ ]
342 for addrspec in cfg.get('server','addrs').split():
343 sa = ServerAddr(port, addrspec)
344 c.saddrs.append(sa)
345
346 def process_cfg_clients(constructor):
347 c.clients = [ ]
348 for cs in cfg.sections():
349 if not (':' in cs or '.' in cs): continue
350 ci = ipaddr(cs)
351 pw = cfg.get(cs, 'password')
352 pw = pw.encode('utf-8')
353 constructor(ci,cs,pw)
354
355 #---------- startup ----------
356
357 def common_startup():
358 log_formatter = twisted.logger.formatEventAsClassicLogText
359 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
360 twisted.logger.globalLogBeginner.beginLoggingTo(
361 [ log_observer, crash_on_critical ]
362 )
363
364 optparser.add_option('-c', '--config', dest='configfile',
365 default='/etc/hippotat/config')
366 (opts, args) = optparser.parse_args()
367 if len(args): optparser.error('no non-option arguments please')
368
369 re = regexp.compile('#.*')
370 cfg.read_string(re.sub('', defcfg))
371 cfg.read(opts.configfile)
372
373 def common_run():
374 log_debug(DBG.INIT, 'entering reactor')
375 if not _crashing: reactor.run()
376 print('CRASHED (end)', file=sys.stderr)