wip
[hippotat] / server
1 #!/usr/bin/python3
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 import twisted
10 import twisted.internet
11 import twisted.internet.endpoints
12 from twisted.internet import reactor
13 from twisted.web.server import NOT_DONE_YET
14 from twisted.logger import LogLevel
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 #import twisted.web.server import Site
20 #from twisted.web.resource import Resource
21
22 from optparse import OptionParser
23 from configparser import ConfigParser
24 from configparser import NoOptionError
25
26 import collections
27
28 import syslog
29
30 clients = { }
31
32 def ipaddr(input):
33 try:
34 r = ipaddress.IPv4Address(input)
35 except AddressValueError:
36 r = ipaddress.IPv6Address(input)
37 return r
38
39 def ipnetwork(input):
40 try:
41 r = ipaddress.IPv4Network(input)
42 except NetworkValueError:
43 r = ipaddress.IPv6Network(input)
44 return r
45
46 defcfg = '''
47 [DEFAULT]
48 max_batch_down = 65536
49 max_queue_time = 10
50 max_request_time = 54
51
52 [virtual]
53 mtu = 1500
54 # network
55 # [host]
56 # [relay]
57
58 [server]
59 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
60 addrs = 127.0.0.1 ::1
61 port = 8099
62
63 [limits]
64 max_batch_down = 262144
65 max_queue_time = 121
66 max_request_time = 121
67 '''
68
69 #---------- error handling ----------
70
71 def crash(err):
72 print('CRASH ', err, file=sys.stderr)
73 try: reactor.stop()
74 except twisted.internet.error.ReactorNotRunning: pass
75
76 def crash_on_defer(defer):
77 defer.addErrback(lambda err: crash(err))
78
79 def crash_on_critical(event):
80 if event.get('log_level') >= LogLevel.critical:
81 crash(twisted.logger.formatEvent(event))
82
83 #---------- "router" ----------
84
85 def route(packet, saddr, daddr):
86 print('TRACE ', saddr, daddr, packet)
87 try: client = clients[daddr]
88 except KeyError: dclient = None
89 if dclient is not None:
90 dclient.queue_outbound(packet)
91 elif saddr.is_link_local or daddr.is_link_local:
92 log_discard(packet, saddr, daddr, 'link-local')
93 elif daddr == host or daddr not in network:
94 print('TRACE INBOUND ', saddr, daddr, packet)
95 queue_inbound(packet)
96 elif daddr == relay:
97 log_discard(packet, saddr, daddr, 'relay')
98 else:
99 log_discard(packet, saddr, daddr, 'no client')
100
101 def log_discard(packet, saddr, daddr, why):
102 print('DROP ', saddr, daddr, why)
103 # syslog.syslog(syslog.LOG_DEBUG,
104 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
105
106 #---------- ipif (slip subprocess) ----------
107
108 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
109 def __init__(self):
110 self._buffer = b''
111 def connectionMade(self): pass
112 def outReceived(self, data):
113 #print('RECV ', repr(data))
114 self._buffer += data
115 packets = slip_decode(self._buffer)
116 self._buffer = packets.pop()
117 for packet in packets:
118 if not len(packet): continue
119 (saddr, daddr) = packet_addrs(packet)
120 route(packet, saddr, daddr)
121 def processEnded(self, status):
122 status.raiseException()
123
124 def start_ipif():
125 global ipif
126 ipif = IpifProcessProtocol()
127 reactor.spawnProcess(ipif,
128 '/bin/sh',['sh','-xc', ipif_command],
129 childFDs={0:'w', 1:'r', 2:2})
130
131 def queue_inbound(packet):
132 ipif.transport.write(slip_delimiter)
133 ipif.transport.write(slip_encode(packet))
134 ipif.transport.write(slip_delimiter)
135
136 #---------- SLIP handling ----------
137
138 slip_end = b'\300'
139 slip_esc = b'\333'
140 slip_esc_end = b'\334'
141 slip_esc_esc = b'\335'
142 slip_delimiter = slip_end
143
144 def slip_encode(packet):
145 return (packet
146 .replace(slip_esc, slip_esc + slip_esc_esc)
147 .replace(slip_end, slip_esc + slip_esc_end))
148
149 def slip_decode(data):
150 print('DECODE ', repr(data))
151 out = []
152 for packet in data.split(slip_end):
153 pdata = b''
154 while True:
155 eix = packet.find(slip_esc)
156 if eix == -1:
157 pdata += packet
158 break
159 #print('ESC ', repr((pdata, packet, eix)))
160 pdata += packet[0 : eix]
161 ck = packet[eix+1]
162 #print('ESC... %o' % ck)
163 if ck == slip_esc_esc[0]: pdata += slip_esc
164 elif ck == slip_esc_end[0]: pdata += slip_end
165 else: raise ValueError('invalid SLIP escape')
166 packet = packet[eix+2 : ]
167 out.append(pdata)
168 print('DECODED ', repr(out))
169 return out
170
171 #---------- packet parsing ----------
172
173 def packet_addrs(packet):
174 version = packet[0] >> 4
175 if version == 4:
176 addrlen = 4
177 saddroff = 3*4
178 factory = ipaddress.IPv4Address
179 elif version == 6:
180 addrlen = 16
181 saddroff = 2*4
182 factory = ipaddress.IPv6Address
183 else:
184 raise ValueError('unsupported IP version %d' % version)
185 saddr = factory(packet[ saddroff : saddroff + addrlen ])
186 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
187 return (saddr, daddr)
188
189 #---------- client ----------
190
191 class Client():
192 def __init__(self, ip, cs):
193 # instance data members
194 self._ip = ip
195 self._cs = cs
196 self.pw = cfg.get(cs, 'password')
197 self._rq = collections.deque() # requests
198 self._pq = collections.deque() # packets
199 # plus from config:
200 # .max_batch_down
201 # .max_queue_time
202 # .max_request_time
203 for k in ('max_batch_down','max_queue_time','max_request_time'):
204 req = cfg.getint(cs, k)
205 limit = cfg.getint('limits',k)
206 self.__dict__[k] = min(req, limit)
207
208 def process_arriving_data(self, d):
209 for packet in slip_decode(d):
210 (saddr, daddr) = packet_addrs(packet)
211 if saddr != self._ip:
212 raise ValueError('wrong source address %s' % saddr)
213 route(packet, saddr, daddr)
214
215 def _req_cancel(self, request):
216 request.finish()
217
218 def _req_error(self, err, request):
219 self._req_cancel(request)
220
221 def queue_outbound(self, packet):
222 self._pq.append((time.monotonic(), packet))
223
224 def http_request(self, request):
225 request.setHeader('Content-Type','application/octet-stream')
226 reactor.callLater(self.max_request_time, self._req_cancel, request)
227 request.notifyFinish().addErrback(self._req_error, request)
228 self._rq.append(request)
229 self._check_outbound()
230
231 def _check_outbound(self):
232 while True:
233 try: request = self._rq[0]
234 except IndexError: request = None
235 if request and request.finished:
236 self._rq.popleft()
237 continue
238
239 # now request is an unfinished request, or None
240 try: (queuetime, packet) = self._pq[0]
241 except IndexError:
242 # no packets, oh well
243 break
244
245 age = time.monotonic() - queuetime
246 if age > self.max_queue_time:
247 self._pq.popleft()
248 continue
249
250 if request is None:
251 # no request
252 break
253
254 # request, and also some non-expired packets
255 while True:
256 try: (dummy, packet) = self._pq[0]
257 except IndexError: break
258
259 encoded = slip_encode(packet)
260
261 if request.sentLength > 0:
262 if (request.sentLength + len(slip_delimiter)
263 + len(encoded) > self.max_batch_down):
264 break
265 request.write(slip_delimiter)
266
267 request.write(encoded)
268 self._pq.popLeft()
269
270 assert(request.sentLength)
271 self._rq.popLeft()
272 request.finish()
273 # round again, looking for more to do
274
275 class IphttpResource(twisted.web.resource.Resource):
276 isLeaf = True
277 def render_POST(self, request):
278 # find client, update config, etc.
279 ci = ipaddr(request.args['i'])
280 c = clients[ci]
281 pw = request.args['pw']
282 if pw != c.pw: raise ValueError('bad password')
283
284 # update config
285 for r, w in (('mbd', 'max_batch_down'),
286 ('mqt', 'max_queue_time'),
287 ('mrt', 'max_request_time')):
288 try: v = request.args[r]
289 except KeyError: continue
290 v = int(v)
291 c.__dict__[w] = v
292
293 try: d = request.args['d']
294 except KeyError: d = ''
295
296 c.process_arriving_data(d)
297 c.new_request(request)
298
299 def render_GET(self, request):
300 return b'<html><body>hippotit</body></html>'
301
302 def start_http():
303 resource = IphttpResource()
304 site = twisted.web.server.Site(resource)
305 for addrspec in cfg.get('server','addrs').split():
306 try:
307 addr = ipaddress.IPv4Address(addrspec)
308 endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
309 except AddressValueError:
310 addr = ipaddress.IPv6Address(addrspec)
311 endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
312 ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
313 crash_on_defer(ep.listen(site))
314
315 #---------- config and setup ----------
316
317 def process_cfg():
318 global network
319 global host
320 global relay
321 global ipif_command
322
323 network = ipnetwork(cfg.get('virtual','network'))
324 if network.num_addresses < 3 + 2:
325 raise ValueError('network needs at least 2^3 addresses')
326
327 try:
328 host = cfg.get('virtual','host')
329 except NoOptionError:
330 host = next(network.hosts())
331
332 try:
333 relay = cfg.get('virtual','relay')
334 except NoOptionError:
335 for search in network.hosts():
336 if search == host: continue
337 relay = search
338 break
339
340 for cs in cfg.sections():
341 if not (':' in cs or '.' in cs): continue
342 ci = ipaddr(cs)
343 if ci not in network:
344 raise ValueError('client %s not in network' % ci)
345 if ci in clients:
346 raise ValueError('multiple client cfg sections for %s' % ci)
347 clients[ci] = Client(ci, cs)
348
349 global mtu
350 mtu = cfg.get('virtual','mtu')
351
352 iic_vars = { }
353 for k in ('host','relay','mtu','network'):
354 iic_vars[k] = globals()[k]
355
356 ipif_command = cfg.get('server','ipif', vars=iic_vars)
357
358 def startup():
359 global cfg
360
361 op = OptionParser()
362 op.add_option('-c', '--config', dest='configfile',
363 default='/etc/hippottd/server.conf')
364 global opts
365 (opts, args) = op.parse_args()
366 if len(args): op.error('no non-option arguments please')
367
368 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
369
370 cfg = ConfigParser()
371 cfg.read_string(defcfg)
372 cfg.read(opts.configfile)
373 process_cfg()
374
375 start_ipif()
376 start_http()
377
378 startup()
379 reactor.run()
380 print('CRASHED (end)', file=sys.stderr)