1baaed33c68703758d160ee2d32f9e1b0059b471
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9 import twisted.internet.endpoints
10
11 import ipaddress
12 from ipaddress import AddressValueError
13
14 import hippotat.slip as slip
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 # these need to be defined here so that they can be imported by import *
23 cfg = ConfigParser()
24 optparser = OptionParser()
25
26 class ConfigResults:
27 def __init__(self, d = { }):
28 self.__dict__ = d
29 def __repr__(self):
30 return 'ConfigResults('+repr(self.__dict__)+')'
31
32 c = ConfigResults()
33
34 #---------- packet parsing ----------
35
36 def packet_addrs(packet):
37 version = packet[0] >> 4
38 if version == 4:
39 addrlen = 4
40 saddroff = 3*4
41 factory = ipaddress.IPv4Address
42 elif version == 6:
43 addrlen = 16
44 saddroff = 2*4
45 factory = ipaddress.IPv6Address
46 else:
47 raise ValueError('unsupported IP version %d' % version)
48 saddr = factory(packet[ saddroff : saddroff + addrlen ])
49 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
50 return (saddr, daddr)
51
52 #---------- address handling ----------
53
54 def ipaddr(input):
55 try:
56 r = ipaddress.IPv4Address(input)
57 except AddressValueError:
58 r = ipaddress.IPv6Address(input)
59 return r
60
61 def ipnetwork(input):
62 try:
63 r = ipaddress.IPv4Network(input)
64 except NetworkValueError:
65 r = ipaddress.IPv6Network(input)
66 return r
67
68 #---------- ipif (SLIP) subprocess ----------
69
70 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
71 def __init__(self, router):
72 self._buffer = b''
73 self._router = router
74 def connectionMade(self): pass
75 def outReceived(self, data):
76 #print('IPIF-GOT ', repr(data))
77 self._buffer += data
78 packets = slip.decode(self._buffer)
79 self._buffer = packets.pop()
80 for packet in packets:
81 if not len(packet): continue
82 (saddr, daddr) = packet_addrs(packet)
83 self._router(packet, saddr, daddr)
84 def processEnded(self, status):
85 status.raiseException()
86
87 def start_ipif(command, router):
88 global ipif
89 ipif = _IpifProcessProtocol(router)
90 reactor.spawnProcess(ipif,
91 '/bin/sh',['sh','-xc', command],
92 childFDs={0:'w', 1:'r', 2:2})
93
94 def queue_inbound(packet):
95 ipif.transport.write(slip.delimiter)
96 ipif.transport.write(slip.encode(packet))
97 ipif.transport.write(slip.delimiter)
98
99 #---------- packet queue ----------
100
101 class PacketQueue():
102 def __init__(self, max_queue_time):
103 self._max_queue_time = max_queue_time
104 self._pq = collections.deque() # packets
105
106 def append(self, packet):
107 self._pq.append((time.monotonic(), packet))
108
109 def nonempty(self):
110 while True:
111 try: (queuetime, packet) = self._pq[0]
112 except IndexError: return False
113
114 age = time.monotonic() - queuetime
115 if age > self.max_queue_time:
116 # strip old packets off the front
117 self._pq.popleft()
118 continue
119
120 return True
121
122 def popleft(self):
123 # caller must have checked nonempty
124 try: (dummy, packet) = self._pq[0]
125 except IndexError: return None
126 return packet
127
128 #---------- error handling ----------
129
130 def crash(err):
131 print('CRASH ', err, file=sys.stderr)
132 try: reactor.stop()
133 except twisted.internet.error.ReactorNotRunning: pass
134
135 def crash_on_defer(defer):
136 defer.addErrback(lambda err: crash(err))
137
138 def crash_on_critical(event):
139 if event.get('log_level') >= LogLevel.critical:
140 crash(twisted.logger.formatEvent(event))
141
142 #---------- config processing ----------
143
144 def process_cfg_common_always():
145 global mtu
146 c.mtu = cfg.get('virtual','mtu')
147
148 def process_cfg_ipif(section, varmap):
149 for d, s in varmap:
150 try: v = getattr(c, s)
151 except AttributeError: continue
152 setattr(c, d, v)
153
154 print(repr(c))
155
156 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
157
158 def process_cfg_network():
159 c.network = ipnetwork(cfg.get('virtual','network'))
160 if c.network.num_addresses < 3 + 2:
161 raise ValueError('network needs at least 2^3 addresses')
162
163 def process_cfg_server():
164 try:
165 c.server = cfg.get('virtual','server')
166 except NoOptionError:
167 process_cfg_network()
168 c.server = next(c.network.hosts())
169
170 class ServerAddr():
171 def __init__(self, port, addrspec):
172 self.port = port
173 # also self.addr
174 try:
175 self.addr = ipaddress.IPv4Address(addrspec)
176 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
177 self._inurl = '%s'
178 except AddressValueError:
179 self.addr = ipaddress.IPv6Address(addrspec)
180 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
181 self._inurl = '[%s]'
182 def make_endpoint(self):
183 return self._endpointfactory(reactor, self.port, self.addr)
184 def url(self):
185 url = 'http://' + (self._inurl % self.addr)
186 if self.port != 80: url += ':%d' % self.port
187 url += '/'
188 return url
189
190 def process_cfg_saddrs():
191 try: port = cfg.getint('server','port')
192 except NoOptionError: port = 80
193
194 c.saddrs = [ ]
195 for addrspec in cfg.get('server','addrs').split():
196 sa = ServerAddr(port, addrspec)
197 c.saddrs.append(sa)
198
199 def process_cfg_clients(constructor):
200 c.clients = [ ]
201 for cs in cfg.sections():
202 if not (':' in cs or '.' in cs): continue
203 ci = ipaddr(cs)
204 pw = cfg.get(cs, 'password')
205 constructor(ci,cs,pw)
206
207 #---------- startup ----------
208
209 def common_startup(defcfg):
210 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
211
212 optparser.add_option('-c', '--config', dest='configfile',
213 default='/etc/hippotat/config')
214 (opts, args) = optparser.parse_args()
215 if len(args): optparser.error('no non-option arguments please')
216
217 cfg.read_string(defcfg)
218 cfg.read(opts.configfile)
219
220 def common_run():
221 reactor.run()
222 print('CRASHED (end)', file=sys.stderr)