wip, new log stuff
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32 ROUTE = NamedConstant()
33 FLOW = NamedConstant()
34 HTTP = NamedConstant()
35 HTTP_CTRL = NamedConstant()
36 INIT = NamedConstant()
37 QUEUE = NamedConstant()
38 QUEUE_CTRL = NamedConstant()
39
40 _hexcodec = codecs.getencoder('hex_codec')
41
42 log = twisted.logger.Logger()
43
44 def log_debug(dflag, msg, idof=None, d=None):
45 if idof is not None:
46 msg = '[%d] %s' % (id(idof), msg)
47 if d is not None:
48 d = d[0:64]
49 d = _hex_codec(d)[0]
50 msg += ' ' + d
51 log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
52
53 defcfg = '''
54 [DEFAULT]
55 #[<client>] overrides
56 max_batch_down = 65536 # used by server, subject to [limits]
57 max_queue_time = 10 # used by server, subject to [limits]
58 max_request_time = 54 # used by server, subject to [limits]
59 target_requests_outstanding = 3 # must match; subject to [limits] on server
60 max_requests_outstanding = 4 # used by client
61 max_batch_up = 4000 # used by client
62 http_timeout = 30 # used by client
63 http_retry = 5 # used by client
64
65 #[server] or [<client>] overrides
66 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
67 # extra interpolations: %(local)s %(peer)s %(rnet)s
68 # obtained on server [virtual]server [virtual]relay [virtual]network
69 # from on client <client> [virtual]server [virtual]routes
70
71 [virtual]
72 mtu = 1500
73 routes = ''
74 # network = <prefix>/<len> # mandatory for server
75 # server = <ipaddr> # used by both, default is computed from `network'
76 # relay = <ipaddr> # used by server, default from `network' and `server'
77 # default server is first host in network
78 # default relay is first host which is not server
79
80 [server]
81 # addrs = 127.0.0.1 ::1 # mandatory for server
82 port = 80 # used by server
83 # url # used by client; default from first `addrs' and `port'
84
85 # [<client-ip4-or-ipv6-address>]
86 # password = <password> # used by both, must match
87
88 [limits]
89 max_batch_down = 262144 # used by server
90 max_queue_time = 121 # used by server
91 max_request_time = 121 # used by server
92 target_requests_outstanding = 10 # used by server
93 '''
94
95 # these need to be defined here so that they can be imported by import *
96 cfg = ConfigParser()
97 optparser = OptionParser()
98
99 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
100 def mime_translate(s):
101 # SLIP-encoded packets cannot contain ESC ESC.
102 # Swap `-' and ESC. The result cannot contain `--'
103 return s.translate(_mimetrans)
104
105 class ConfigResults:
106 def __init__(self, d = { }):
107 self.__dict__ = d
108 def __repr__(self):
109 return 'ConfigResults('+repr(self.__dict__)+')'
110
111 c = ConfigResults()
112
113 def log_discard(packet, saddr, daddr, why):
114
115 Print('Drop ', Saddr, Daddr, why)
116 # syslog.syslog(syslog.LOG_DEBUG,
117 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
118
119 #---------- packet parsing ----------
120
121 def packet_addrs(packet):
122 version = packet[0] >> 4
123 if version == 4:
124 addrlen = 4
125 saddroff = 3*4
126 factory = ipaddress.IPv4Address
127 elif version == 6:
128 addrlen = 16
129 saddroff = 2*4
130 factory = ipaddress.IPv6Address
131 else:
132 raise ValueError('unsupported IP version %d' % version)
133 saddr = factory(packet[ saddroff : saddroff + addrlen ])
134 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
135 return (saddr, daddr)
136
137 #---------- address handling ----------
138
139 def ipaddr(input):
140 try:
141 r = ipaddress.IPv4Address(input)
142 except AddressValueError:
143 r = ipaddress.IPv6Address(input)
144 return r
145
146 def ipnetwork(input):
147 try:
148 r = ipaddress.IPv4Network(input)
149 except NetworkValueError:
150 r = ipaddress.IPv6Network(input)
151 return r
152
153 #---------- ipif (SLIP) subprocess ----------
154
155 class SlipStreamDecoder():
156 def __init__(self, on_packet):
157 # we will call packet(<packet>)
158 self._buffer = b''
159 self._on_packet = on_packet
160
161 def inputdata(self, data):
162 #print('SLIP-GOT ', repr(data))
163 self._buffer += data
164 packets = slip.decode(self._buffer)
165 self._buffer = packets.pop()
166 for packet in packets:
167 self._maybe_packet(packet)
168
169 def _maybe_packet(self, packet):
170 if len(packet):
171 self._on_packet(packet)
172
173 def flush(self):
174 self._maybe_packet(self._buffer)
175 self._buffer = b''
176
177 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
178 def __init__(self, router):
179 self._router = router
180 self._decoder = SlipStreamDecoder(self.slip_on_packet)
181 def connectionMade(self): pass
182 def outReceived(self, data):
183 self._decoder.inputdata(data)
184 def slip_on_packet(self, packet):
185 (saddr, daddr) = packet_addrs(packet)
186 if saddr.is_link_local or daddr.is_link_local:
187 log_discard(packet, saddr, daddr, 'link-local')
188 return
189 self._router(packet, saddr, daddr)
190 def processEnded(self, status):
191 status.raiseException()
192
193 def start_ipif(command, router):
194 global ipif
195 ipif = _IpifProcessProtocol(router)
196 reactor.spawnProcess(ipif,
197 '/bin/sh',['sh','-xc', command],
198 childFDs={0:'w', 1:'r', 2:2},
199 env=None)
200
201 def queue_inbound(packet):
202 ipif.transport.write(slip.delimiter)
203 ipif.transport.write(slip.encode(packet))
204 ipif.transport.write(slip.delimiter)
205
206 #---------- packet queue ----------
207
208 class PacketQueue():
209 def __init__(self, desc, max_queue_time):
210 self._desc = desc
211 self._max_queue_time = max_queue_time
212 self._pq = collections.deque() # packets
213
214 def _log(self, dflag, msg, **kwargs)
215 log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
216
217 def append(self, packet):
218 self._log(DBG.QUEUE, 'append', d=packet)
219 self._pq.append((time.monotonic(), packet))
220
221 def nonempty(self):
222 self._log(DBG.QUEUE, 'nonempty ?')
223 while True:
224 try: (queuetime, packet) = self._pq[0]
225 except IndexError:
226 self._log(DBG.QUEUE, 'nonempty ? empty.')
227 return False
228
229 age = time.monotonic() - queuetime
230 if age > self._max_queue_time:
231 # strip old packets off the front
232 self._log(DBG.QUEUE, 'dropping (old)', d=packet)
233 self._pq.popleft()
234 continue
235
236 self._log(DBG.QUEUE, 'nonempty ? nonempty.')
237 return True
238
239 def process(self, sizequery, moredata, max_batch):
240 # sizequery() should return size of batch so far
241 # moredata(s) should add s to batch
242 self._log(DBG.QUEUE, 'process...')
243 while True:
244 try: (dummy, packet) = self._pq[0]
245 except IndexError:
246 self._log(DBG.QUEUE, 'process... empty')
247 break
248
249 self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
250
251 encoded = slip.encode(packet)
252 sofar = sizequery()
253
254 self._log(DBG.QUEUE_CTRL,
255 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
256 d=encoded
257
258 if sofar > 0:
259 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
260 self._log(DBG.QUEUE_CTRL, 'process... overflow')
261 break
262 moredata(slip.delimiter)
263
264 moredata(encoded)
265 self._pq.popleft()
266
267 #---------- error handling ----------
268
269 def crash(err):
270 print('CRASH ', err, file=sys.stderr)
271 try: reactor.stop()
272 except twisted.internet.error.ReactorNotRunning: pass
273
274 def crash_on_defer(defer):
275 defer.addErrback(lambda err: crash(err))
276
277 def crash_on_critical(event):
278 if event.get('log_level') >= LogLevel.critical:
279 crash(twisted.logger.formatEvent(event))
280
281 #---------- config processing ----------
282
283 def process_cfg_common_always():
284 global mtu
285 c.mtu = cfg.get('virtual','mtu')
286
287 def process_cfg_ipif(section, varmap):
288 for d, s in varmap:
289 try: v = getattr(c, s)
290 except AttributeError: continue
291 setattr(c, d, v)
292
293 print(repr(c))
294
295 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
296
297 def process_cfg_network():
298 c.network = ipnetwork(cfg.get('virtual','network'))
299 if c.network.num_addresses < 3 + 2:
300 raise ValueError('network needs at least 2^3 addresses')
301
302 def process_cfg_server():
303 try:
304 c.server = cfg.get('virtual','server')
305 except NoOptionError:
306 process_cfg_network()
307 c.server = next(c.network.hosts())
308
309 class ServerAddr():
310 def __init__(self, port, addrspec):
311 self.port = port
312 # also self.addr
313 try:
314 self.addr = ipaddress.IPv4Address(addrspec)
315 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
316 self._inurl = b'%s'
317 except AddressValueError:
318 self.addr = ipaddress.IPv6Address(addrspec)
319 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
320 self._inurl = b'[%s]'
321 def make_endpoint(self):
322 return self._endpointfactory(reactor, self.port, self.addr)
323 def url(self):
324 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
325 if self.port != 80: url += b':%d' % self.port
326 url += b'/'
327 return url
328
329 def process_cfg_saddrs():
330 try: port = cfg.getint('server','port')
331 except NoOptionError: port = 80
332
333 c.saddrs = [ ]
334 for addrspec in cfg.get('server','addrs').split():
335 sa = ServerAddr(port, addrspec)
336 c.saddrs.append(sa)
337
338 def process_cfg_clients(constructor):
339 c.clients = [ ]
340 for cs in cfg.sections():
341 if not (':' in cs or '.' in cs): continue
342 ci = ipaddr(cs)
343 pw = cfg.get(cs, 'password')
344 pw = pw.encode('utf-8')
345 constructor(ci,cs,pw)
346
347 #---------- startup ----------
348
349 def common_startup():
350 log_formatter = twisted.logger.formatEventAsClassicLogText
351 log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
352 twisted.logger.globalLogBeginner.beginLoggingTo(
353 [ log_observer, crash_on_critical ]
354 )
355
356 optparser.add_option('-c', '--config', dest='configfile',
357 default='/etc/hippotat/config')
358 (opts, args) = optparser.parse_args()
359 if len(args): optparser.error('no non-option arguments please')
360
361 re = regexp.compile('#.*')
362 cfg.read_string(re.sub('', defcfg))
363 cfg.read(opts.configfile)
364
365 def common_run():
366 log_debug(DBG.INIT, 'ready')
367 reactor.run()
368 print('CRASHED (end)', file=sys.stderr)