wip, config reorg
[hippotat] / hippotat / __init__.py
CommitLineData
b0cfbfce
IJ
1# -*- python -*-
2
37ab4cdc
IJ
3import signal
4signal.signal(signal.SIGINT, signal.SIG_DFL)
5
040ff511
IJ
6import twisted
7from twisted.internet import reactor
ae7c7784 8from twisted.logger import LogLevel
b0cfbfce
IJ
9
10import ipaddress
11from ipaddress import AddressValueError
12
040ff511
IJ
13import hippotat.slip as slip
14
ae7c7784
IJ
15from optparse import OptionParser
16from configparser import ConfigParser
17from configparser import NoOptionError
18
19import collections
20
87a7c0c7 21# these need to be defined here so that they can be imported by import *
ae7c7784
IJ
22cfg = ConfigParser()
23optparser = OptionParser()
24
87a7c0c7
IJ
25class ConfigResults:
26 def __init__(self, d = { }):
27 self.__dict__ = d
28 def __repr__(self):
29 return 'ConfigResults('+repr(self.__dict__)+')'
30
31c = ConfigResults()
32
b0cfbfce
IJ
33#---------- packet parsing ----------
34
35def packet_addrs(packet):
36 version = packet[0] >> 4
37 if version == 4:
38 addrlen = 4
39 saddroff = 3*4
40 factory = ipaddress.IPv4Address
41 elif version == 6:
42 addrlen = 16
43 saddroff = 2*4
44 factory = ipaddress.IPv6Address
45 else:
46 raise ValueError('unsupported IP version %d' % version)
47 saddr = factory(packet[ saddroff : saddroff + addrlen ])
48 daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
49 return (saddr, daddr)
50
51#---------- address handling ----------
52
53def ipaddr(input):
54 try:
55 r = ipaddress.IPv4Address(input)
56 except AddressValueError:
57 r = ipaddress.IPv6Address(input)
58 return r
59
60def ipnetwork(input):
61 try:
62 r = ipaddress.IPv4Network(input)
63 except NetworkValueError:
64 r = ipaddress.IPv6Network(input)
65 return r
040ff511
IJ
66
67#---------- ipif (SLIP) subprocess ----------
68
69class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
70 def __init__(self, router):
71 self._buffer = b''
72 self._router = router
73 def connectionMade(self): pass
74 def outReceived(self, data):
75 #print('RECV ', repr(data))
76 self._buffer += data
77 packets = slip.decode(self._buffer)
78 self._buffer = packets.pop()
79 for packet in packets:
80 if not len(packet): continue
81 (saddr, daddr) = packet_addrs(packet)
82 self._router(packet, saddr, daddr)
83 def processEnded(self, status):
84 status.raiseException()
85
86def start_ipif(command, router):
87 global ipif
88 ipif = _IpifProcessProtocol(router)
89 reactor.spawnProcess(ipif,
90 '/bin/sh',['sh','-xc', command],
91 childFDs={0:'w', 1:'r', 2:2})
92
93def queue_inbound(packet):
94 ipif.transport.write(slip.delimiter)
95 ipif.transport.write(slip.encode(packet))
96 ipif.transport.write(slip.delimiter)
97
650a3251
IJ
98#---------- packet queue ----------
99
100class PacketQueue():
101 def __init__(self, max_queue_time):
102 self._max_queue_time = max_queue_time
103 self._pq = collections.deque() # packets
104
105 def append(self, packet):
106 self._pq.append((time.monotonic(), packet))
107
108 def nonempty(self):
109 while True:
110 try: (queuetime, packet) = self._pq[0]
111 except IndexError: return False
112
113 age = time.monotonic() - queuetime
114 if age > self.max_queue_time:
115 # strip old packets off the front
116 self._pq.popleft()
117 continue
118
119 return True
120
121 def popleft(self):
122 # caller must have checked nonempty
123 try: (dummy, packet) = self._pq[0]
124 except IndexError: return None
125 return packet
ae7c7784
IJ
126
127#---------- error handling ----------
128
129def crash(err):
130 print('CRASH ', err, file=sys.stderr)
131 try: reactor.stop()
132 except twisted.internet.error.ReactorNotRunning: pass
133
134def crash_on_defer(defer):
135 defer.addErrback(lambda err: crash(err))
136
137def crash_on_critical(event):
138 if event.get('log_level') >= LogLevel.critical:
139 crash(twisted.logger.formatEvent(event))
140
87a7c0c7
IJ
141#---------- config processing ----------
142
143def process_cfg_common_always():
144 global mtu
145 c.mtu = cfg.get('virtual','mtu')
146
88487243
IJ
147def process_cfg_ipif(section, varmap):
148 for d, s in varmap:
149 try: v = getattr(c, s)
150 except KeyError: pass
151 setattr(c, d, v)
152
153 print(repr(c))
154
155 c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
156
157def process_cfg_network():
158 c.network = ipnetwork(cfg.get('virtual','network'))
159 if c.network.num_addresses < 3 + 2:
160 raise ValueError('network needs at least 2^3 addresses')
161
162def process_cfg_server():
163 try:
164 c.server = cfg.get('virtual','server')
165 except NoOptionError:
166 process_cfg_network()
167 c.server = next(c.network.hosts())
168
169class ServerAddr():
170 def __init__(self, port, addrspec):
171 self.port = port
172 # also self.addr
173 try:
174 self.addr = ipaddress.IPv4Address(addrspec)
175 self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
176 self._inurl = '%s'
177 except AddressValueError:
178 self.addr = ipaddress.IPv6Address(addrspec)
179 self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
180 self._inurl = '[%s]'
181 def make_endpoint(self):
182 return self._endpointfactory(reactor, self.port, self.addr)
183 def url(self):
184 url = 'http://' + (self._inurl % self.addr)
185 if self.port != 80: url += ':%d' % self.port
186 url += '/'
187 return url
188
189def process_cfg_saddrs():
190 port = cfg.getint('server','port')
191
192 c.saddrs = [ ]
193 for addrspec in cfg.get('server','addrs').split():
194 sa = ServerAddr(port, addrspec)
195 c.saddrs.append(sa)
196
197def process_cfg_clients(constructor):
198 c.clients = [ ]
199 for cs in cfg.sections():
200 if not (':' in cs or '.' in cs): continue
201 ci = ipaddr(cs)
202 pw = cfg.get(cs, 'password')
203 constructor(ci,cs,pw)
204
ae7c7784
IJ
205#---------- startup ----------
206
207def common_startup(defcfg):
208 twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
209
210 optparser.add_option('-c', '--config', dest='configfile',
211 default='/etc/hippotat/config')
212 (opts, args) = optparser.parse_args()
213 if len(args): optparser.error('no non-option arguments please')
214
215 cfg.read_string(defcfg)
216 cfg.read(opts.configfile)
217
218def common_run():
219 reactor.run()
220 print('CRASHED (end)', file=sys.stderr)