wip
[hippotat] / server
1 #!/usr/bin/python3
2
3 import sys
4 import os
5
6 import twisted
7 import twisted.internet
8 import twisted.internet.endpoints
9 from twisted.internet import reactor
10 from twisted.web.server import NOT_DONE_YET
11 from twisted.logger import LogLevel
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 clients = { }
28
29 def ipaddr(input):
30 try:
31 r = ipaddress.IPv4Address(input)
32 except AddressValueError:
33 r = ipaddress.IPv6Address(input)
34 return r
35
36 def ipnetwork(input):
37 try:
38 r = ipaddress.IPv4Network(input)
39 except NetworkValueError:
40 r = ipaddress.IPv6Network(input)
41 return r
42
43 defcfg = '''
44 [DEFAULT]
45 max_batch_down = 65536
46 max_queue_time = 10
47 max_request_time = 54
48
49 [virtual]
50 mtu = 1500
51 # network
52 # [host]
53 # [relay]
54
55 [server]
56 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
57 addrs = 127.0.0.1 ::1
58 port = 80
59
60 [limits]
61 max_batch_down = 262144
62 max_queue_time = 121
63 max_request_time = 121
64 '''
65
66 #---------- "router" ----------
67
68 def route(packet, saddr, daddr):
69 print('TRACE ', saddr, daddr, packet)
70 try: client = clients[daddr]
71 except KeyError: dclient = None
72 if dclient is not None:
73 dclient.queue_outbound(packet)
74 elif saddr.is_link_local or daddr.is_link_local:
75 log_discard(packet, saddr, daddr, 'link-local')
76 elif daddr == host or daddr not in network:
77 print('TRACE INBOUND ', saddr, daddr, packet)
78 queue_inbound(packet)
79 elif daddr == relay:
80 log_discard(packet, saddr, daddr, 'relay')
81 else:
82 log_discard(packet, saddr, daddr, 'no client')
83
84 def log_discard(packet, saddr, daddr, why):
85 print('DROP ', saddr, daddr, why)
86 # syslog.syslog(syslog.LOG_DEBUG,
87 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
88
89 #---------- ipif (slip subprocess) ----------
90
91 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
92 def __init__(self):
93 self._buffer = b''
94 def connectionMade(self): pass
95 def outReceived(self, data):
96 #print('RECV ', repr(data))
97 self._buffer += data
98 packets = slip_decode(self._buffer)
99 self._buffer = packets.pop()
100 for packet in packets:
101 if not len(packet): continue
102 (saddr, daddr) = packet_addrs(packet)
103 route(packet, saddr, daddr)
104 def processEnded(self, status):
105 status.raiseException()
106
107 def start_ipif():
108 global ipif
109 ipif = IpifProcessProtocol()
110 reactor.spawnProcess(ipif,
111 '/bin/sh',['sh','-xc', ipif_command],
112 childFDs={0:'w', 1:'r', 2:2})
113
114 def queue_inbound(packet):
115 ipif.transport.write(slip_delimiter)
116 ipif.transport.write(slip_encode(packet))
117 ipif.transport.write(slip_delimiter)
118
119 #---------- SLIP handling ----------
120
121 slip_end = b'\300'
122 slip_esc = b'\333'
123 slip_esc_end = b'\334'
124 slip_esc_esc = b'\335'
125 slip_delimiter = slip_end
126
127 def slip_encode(packet):
128 return (packet
129 .replace(slip_esc, slip_esc + slip_esc_esc)
130 .replace(slip_end, slip_esc + slip_esc_end))
131
132 def slip_decode(data):
133 print('DECODE ', repr(data))
134 out = []
135 for packet in data.split(slip_end):
136 pdata = b''
137 while True:
138 eix = packet.find(slip_esc)
139 if eix == -1:
140 pdata += packet
141 break
142 #print('ESC ', repr((pdata, packet, eix)))
143 pdata += packet[0 : eix]
144 ck = packet[eix+1]
145 if ck == slip_esc_esc: pdata += slip_esc
146 elif ck == slip_esc_end: pdata += slip_end
147 else: raise ValueError('invalid SLIP escape')
148 packet = packet[eix+2 : ]
149 out.append(pdata)
150 print('DECODED ', repr(out))
151 return out
152
153 #---------- packet parsing ----------
154
155 def packet_addrs(packet):
156 version = packet[0] >> 4
157 if version == 4:
158 addrlen = 4
159 saddroff = 3*4
160 factory = ipaddress.IPv4Address
161 elif version == 6:
162 addrlen = 16
163 saddroff = 2*4
164 factory = ipaddress.IPv6Address
165 else:
166 raise ValueError('unsupported IP version %d' % version)
167 saddr = factory(packet[ saddroff : saddroff + addrlen ])
168 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
169 return (saddr, daddr)
170
171 #---------- client ----------
172
173 class Client():
174 def __init__(self, ip, cs):
175 # instance data members
176 self._ip = ip
177 self._cs = cs
178 self.pw = cfg.get(cs, 'password')
179 self._rq = collections.deque() # requests
180 self._pq = collections.deque() # packets
181 # plus from config:
182 # .max_batch_down
183 # .max_queue_time
184 # .max_request_time
185 for k in ('max_batch_down','max_queue_time','max_request_time'):
186 req = cfg.getint(cs, k)
187 limit = cfg.getint('limits',k)
188 self.__dict__[k] = min(req, limit)
189
190 def process_arriving_data(self, d):
191 for packet in slip_decode(d):
192 (saddr, daddr) = packet_addrs(packet)
193 if saddr != self._ip:
194 raise ValueError('wrong source address %s' % saddr)
195 route(packet, saddr, daddr)
196
197 def _req_cancel(self, request):
198 request.finish()
199
200 def _req_error(self, err, request):
201 self._req_cancel(request)
202
203 def queue_outbound(self, packet):
204 self._pq.append((time.monotonic(), packet))
205
206 def http_request(self, request):
207 request.setHeader('Content-Type','application/octet-stream')
208 reactor.callLater(self.max_request_time, self._req_cancel, request)
209 request.notifyFinish().addErrback(self._req_error, request)
210 self._rq.append(request)
211 self._check_outbound()
212
213 def _check_outbound(self):
214 while True:
215 try: request = self._rq[0]
216 except IndexError: request = None
217 if request and request.finished:
218 self._rq.popleft()
219 continue
220
221 # now request is an unfinished request, or None
222 try: (queuetime, packet) = self._pq[0]
223 except IndexError:
224 # no packets, oh well
225 break
226
227 age = time.monotonic() - queuetime
228 if age > self.max_queue_time:
229 self._pq.popleft()
230 continue
231
232 if request is None:
233 # no request
234 break
235
236 # request, and also some non-expired packets
237 while True:
238 try: (dummy, packet) = self._pq[0]
239 except IndexError: break
240
241 encoded = slip_encode(packet)
242
243 if request.sentLength > 0:
244 if (request.sentLength + len(slip_delimiter)
245 + len(encoded) > self.max_batch_down):
246 break
247 request.write(slip_delimiter)
248
249 request.write(encoded)
250 self._pq.popLeft()
251
252 assert(request.sentLength)
253 self._rq.popLeft()
254 request.finish()
255 # round again, looking for more to do
256
257 class IphttpResource(twisted.web.resource.Resource):
258 def render_POST(self, request):
259 # find client, update config, etc.
260 ci = ipaddr(request.args['i'])
261 c = clients[ci]
262 pw = request.args['pw']
263 if pw != c.pw: raise ValueError('bad password')
264
265 # update config
266 for r, w in (('mbd', 'max_batch_down'),
267 ('mqt', 'max_queue_time'),
268 ('mrt', 'max_request_time')):
269 try: v = request.args[r]
270 except KeyError: continue
271 v = int(v)
272 c.__dict__[w] = v
273
274 try: d = request.args['d']
275 except KeyError: d = ''
276
277 c.process_arriving_data(d)
278 c.new_request(request)
279
280 def start_http():
281 resource = IphttpResource()
282 sitefactory = twisted.web.server.Site(resource)
283 for addrspec in cfg.get('server','addrs').split():
284 try:
285 addr = ipaddress.IPv4Address(addrspec)
286 endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
287 except AddressValueError:
288 addr = ipaddress.IPv6Address(addrspec)
289 endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
290 ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
291 defer = ep.listen(sitefactory)
292 defer.addErrback(lambda err: err.raiseException())
293
294 #---------- config and setup ----------
295
296 def process_cfg():
297 global network
298 global host
299 global relay
300 global ipif_command
301
302 network = ipnetwork(cfg.get('virtual','network'))
303 if network.num_addresses < 3 + 2:
304 raise ValueError('network needs at least 2^3 addresses')
305
306 try:
307 host = cfg.get('virtual','host')
308 except NoOptionError:
309 host = next(network.hosts())
310
311 try:
312 relay = cfg.get('virtual','relay')
313 except NoOptionError:
314 for search in network.hosts():
315 if search == host: continue
316 relay = search
317 break
318
319 for cs in cfg.sections():
320 if not (':' in cs or '.' in cs): continue
321 ci = ipaddr(cs)
322 if ci not in network:
323 raise ValueError('client %s not in network' % ci)
324 if ci in clients:
325 raise ValueError('multiple client cfg sections for %s' % ci)
326 clients[ci] = Client(ci, cs)
327
328 global mtu
329 mtu = cfg.get('virtual','mtu')
330
331 iic_vars = { }
332 for k in ('host','relay','mtu','network'):
333 iic_vars[k] = globals()[k]
334
335 ipif_command = cfg.get('server','ipif', vars=iic_vars)
336
337 def crash_on_critical(event):
338 if event.get('log_level') >= LogLevel.critical:
339 print('crashing: ', twisted.logger.formatEvent(event), file=sys.stderr)
340 #print('crashing!', file=sys.stderr)
341 #os._exit(1)
342 try: reactor.stop()
343 except twisted.internet.error.ReactorNotRunning: pass
344
345 def startup():
346 global cfg
347
348 op = OptionParser()
349 op.add_option('-c', '--config', dest='configfile',
350 default='/etc/hippottd/server.conf')
351 global opts
352 (opts, args) = op.parse_args()
353 if len(args): op.error('no non-option arguments please')
354
355 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
356
357 cfg = ConfigParser()
358 cfg.read_string(defcfg)
359 cfg.read(opts.configfile)
360 process_cfg()
361
362 start_ipif()
363 start_http()
364
365 startup()
366 reactor.run()