wip
[hippotat] / server
1 #!/usr/bin/python3
2
3 import sys
4 import os
5
6 import twisted
7 import twisted.internet
8 import twisted.internet.endpoints
9 from twisted.internet import reactor
10 from twisted.web.server import NOT_DONE_YET
11 from twisted.logger import LogLevel
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 clients = { }
28
29 def ipaddr(input):
30 try:
31 r = ipaddress.IPv4Address(input)
32 except AddressValueError:
33 r = ipaddress.IPv6Address(input)
34 return r
35
36 def ipnetwork(input):
37 try:
38 r = ipaddress.IPv4Network(input)
39 except NetworkValueError:
40 r = ipaddress.IPv6Network(input)
41 return r
42
43 defcfg = '''
44 [DEFAULT]
45 max_batch_down = 65536
46 max_queue_time = 10
47 max_request_time = 54
48
49 [virtual]
50 mtu = 1500
51 # network
52 # [host]
53 # [relay]
54
55 [server]
56 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
57 addrs = 127.0.0.1 ::1
58 port = 80
59
60 [limits]
61 max_batch_down = 262144
62 max_queue_time = 121
63 max_request_time = 121
64 '''
65
66 #---------- "router" ----------
67
68 def route(packet, saddr, daddr):
69 print('TRACE ', saddr, daddr, packet)
70 try: client = clients[daddr]
71 except KeyError: dclient = None
72 if dclient is not None:
73 dclient.queue_outbound(packet)
74 elif daddr.is_multicast:
75 log_discard(packet, saddr, daddr, 'multicast')
76 elif daddr.is_link_local:
77 log_discard(packet, saddr, daddr, 'link-local')
78 elif daddr == host or daddr not in network:
79 print('TRACE INBOUND ', saddr, daddr, packet)
80 queue_inbound(packet)
81 elif daddr == relay:
82 log_discard(packet, saddr, daddr, 'relay')
83 else:
84 log_discard(packet, saddr, daddr, 'no client')
85
86 def log_discard(packet, saddr, daddr, why):
87 print('DROP ', saddr, daddr, why, packet)
88 # syslog.syslog(syslog.LOG_DEBUG,
89 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
90
91 #---------- ipif (slip subprocess) ----------
92
93 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
94 def __init__(self):
95 self._buffer = b''
96 def connectionMade(self): pass
97 def outReceived(self, data):
98 #print('RECV ', repr(data))
99 self._buffer += data
100 packets = slip_decode(self._buffer)
101 self._buffer = packets.pop()
102 for packet in packets:
103 if not len(packet): continue
104 (saddr, daddr) = packet_addrs(packet)
105 route(packet, saddr, daddr)
106 def processEnded(self, status):
107 status.raiseException()
108
109 def start_ipif():
110 global ipif
111 ipif = IpifProcessProtocol()
112 reactor.spawnProcess(ipif,
113 '/bin/sh',['sh','-xc', ipif_command],
114 childFDs={0:'w', 1:'r', 2:2})
115
116 def queue_inbound(packet):
117 ipif.transport.write(slip_delimiter)
118 ipif.transport.write(slip_encode(packet))
119 ipif.transport.write(slip_delimiter)
120
121 #---------- SLIP handling ----------
122
123 slip_end = b'\300'
124 slip_esc = b'\333'
125 slip_esc_end = b'\334'
126 slip_esc_esc = b'\335'
127 slip_delimiter = slip_end
128
129 def slip_encode(packet):
130 return (packet
131 .replace(slip_esc, slip_esc + slip_esc_esc)
132 .replace(slip_end, slip_esc + slip_esc_end))
133
134 def slip_decode(data):
135 print('DECODE ', repr(data))
136 out = []
137 for packet in data.split(slip_end):
138 pdata = b''
139 while True:
140 eix = packet.find(slip_esc)
141 if eix == -1:
142 pdata += packet
143 break
144 #print('ESC ', repr((pdata, packet, eix)))
145 pdata += packet[0 : eix]
146 ck = packet[eix+1]
147 if ck == slip_esc_esc: pdata += slip_esc
148 elif ck == slip_esc_end: pdata += slip_end
149 else: raise ValueError('invalid SLIP escape')
150 packet = packet[eix+2 : ]
151 out.append(pdata)
152 print('DECODED ', repr(out))
153 return out
154
155 #---------- packet parsing ----------
156
157 def packet_addrs(packet):
158 version = packet[0] >> 4
159 if version == 4:
160 addrlen = 4
161 saddroff = 3*4
162 factory = ipaddress.IPv4Address
163 elif version == 6:
164 addrlen = 16
165 saddroff = 2*4
166 factory = ipaddress.IPv6Address
167 else:
168 raise ValueError('unsupported IP version %d' % version)
169 saddr = factory(packet[ saddroff : saddroff + addrlen ])
170 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
171 return (saddr, daddr)
172
173 #---------- client ----------
174
175 class Client():
176 def __init__(self, ip, cs):
177 # instance data members
178 self._ip = ip
179 self._cs = cs
180 self.pw = cfg.get(cs, 'password')
181 self._rq = collections.deque() # requests
182 self._pq = collections.deque() # packets
183 # plus from config:
184 # .max_batch_down
185 # .max_queue_time
186 # .max_request_time
187 for k in ('max_batch_down','max_queue_time','max_request_time'):
188 req = cfg.getint(cs, k)
189 limit = cfg.getint('limits',k)
190 self.__dict__[k] = min(req, limit)
191
192 def process_arriving_data(self, d):
193 for packet in slip_decode(d):
194 (saddr, daddr) = packet_addrs(packet)
195 if saddr != self._ip:
196 raise ValueError('wrong source address %s' % saddr)
197 route(packet, saddr, daddr)
198
199 def _req_cancel(self, request):
200 request.finish()
201
202 def _req_error(self, err, request):
203 self._req_cancel(request)
204
205 def queue_outbound(self, packet):
206 self._pq.append((time.monotonic(), packet))
207
208 def http_request(self, request):
209 request.setHeader('Content-Type','application/octet-stream')
210 reactor.callLater(self.max_request_time, self._req_cancel, request)
211 request.notifyFinish().addErrback(self._req_error, request)
212 self._rq.append(request)
213 self._check_outbound()
214
215 def _check_outbound(self):
216 while True:
217 try: request = self._rq[0]
218 except IndexError: request = None
219 if request and request.finished:
220 self._rq.popleft()
221 continue
222
223 # now request is an unfinished request, or None
224 try: (queuetime, packet) = self._pq[0]
225 except IndexError:
226 # no packets, oh well
227 break
228
229 age = time.monotonic() - queuetime
230 if age > self.max_queue_time:
231 self._pq.popleft()
232 continue
233
234 if request is None:
235 # no request
236 break
237
238 # request, and also some non-expired packets
239 while True:
240 try: (dummy, packet) = self._pq[0]
241 except IndexError: break
242
243 encoded = slip_encode(packet)
244
245 if request.sentLength > 0:
246 if (request.sentLength + len(slip_delimiter)
247 + len(encoded) > self.max_batch_down):
248 break
249 request.write(slip_delimiter)
250
251 request.write(encoded)
252 self._pq.popLeft()
253
254 assert(request.sentLength)
255 self._rq.popLeft()
256 request.finish()
257 # round again, looking for more to do
258
259 class IphttpResource(twisted.web.resource.Resource):
260 def render_POST(self, request):
261 # find client, update config, etc.
262 ci = ipaddr(request.args['i'])
263 c = clients[ci]
264 pw = request.args['pw']
265 if pw != c.pw: raise ValueError('bad password')
266
267 # update config
268 for r, w in (('mbd', 'max_batch_down'),
269 ('mqt', 'max_queue_time'),
270 ('mrt', 'max_request_time')):
271 try: v = request.args[r]
272 except KeyError: continue
273 v = int(v)
274 c.__dict__[w] = v
275
276 try: d = request.args['d']
277 except KeyError: d = ''
278
279 c.process_arriving_data(d)
280 c.new_request(request)
281
282 def start_http():
283 resource = IphttpResource()
284 sitefactory = twisted.web.server.Site(resource)
285 for addrspec in cfg.get('server','addrs').split():
286 try:
287 addr = ipaddress.IPv4Address(addrspec)
288 endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
289 except AddressValueError:
290 addr = ipaddress.IPv6Address(addrspec)
291 endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
292 ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
293 ep.listen(sitefactory)
294
295 #---------- config and setup ----------
296
297 def process_cfg():
298 global network
299 global host
300 global relay
301 global ipif_command
302
303 network = ipnetwork(cfg.get('virtual','network'))
304 if network.num_addresses < 3 + 2:
305 raise ValueError('network needs at least 2^3 addresses')
306
307 try:
308 host = cfg.get('virtual','host')
309 except NoOptionError:
310 host = next(network.hosts())
311
312 try:
313 relay = cfg.get('virtual','relay')
314 except NoOptionError:
315 for search in network.hosts():
316 if search == host: continue
317 relay = search
318 break
319
320 for cs in cfg.sections():
321 if not (':' in cs or '.' in cs): continue
322 ci = ipaddr(cs)
323 if ci not in network:
324 raise ValueError('client %s not in network' % ci)
325 if ci in clients:
326 raise ValueError('multiple client cfg sections for %s' % ci)
327 clients[ci] = Client(ci, cs)
328
329 global mtu
330 mtu = cfg.get('virtual','mtu')
331
332 iic_vars = { }
333 for k in ('host','relay','mtu','network'):
334 iic_vars[k] = globals()[k]
335
336 ipif_command = cfg.get('server','ipif', vars=iic_vars)
337
338 def crash_on_critical(event):
339 if event.get('log_level') >= LogLevel.critical:
340 print('crashing: ', twisted.logger.formatEvent(event), file=sys.stderr)
341 #print('crashing!', file=sys.stderr)
342 #os._exit(1)
343 try: reactor.stop()
344 except twisted.internet.error.ReactorNotRunning: pass
345
346 def startup():
347 global cfg
348
349 op = OptionParser()
350 op.add_option('-c', '--config', dest='configfile',
351 default='/etc/hippottd/server.conf')
352 global opts
353 (opts, args) = op.parse_args()
354 if len(args): op.error('no non-option arguments please')
355
356 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
357
358 cfg = ConfigParser()
359 cfg.read_string(defcfg)
360 cfg.read(opts.configfile)
361 process_cfg()
362
363 start_ipif()
364 start_http()
365
366 startup()
367 reactor.run()