wip, config reorg
[hippotat] / hippotat / __init__.py
... / ...
CommitLineData
1# -*- python -*-
2
3import signal
4signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6import twisted
7from twisted.internet import reactor
8from twisted.logger import LogLevel
9
10import ipaddress
11from ipaddress import AddressValueError
12
13import hippotat.slip as slip
14
15from optparse import OptionParser
16from configparser import ConfigParser
17from configparser import NoOptionError
18
19import collections
20
21# these need to be defined here so that they can be imported by import *
22cfg = ConfigParser()
23optparser = OptionParser()
24
25class ConfigResults:
26 def __init__(self, d = { }):
27 self.__dict__ = d
28 def __repr__(self):
29 return 'ConfigResults('+repr(self.__dict__)+')'
30
31c = ConfigResults()
32
33#---------- packet parsing ----------
34
35def packet_addrs(packet):
36 version = packet[0] >> 4
37 if version == 4:
38 addrlen = 4
39 saddroff = 3*4
40 factory = ipaddress.IPv4Address
41 elif version == 6:
42 addrlen = 16
43 saddroff = 2*4
44 factory = ipaddress.IPv6Address
45 else:
46 raise ValueError('unsupported IP version %d' % version)
47 saddr = factory(packet[ saddroff : saddroff + addrlen ])
48 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
49 return (saddr, daddr)
50
51#---------- address handling ----------
52
53def ipaddr(input):
54 try:
55 r = ipaddress.IPv4Address(input)
56 except AddressValueError:
57 r = ipaddress.IPv6Address(input)
58 return r
59
60def ipnetwork(input):
61 try:
62 r = ipaddress.IPv4Network(input)
63 except NetworkValueError:
64 r = ipaddress.IPv6Network(input)
65 return r
66
67#---------- ipif (SLIP) subprocess ----------
68
69class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
70 def __init__(self, router):
71 self._buffer = b''
72 self._router = router
73 def connectionMade(self): pass
74 def outReceived(self, data):
75 #print('RECV ', repr(data))
76 self._buffer += data
77 packets = slip.decode(self._buffer)
78 self._buffer = packets.pop()
79 for packet in packets:
80 if not len(packet): continue
81 (saddr, daddr) = packet_addrs(packet)
82 self._router(packet, saddr, daddr)
83 def processEnded(self, status):
84 status.raiseException()
85
86def start_ipif(command, router):
87 global ipif
88 ipif = _IpifProcessProtocol(router)
89 reactor.spawnProcess(ipif,
90 '/bin/sh',['sh','-xc', command],
91 childFDs={0:'w', 1:'r', 2:2})
92
93def queue_inbound(packet):
94 ipif.transport.write(slip.delimiter)
95 ipif.transport.write(slip.encode(packet))
96 ipif.transport.write(slip.delimiter)
97
98#---------- packet queue ----------
99
100class PacketQueue():
101 def __init__(self, max_queue_time):
102 self._max_queue_time = max_queue_time
103 self._pq = collections.deque() # packets
104
105 def append(self, packet):
106 self._pq.append((time.monotonic(), packet))
107
108 def nonempty(self):
109 while True:
110 try: (queuetime, packet) = self._pq[0]
111 except IndexError: return False
112
113 age = time.monotonic() - queuetime
114 if age > self.max_queue_time:
115 # strip old packets off the front
116 self._pq.popleft()
117 continue
118
119 return True
120
121 def popleft(self):
122 # caller must have checked nonempty
123 try: (dummy, packet) = self._pq[0]
124 except IndexError: return None
125 return packet
126
127#---------- error handling ----------
128
129def crash(err):
130 print('CRASH ', err, file=sys.stderr)
131 try: reactor.stop()
132 except twisted.internet.error.ReactorNotRunning: pass
133
134def crash_on_defer(defer):
135 defer.addErrback(lambda err: crash(err))
136
137def crash_on_critical(event):
138 if event.get('log_level') >= LogLevel.critical:
139 crash(twisted.logger.formatEvent(event))
140
141#---------- config processing ----------
142
143def process_cfg_common_always():
144 global mtu
145 c.mtu = cfg.get('virtual','mtu')
146
147def process_cfg_ipif(section, varmap):
148 for d, s in varmap:
149 try: v = getattr(c, s)
150 except KeyError: pass
151 setattr(c, d, v)
152
153 print(repr(c))
154
155 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
156
157def process_cfg_network():
158 c.network = ipnetwork(cfg.get('virtual','network'))
159 if c.network.num_addresses < 3 + 2:
160 raise ValueError('network needs at least 2^3 addresses')
161
162def process_cfg_server():
163 try:
164 c.server = cfg.get('virtual','server')
165 except NoOptionError:
166 process_cfg_network()
167 c.server = next(c.network.hosts())
168
169class ServerAddr():
170 def __init__(self, port, addrspec):
171 self.port = port
172 # also self.addr
173 try:
174 self.addr = ipaddress.IPv4Address(addrspec)
175 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
176 self._inurl = '%s'
177 except AddressValueError:
178 self.addr = ipaddress.IPv6Address(addrspec)
179 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
180 self._inurl = '[%s]'
181 def make_endpoint(self):
182 return self._endpointfactory(reactor, self.port, self.addr)
183 def url(self):
184 url = 'http://' + (self._inurl % self.addr)
185 if self.port != 80: url += ':%d' % self.port
186 url += '/'
187 return url
188
189def process_cfg_saddrs():
190 port = cfg.getint('server','port')
191
192 c.saddrs = [ ]
193 for addrspec in cfg.get('server','addrs').split():
194 sa = ServerAddr(port, addrspec)
195 c.saddrs.append(sa)
196
197def process_cfg_clients(constructor):
198 c.clients = [ ]
199 for cs in cfg.sections():
200 if not (':' in cs or '.' in cs): continue
201 ci = ipaddr(cs)
202 pw = cfg.get(cs, 'password')
203 constructor(ci,cs,pw)
204
205#---------- startup ----------
206
207def common_startup(defcfg):
208 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
209
210 optparser.add_option('-c', '--config', dest='configfile',
211 default='/etc/hippotat/config')
212 (opts, args) = optparser.parse_args()
213 if len(args): optparser.error('no non-option arguments please')
214
215 cfg.read_string(defcfg)
216 cfg.read(opts.configfile)
217
218def common_run():
219 reactor.run()
220 print('CRASHED (end)', file=sys.stderr)