wip
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21 import time
22
23 import re as regexp
24
25 import hippotat.slip as slip
26
27 defcfg = '''
28 [DEFAULT]
29 #[<client>] overrides
30 max_batch_down = 65536 # used by server, subject to [limits]
31 max_queue_time = 10 # used by server, subject to [limits]
32 max_request_time = 54 # used by server, subject to [limits]
33 target_requests_outstanding = 3 # must match; subject to [limits] on server
34 max_requests_outstanding = 4 # used by client
35 max_batch_up = 4000 # used by client
36 http_timeout = 30 # used by client
37
38 #[server] or [<client>] overrides
39 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
40 # extra interpolations: %(local)s %(peer)s %(rnet)s
41 # obtained on server [virtual]server [virtual]relay [virtual]network
42 # from on client <client> [virtual]server [virtual]routes
43
44 [virtual]
45 mtu = 1500
46 routes = ''
47 # network = <prefix>/<len> # mandatory for server
48 # server = <ipaddr> # used by both, default is computed from `network'
49 # relay = <ipaddr> # used by server, default from `network' and `server'
50 # default server is first host in network
51 # default relay is first host which is not server
52
53 [server]
54 # addrs = 127.0.0.1 ::1 # mandatory for server
55 port = 80 # used by server
56 # url # used by client; default from first `addrs' and `port'
57
58 # [<client-ip4-or-ipv6-address>]
59 # password = <password> # used by both, must match
60
61 [limits]
62 max_batch_down = 262144 # used by server
63 max_queue_time = 121 # used by server
64 max_request_time = 121 # used by server
65 target_requests_outstanding = 10 # used by server
66 '''
67
68 # these need to be defined here so that they can be imported by import *
69 cfg = ConfigParser()
70 optparser = OptionParser()
71
72 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
73 def mime_translate(s):
74 # SLIP-encoded packets cannot contain ESC ESC.
75 # Swap `-' and ESC. The result cannot contain `--'
76 return s.translate(_mimetrans)
77
78 class ConfigResults:
79 def __init__(self, d = { }):
80 self.__dict__ = d
81 def __repr__(self):
82 return 'ConfigResults('+repr(self.__dict__)+')'
83
84 c = ConfigResults()
85
86 def log_discard(packet, saddr, daddr, why):
87 print('DROP ', saddr, daddr, why)
88 # syslog.syslog(syslog.LOG_DEBUG,
89 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
90
91 #---------- packet parsing ----------
92
93 def packet_addrs(packet):
94 version = packet[0] >> 4
95 if version == 4:
96 addrlen = 4
97 saddroff = 3*4
98 factory = ipaddress.IPv4Address
99 elif version == 6:
100 addrlen = 16
101 saddroff = 2*4
102 factory = ipaddress.IPv6Address
103 else:
104 raise ValueError('unsupported IP version %d' % version)
105 saddr = factory(packet[ saddroff : saddroff + addrlen ])
106 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
107 return (saddr, daddr)
108
109 #---------- address handling ----------
110
111 def ipaddr(input):
112 try:
113 r = ipaddress.IPv4Address(input)
114 except AddressValueError:
115 r = ipaddress.IPv6Address(input)
116 return r
117
118 def ipnetwork(input):
119 try:
120 r = ipaddress.IPv4Network(input)
121 except NetworkValueError:
122 r = ipaddress.IPv6Network(input)
123 return r
124
125 #---------- ipif (SLIP) subprocess ----------
126
127 class SlipStreamDecoder():
128 def __init__(self, on_packet):
129 # we will call packet(<packet>)
130 self._buffer = b''
131 self._on_packet = on_packet
132
133 def inputdata(self, data):
134 #print('SLIP-GOT ', repr(data))
135 self._buffer += data
136 packets = slip.decode(self._buffer)
137 self._buffer = packets.pop()
138 for packet in packets:
139 self._maybe_packet(packet)
140
141 def _maybe_packet(self, packet):
142 if len(packet):
143 self._on_packet(packet)
144
145 def flush(self):
146 self._maybe_packet(self._buffer)
147 self._buffer = b''
148
149 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
150 def __init__(self, router):
151 self._router = router
152 self._decoder = SlipStreamDecoder(self.slip_on_packet)
153 def connectionMade(self): pass
154 def outReceived(self, data):
155 self._decoder.inputdata(data)
156 def slip_on_packet(self, packet):
157 (saddr, daddr) = packet_addrs(packet)
158 if saddr.is_link_local or daddr.is_link_local:
159 log_discard(packet, saddr, daddr, 'link-local')
160 return
161 self._router(packet, saddr, daddr)
162 def processEnded(self, status):
163 status.raiseException()
164
165 def start_ipif(command, router):
166 global ipif
167 ipif = _IpifProcessProtocol(router)
168 reactor.spawnProcess(ipif,
169 '/bin/sh',['sh','-xc', command],
170 childFDs={0:'w', 1:'r', 2:2},
171 env=None)
172
173 def queue_inbound(packet):
174 ipif.transport.write(slip.delimiter)
175 ipif.transport.write(slip.encode(packet))
176 ipif.transport.write(slip.delimiter)
177
178 #---------- packet queue ----------
179
180 class PacketQueue():
181 def __init__(self, max_queue_time):
182 self._max_queue_time = max_queue_time
183 self._pq = collections.deque() # packets
184
185 def append(self, packet):
186 self._pq.append((time.monotonic(), packet))
187
188 def nonempty(self):
189 while True:
190 try: (queuetime, packet) = self._pq[0]
191 except IndexError: return False
192
193 age = time.monotonic() - queuetime
194 if age > self._max_queue_time:
195 # strip old packets off the front
196 self._pq.popleft()
197 continue
198
199 return True
200
201 def process(self, sizequery, moredata, max_batch):
202 # sizequery() should return size of batch so far
203 # moredata(s) should add s to batch
204 while True:
205 try: (dummy, packet) = self._pq[0]
206 except IndexError: break
207
208 encoded = slip.encode(packet)
209 sofar = sizequery()
210
211 if sofar > 0:
212 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
213 break
214 moredata(slip.delimiter)
215
216 moredata(encoded)
217 self._pq.popleft()
218
219 #---------- error handling ----------
220
221 def crash(err):
222 print('CRASH ', err, file=sys.stderr)
223 try: reactor.stop()
224 except twisted.internet.error.ReactorNotRunning: pass
225
226 def crash_on_defer(defer):
227 defer.addErrback(lambda err: crash(err))
228
229 def crash_on_critical(event):
230 if event.get('log_level') >= LogLevel.critical:
231 crash(twisted.logger.formatEvent(event))
232
233 #---------- config processing ----------
234
235 def process_cfg_common_always():
236 global mtu
237 c.mtu = cfg.get('virtual','mtu')
238
239 def process_cfg_ipif(section, varmap):
240 for d, s in varmap:
241 try: v = getattr(c, s)
242 except AttributeError: continue
243 setattr(c, d, v)
244
245 print(repr(c))
246
247 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
248
249 def process_cfg_network():
250 c.network = ipnetwork(cfg.get('virtual','network'))
251 if c.network.num_addresses < 3 + 2:
252 raise ValueError('network needs at least 2^3 addresses')
253
254 def process_cfg_server():
255 try:
256 c.server = cfg.get('virtual','server')
257 except NoOptionError:
258 process_cfg_network()
259 c.server = next(c.network.hosts())
260
261 class ServerAddr():
262 def __init__(self, port, addrspec):
263 self.port = port
264 # also self.addr
265 try:
266 self.addr = ipaddress.IPv4Address(addrspec)
267 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
268 self._inurl = b'%s'
269 except AddressValueError:
270 self.addr = ipaddress.IPv6Address(addrspec)
271 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
272 self._inurl = b'[%s]'
273 def make_endpoint(self):
274 return self._endpointfactory(reactor, self.port, self.addr)
275 def url(self):
276 url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
277 if self.port != 80: url += b':%d' % self.port
278 url += b'/'
279 return url
280
281 def process_cfg_saddrs():
282 try: port = cfg.getint('server','port')
283 except NoOptionError: port = 80
284
285 c.saddrs = [ ]
286 for addrspec in cfg.get('server','addrs').split():
287 sa = ServerAddr(port, addrspec)
288 c.saddrs.append(sa)
289
290 def process_cfg_clients(constructor):
291 c.clients = [ ]
292 for cs in cfg.sections():
293 if not (':' in cs or '.' in cs): continue
294 ci = ipaddr(cs)
295 pw = cfg.get(cs, 'password')
296 pw = pw.encode('utf-8')
297 constructor(ci,cs,pw)
298
299 #---------- startup ----------
300
301 def common_startup():
302 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
303
304 optparser.add_option('-c', '--config', dest='configfile',
305 default='/etc/hippotat/config')
306 (opts, args) = optparser.parse_args()
307 if len(args): optparser.error('no non-option arguments please')
308
309 re = regexp.compile('#.*')
310 cfg.read_string(re.sub('', defcfg))
311 cfg.read(opts.configfile)
312
313 def common_run():
314 reactor.run()
315 print('CRASHED (end)', file=sys.stderr)