6f0d3e87d8a2ba2dcf5f7d4e1f7ea154b716f22e
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9
10 import ipaddress
11 from ipaddress import AddressValueError
12
13 import hippotat.slip as slip
14
15 from optparse import OptionParser
16 from configparser import ConfigParser
17 from configparser import NoOptionError
18
19 import collections
20
21 # these need to be defined here so that they can be imported by import *
22 cfg = ConfigParser()
23 optparser = OptionParser()
24
25 class ConfigResults:
26 def __init__(self, d = { }):
27 self.__dict__ = d
28 def __repr__(self):
29 return 'ConfigResults('+repr(self.__dict__)+')'
30
31 c = ConfigResults()
32
33 #---------- packet parsing ----------
34
35 def packet_addrs(packet):
36 version = packet[0] >> 4
37 if version == 4:
38 addrlen = 4
39 saddroff = 3*4
40 factory = ipaddress.IPv4Address
41 elif version == 6:
42 addrlen = 16
43 saddroff = 2*4
44 factory = ipaddress.IPv6Address
45 else:
46 raise ValueError('unsupported IP version %d' % version)
47 saddr = factory(packet[ saddroff : saddroff + addrlen ])
48 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
49 return (saddr, daddr)
50
51 #---------- address handling ----------
52
53 def ipaddr(input):
54 try:
55 r = ipaddress.IPv4Address(input)
56 except AddressValueError:
57 r = ipaddress.IPv6Address(input)
58 return r
59
60 def ipnetwork(input):
61 try:
62 r = ipaddress.IPv4Network(input)
63 except NetworkValueError:
64 r = ipaddress.IPv6Network(input)
65 return r
66
67 #---------- ipif (SLIP) subprocess ----------
68
69 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
70 def __init__(self, router):
71 self._buffer = b''
72 self._router = router
73 def connectionMade(self): pass
74 def outReceived(self, data):
75 #print('RECV ', repr(data))
76 self._buffer += data
77 packets = slip.decode(self._buffer)
78 self._buffer = packets.pop()
79 for packet in packets:
80 if not len(packet): continue
81 (saddr, daddr) = packet_addrs(packet)
82 self._router(packet, saddr, daddr)
83 def processEnded(self, status):
84 status.raiseException()
85
86 def start_ipif(command, router):
87 global ipif
88 ipif = _IpifProcessProtocol(router)
89 reactor.spawnProcess(ipif,
90 '/bin/sh',['sh','-xc', command],
91 childFDs={0:'w', 1:'r', 2:2})
92
93 def queue_inbound(packet):
94 ipif.transport.write(slip.delimiter)
95 ipif.transport.write(slip.encode(packet))
96 ipif.transport.write(slip.delimiter)
97
98 #---------- packet queue ----------
99
100 class PacketQueue():
101 def __init__(self, max_queue_time):
102 self._max_queue_time = max_queue_time
103 self._pq = collections.deque() # packets
104
105 def append(self, packet):
106 self._pq.append((time.monotonic(), packet))
107
108 def nonempty(self):
109 while True:
110 try: (queuetime, packet) = self._pq[0]
111 except IndexError: return False
112
113 age = time.monotonic() - queuetime
114 if age > self.max_queue_time:
115 # strip old packets off the front
116 self._pq.popleft()
117 continue
118
119 return True
120
121 def popleft(self):
122 # caller must have checked nonempty
123 try: (dummy, packet) = self._pq[0]
124 except IndexError: return None
125 return packet
126
127 #---------- error handling ----------
128
129 def crash(err):
130 print('CRASH ', err, file=sys.stderr)
131 try: reactor.stop()
132 except twisted.internet.error.ReactorNotRunning: pass
133
134 def crash_on_defer(defer):
135 defer.addErrback(lambda err: crash(err))
136
137 def crash_on_critical(event):
138 if event.get('log_level') >= LogLevel.critical:
139 crash(twisted.logger.formatEvent(event))
140
141 #---------- config processing ----------
142
143 def process_cfg_common_always():
144 global mtu
145 c.mtu = cfg.get('virtual','mtu')
146
147 def process_cfg_ipif(section, varmap):
148 for d, s in varmap:
149 try: v = getattr(c, s)
150 except KeyError: pass
151 setattr(c, d, v)
152
153 print(repr(c))
154
155 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
156
157 def process_cfg_network():
158 c.network = ipnetwork(cfg.get('virtual','network'))
159 if c.network.num_addresses < 3 + 2:
160 raise ValueError('network needs at least 2^3 addresses')
161
162 def process_cfg_server():
163 try:
164 c.server = cfg.get('virtual','server')
165 except NoOptionError:
166 process_cfg_network()
167 c.server = next(c.network.hosts())
168
169 class ServerAddr():
170 def __init__(self, port, addrspec):
171 self.port = port
172 # also self.addr
173 try:
174 self.addr = ipaddress.IPv4Address(addrspec)
175 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
176 self._inurl = '%s'
177 except AddressValueError:
178 self.addr = ipaddress.IPv6Address(addrspec)
179 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
180 self._inurl = '[%s]'
181 def make_endpoint(self):
182 return self._endpointfactory(reactor, self.port, self.addr)
183 def url(self):
184 url = 'http://' + (self._inurl % self.addr)
185 if self.port != 80: url += ':%d' % self.port
186 url += '/'
187 return url
188
189 def process_cfg_saddrs():
190 port = cfg.getint('server','port')
191
192 c.saddrs = [ ]
193 for addrspec in cfg.get('server','addrs').split():
194 sa = ServerAddr(port, addrspec)
195 c.saddrs.append(sa)
196
197 def process_cfg_clients(constructor):
198 c.clients = [ ]
199 for cs in cfg.sections():
200 if not (':' in cs or '.' in cs): continue
201 ci = ipaddr(cs)
202 pw = cfg.get(cs, 'password')
203 constructor(ci,cs,pw)
204
205 #---------- startup ----------
206
207 def common_startup(defcfg):
208 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
209
210 optparser.add_option('-c', '--config', dest='configfile',
211 default='/etc/hippotat/config')
212 (opts, args) = optparser.parse_args()
213 if len(args): optparser.error('no non-option arguments please')
214
215 cfg.read_string(defcfg)
216 cfg.read(opts.configfile)
217
218 def common_run():
219 reactor.run()
220 print('CRASHED (end)', file=sys.stderr)