fixes
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 import re as regexp
23
24 import hippotat.slip as slip
25
26 defcfg = '''
27 [DEFAULT]
28 #[<client>] overrides
29 max_batch_down = 65536 # used by server, subject to [limits]
30 max_queue_time = 10 # used by server, subject to [limits]
31 max_request_time = 54 # used by server, subject to [limits]
32 target_requests_outstanding = 3 # must match; subject to [limits] on server
33 max_requests_outstanding = 4 # used by client
34 max_batch_up = 4000 # used by client
35 http_timeout = 30 # used by client
36
37 #[server] or [<client>] overrides
38 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
39 # extra interpolations: %(local)s %(peer)s %(rnet)s
40 # obtained on server [virtual]server [virtual]relay [virtual]network
41 # from on client <client> [virtual]server [virtual]routes
42
43 [virtual]
44 mtu = 1500
45 routes = ''
46 # network = <prefix>/<len> # mandatory for server
47 # server = <ipaddr> # used by both, default is computed from `network'
48 # relay = <ipaddr> # used by server, default from `network' and `server'
49 # default server is first host in network
50 # default relay is first host which is not server
51
52 [server]
53 # addrs = 127.0.0.1 ::1 # mandatory for server
54 port = 80 # used by server
55 # url # used by client; default from first `addrs' and `port'
56
57 # [<client-ip4-or-ipv6-address>]
58 # password = <password> # used by both, must match
59
60 [limits]
61 max_batch_down = 262144 # used by server
62 max_queue_time = 121 # used by server
63 max_request_time = 121 # used by server
64 target_requests_outstanding = 10 # used by server
65 '''
66
67 # these need to be defined here so that they can be imported by import *
68 cfg = ConfigParser()
69 optparser = OptionParser()
70
71 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
72 def mime_translate(s):
73 # SLIP-encoded packets cannot contain ESC ESC.
74 # Swap `-' and ESC. The result cannot contain `--'
75 return s.translate(_mimetrans)
76
77 class ConfigResults:
78 def __init__(self, d = { }):
79 self.__dict__ = d
80 def __repr__(self):
81 return 'ConfigResults('+repr(self.__dict__)+')'
82
83 c = ConfigResults()
84
85 def log_discard(packet, saddr, daddr, why):
86 print('DROP ', saddr, daddr, why)
87 # syslog.syslog(syslog.LOG_DEBUG,
88 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
89
90 #---------- packet parsing ----------
91
92 def packet_addrs(packet):
93 version = packet[0] >> 4
94 if version == 4:
95 addrlen = 4
96 saddroff = 3*4
97 factory = ipaddress.IPv4Address
98 elif version == 6:
99 addrlen = 16
100 saddroff = 2*4
101 factory = ipaddress.IPv6Address
102 else:
103 raise ValueError('unsupported IP version %d' % version)
104 saddr = factory(packet[ saddroff : saddroff + addrlen ])
105 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
106 return (saddr, daddr)
107
108 #---------- address handling ----------
109
110 def ipaddr(input):
111 try:
112 r = ipaddress.IPv4Address(input)
113 except AddressValueError:
114 r = ipaddress.IPv6Address(input)
115 return r
116
117 def ipnetwork(input):
118 try:
119 r = ipaddress.IPv4Network(input)
120 except NetworkValueError:
121 r = ipaddress.IPv6Network(input)
122 return r
123
124 #---------- ipif (SLIP) subprocess ----------
125
126 class SlipStreamDecoder():
127 def __init__(self, on_packet):
128 # we will call packet(<packet>)
129 self._buffer = b''
130 self._on_packet = on_packet
131
132 def inputdata(self, data):
133 #print('SLIP-GOT ', repr(data))
134 self._buffer += data
135 packets = slip.decode(self._buffer)
136 self._buffer = packets.pop()
137 for packet in packets:
138 self._maybe_packet(packet)
139
140 def _maybe_packet(self, packet):
141 if len(packet):
142 self._on_packet(packet)
143
144 def flush(self):
145 self._maybe_packet(self._buffer)
146 self._buffer = b''
147
148 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
149 def __init__(self, router):
150 self._router = router
151 self._decoder = SlipStreamDecoder(self.slip_on_packet)
152 def connectionMade(self): pass
153 def outReceived(self, data):
154 self._decoder.inputdata(data)
155 def slip_on_packet(self, packet):
156 (saddr, daddr) = packet_addrs(packet)
157 if saddr.is_link_local or daddr.is_link_local:
158 log_discard(packet, saddr, daddr, 'link-local')
159 return
160 self._router(packet, saddr, daddr)
161 def processEnded(self, status):
162 status.raiseException()
163
164 def start_ipif(command, router):
165 global ipif
166 ipif = _IpifProcessProtocol(router)
167 reactor.spawnProcess(ipif,
168 '/bin/sh',['sh','-xc', command],
169 childFDs={0:'w', 1:'r', 2:2})
170
171 def queue_inbound(packet):
172 ipif.transport.write(slip.delimiter)
173 ipif.transport.write(slip.encode(packet))
174 ipif.transport.write(slip.delimiter)
175
176 #---------- packet queue ----------
177
178 class PacketQueue():
179 def __init__(self, max_queue_time):
180 self._max_queue_time = max_queue_time
181 self._pq = collections.deque() # packets
182
183 def append(self, packet):
184 self._pq.append((time.monotonic(), packet))
185
186 def nonempty(self):
187 while True:
188 try: (queuetime, packet) = self._pq[0]
189 except IndexError: return False
190
191 age = time.monotonic() - queuetime
192 if age > self.max_queue_time:
193 # strip old packets off the front
194 self._pq.popleft()
195 continue
196
197 return True
198
199 def process(self, sizequery, moredata, max_batch):
200 # sizequery() should return size of batch so far
201 # moredata(s) should add s to batch
202 while True:
203 try: (dummy, packet) = self._pq[0]
204 except IndexError: break
205
206 encoded = slip.encode(packet)
207 sofar = sizequery()
208
209 if sofar > 0:
210 if sofar + len(slip.delimiter) + len(encoded) > max_batch:
211 break
212 moredata(slip.delimiter)
213
214 moredata(encoded)
215 self._pq.popLeft()
216
217 #---------- error handling ----------
218
219 def crash(err):
220 print('CRASH ', err, file=sys.stderr)
221 try: reactor.stop()
222 except twisted.internet.error.ReactorNotRunning: pass
223
224 def crash_on_defer(defer):
225 defer.addErrback(lambda err: crash(err))
226
227 def crash_on_critical(event):
228 if event.get('log_level') >= LogLevel.critical:
229 crash(twisted.logger.formatEvent(event))
230
231 #---------- config processing ----------
232
233 def process_cfg_common_always():
234 global mtu
235 c.mtu = cfg.get('virtual','mtu')
236
237 def process_cfg_ipif(section, varmap):
238 for d, s in varmap:
239 try: v = getattr(c, s)
240 except AttributeError: continue
241 setattr(c, d, v)
242
243 print(repr(c))
244
245 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
246
247 def process_cfg_network():
248 c.network = ipnetwork(cfg.get('virtual','network'))
249 if c.network.num_addresses < 3 + 2:
250 raise ValueError('network needs at least 2^3 addresses')
251
252 def process_cfg_server():
253 try:
254 c.server = cfg.get('virtual','server')
255 except NoOptionError:
256 process_cfg_network()
257 c.server = next(c.network.hosts())
258
259 class ServerAddr():
260 def __init__(self, port, addrspec):
261 self.port = port
262 # also self.addr
263 try:
264 self.addr = ipaddress.IPv4Address(addrspec)
265 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
266 self._inurl = '%s'
267 except AddressValueError:
268 self.addr = ipaddress.IPv6Address(addrspec)
269 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
270 self._inurl = '[%s]'
271 def make_endpoint(self):
272 return self._endpointfactory(reactor, self.port, self.addr)
273 def url(self):
274 url = 'http://' + (self._inurl % self.addr)
275 if self.port != 80: url += ':%d' % self.port
276 url += '/'
277 return url
278
279 def process_cfg_saddrs():
280 try: port = cfg.getint('server','port')
281 except NoOptionError: port = 80
282
283 c.saddrs = [ ]
284 for addrspec in cfg.get('server','addrs').split():
285 sa = ServerAddr(port, addrspec)
286 c.saddrs.append(sa)
287
288 def process_cfg_clients(constructor):
289 c.clients = [ ]
290 for cs in cfg.sections():
291 if not (':' in cs or '.' in cs): continue
292 ci = ipaddr(cs)
293 pw = cfg.get(cs, 'password')
294 constructor(ci,cs,pw)
295
296 #---------- startup ----------
297
298 def common_startup():
299 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
300
301 optparser.add_option('-c', '--config', dest='configfile',
302 default='/etc/hippotat/config')
303 (opts, args) = optparser.parse_args()
304 if len(args): optparser.error('no non-option arguments please')
305
306 re = regexp.compile('#.*')
307 cfg.read_string(re.sub('', defcfg))
308 cfg.read(opts.configfile)
309
310 def common_run():
311 reactor.run()
312 print('CRASHED (end)', file=sys.stderr)