wip
[hippotat] / server
1 #!/usr/bin/python3
2
3 import sys
4 import os
5
6 import twisted
7 import twisted.internet
8 import twisted.internet.endpoints
9 from twisted.internet import reactor
10 from twisted.web.server import NOT_DONE_YET
11 from twisted.logger import LogLevel
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 clients = { }
28
29 def ipaddr(input):
30 try:
31 r = ipaddress.IPv4Address(input)
32 except AddressValueError:
33 r = ipaddress.IPv6Address(input)
34 return r
35
36 def ipnetwork(input):
37 try:
38 r = ipaddress.IPv4Network(input)
39 except NetworkValueError:
40 r = ipaddress.IPv6Network(input)
41 return r
42
43 defcfg = '''
44 [DEFAULT]
45 max_batch_down = 65536
46 max_queue_time = 10
47 max_request_time = 54
48
49 [virtual]
50 mtu = 1500
51 # network
52 # [host]
53 # [relay]
54
55 [server]
56 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
57 addrs = 127.0.0.1 ::1
58 port = 80
59
60 [limits]
61 max_batch_down = 262144
62 max_queue_time = 121
63 max_request_time = 121
64 '''
65
66 #---------- "router" ----------
67
68 def route(packet, saddr, daddr):
69 print('TRACE ', saddr, daddr, packet)
70 try: client = clients[daddr]
71 except KeyError: dclient = None
72 if dclient is not None:
73 dclient.queue_outbound(packet)
74 elif saddr.is_link_local or daddr.is_link_local:
75 log_discard(packet, saddr, daddr, 'link-local')
76 elif daddr == host or daddr not in network:
77 print('TRACE INBOUND ', saddr, daddr, packet)
78 queue_inbound(packet)
79 elif daddr == relay:
80 log_discard(packet, saddr, daddr, 'relay')
81 else:
82 log_discard(packet, saddr, daddr, 'no client')
83
84 def log_discard(packet, saddr, daddr, why):
85 print('DROP ', saddr, daddr, why)
86 # syslog.syslog(syslog.LOG_DEBUG,
87 # 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
88
89 #---------- ipif (slip subprocess) ----------
90
91 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
92 def __init__(self):
93 self._buffer = b''
94 def connectionMade(self): pass
95 def outReceived(self, data):
96 #print('RECV ', repr(data))
97 self._buffer += data
98 packets = slip_decode(self._buffer)
99 self._buffer = packets.pop()
100 for packet in packets:
101 if not len(packet): continue
102 (saddr, daddr) = packet_addrs(packet)
103 route(packet, saddr, daddr)
104 def processEnded(self, status):
105 status.raiseException()
106
107 def start_ipif():
108 global ipif
109 ipif = IpifProcessProtocol()
110 reactor.spawnProcess(ipif,
111 '/bin/sh',['sh','-xc', ipif_command],
112 childFDs={0:'w', 1:'r', 2:2})
113
114 def queue_inbound(packet):
115 ipif.transport.write(slip_delimiter)
116 ipif.transport.write(slip_encode(packet))
117 ipif.transport.write(slip_delimiter)
118
119 #---------- SLIP handling ----------
120
121 slip_end = b'\300'
122 slip_esc = b'\333'
123 slip_esc_end = b'\334'
124 slip_esc_esc = b'\335'
125 slip_delimiter = slip_end
126
127 def slip_encode(packet):
128 return (packet
129 .replace(slip_esc, slip_esc + slip_esc_esc)
130 .replace(slip_end, slip_esc + slip_esc_end))
131
132 def slip_decode(data):
133 print('DECODE ', repr(data))
134 out = []
135 for packet in data.split(slip_end):
136 pdata = b''
137 while True:
138 eix = packet.find(slip_esc)
139 if eix == -1:
140 pdata += packet
141 break
142 #print('ESC ', repr((pdata, packet, eix)))
143 pdata += packet[0 : eix]
144 ck = packet[eix+1]
145 if ck == slip_esc_esc: pdata += slip_esc
146 elif ck == slip_esc_end: pdata += slip_end
147 else: raise ValueError('invalid SLIP escape')
148 packet = packet[eix+2 : ]
149 out.append(pdata)
150 print('DECODED ', repr(out))
151 return out
152
153 #---------- packet parsing ----------
154
155 def packet_addrs(packet):
156 version = packet[0] >> 4
157 if version == 4:
158 addrlen = 4
159 saddroff = 3*4
160 factory = ipaddress.IPv4Address
161 elif version == 6:
162 addrlen = 16
163 saddroff = 2*4
164 factory = ipaddress.IPv6Address
165 else:
166 raise ValueError('unsupported IP version %d' % version)
167 saddr = factory(packet[ saddroff : saddroff + addrlen ])
168 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
169 return (saddr, daddr)
170
171 #---------- client ----------
172
173 class Client():
174 def __init__(self, ip, cs):
175 # instance data members
176 self._ip = ip
177 self._cs = cs
178 self.pw = cfg.get(cs, 'password')
179 self._rq = collections.deque() # requests
180 self._pq = collections.deque() # packets
181 # plus from config:
182 # .max_batch_down
183 # .max_queue_time
184 # .max_request_time
185 for k in ('max_batch_down','max_queue_time','max_request_time'):
186 req = cfg.getint(cs, k)
187 limit = cfg.getint('limits',k)
188 self.__dict__[k] = min(req, limit)
189
190 def process_arriving_data(self, d):
191 for packet in slip_decode(d):
192 (saddr, daddr) = packet_addrs(packet)
193 if saddr != self._ip:
194 raise ValueError('wrong source address %s' % saddr)
195 route(packet, saddr, daddr)
196
197 def _req_cancel(self, request):
198 request.finish()
199
200 def _req_error(self, err, request):
201 self._req_cancel(request)
202
203 def queue_outbound(self, packet):
204 self._pq.append((time.monotonic(), packet))
205
206 def http_request(self, request):
207 request.setHeader('Content-Type','application/octet-stream')
208 reactor.callLater(self.max_request_time, self._req_cancel, request)
209 request.notifyFinish().addErrback(self._req_error, request)
210 self._rq.append(request)
211 self._check_outbound()
212
213 def _check_outbound(self):
214 while True:
215 try: request = self._rq[0]
216 except IndexError: request = None
217 if request and request.finished:
218 self._rq.popleft()
219 continue
220
221 # now request is an unfinished request, or None
222 try: (queuetime, packet) = self._pq[0]
223 except IndexError:
224 # no packets, oh well
225 break
226
227 age = time.monotonic() - queuetime
228 if age > self.max_queue_time:
229 self._pq.popleft()
230 continue
231
232 if request is None:
233 # no request
234 break
235
236 # request, and also some non-expired packets
237 while True:
238 try: (dummy, packet) = self._pq[0]
239 except IndexError: break
240
241 encoded = slip_encode(packet)
242
243 if request.sentLength > 0:
244 if (request.sentLength + len(slip_delimiter)
245 + len(encoded) > self.max_batch_down):
246 break
247 request.write(slip_delimiter)
248
249 request.write(encoded)
250 self._pq.popLeft()
251
252 assert(request.sentLength)
253 self._rq.popLeft()
254 request.finish()
255 # round again, looking for more to do
256
257 class IphttpResource(twisted.web.resource.Resource):
258 def render_POST(self, request):
259 # find client, update config, etc.
260 ci = ipaddr(request.args['i'])
261 c = clients[ci]
262 pw = request.args['pw']
263 if pw != c.pw: raise ValueError('bad password')
264
265 # update config
266 for r, w in (('mbd', 'max_batch_down'),
267 ('mqt', 'max_queue_time'),
268 ('mrt', 'max_request_time')):
269 try: v = request.args[r]
270 except KeyError: continue
271 v = int(v)
272 c.__dict__[w] = v
273
274 try: d = request.args['d']
275 except KeyError: d = ''
276
277 c.process_arriving_data(d)
278 c.new_request(request)
279
280 def start_http():
281 resource = IphttpResource()
282 sitefactory = twisted.web.server.Site(resource)
283 for addrspec in cfg.get('server','addrs').split():
284 try:
285 addr = ipaddress.IPv4Address(addrspec)
286 endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
287 except AddressValueError:
288 addr = ipaddress.IPv6Address(addrspec)
289 endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
290 ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
291 ep.listen(sitefactory)
292
293 #---------- config and setup ----------
294
295 def process_cfg():
296 global network
297 global host
298 global relay
299 global ipif_command
300
301 network = ipnetwork(cfg.get('virtual','network'))
302 if network.num_addresses < 3 + 2:
303 raise ValueError('network needs at least 2^3 addresses')
304
305 try:
306 host = cfg.get('virtual','host')
307 except NoOptionError:
308 host = next(network.hosts())
309
310 try:
311 relay = cfg.get('virtual','relay')
312 except NoOptionError:
313 for search in network.hosts():
314 if search == host: continue
315 relay = search
316 break
317
318 for cs in cfg.sections():
319 if not (':' in cs or '.' in cs): continue
320 ci = ipaddr(cs)
321 if ci not in network:
322 raise ValueError('client %s not in network' % ci)
323 if ci in clients:
324 raise ValueError('multiple client cfg sections for %s' % ci)
325 clients[ci] = Client(ci, cs)
326
327 global mtu
328 mtu = cfg.get('virtual','mtu')
329
330 iic_vars = { }
331 for k in ('host','relay','mtu','network'):
332 iic_vars[k] = globals()[k]
333
334 ipif_command = cfg.get('server','ipif', vars=iic_vars)
335
336 def crash_on_critical(event):
337 if event.get('log_level') >= LogLevel.critical:
338 print('crashing: ', twisted.logger.formatEvent(event), file=sys.stderr)
339 #print('crashing!', file=sys.stderr)
340 #os._exit(1)
341 try: reactor.stop()
342 except twisted.internet.error.ReactorNotRunning: pass
343
344 def startup():
345 global cfg
346
347 op = OptionParser()
348 op.add_option('-c', '--config', dest='configfile',
349 default='/etc/hippottd/server.conf')
350 global opts
351 (opts, args) = op.parse_args()
352 if len(args): op.error('no non-option arguments please')
353
354 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
355
356 cfg = ConfigParser()
357 cfg.read_string(defcfg)
358 cfg.read(opts.configfile)
359 process_cfg()
360
361 start_ipif()
362 start_http()
363
364 startup()
365 reactor.run()