6f0d3e87d8a2ba2dcf5f7d4e1f7ea154b716f22e
4 signal
.signal(signal
.SIGINT
, signal
.SIG_DFL
)
7 from twisted
.internet
import reactor
8 from twisted
.logger
import LogLevel
11 from ipaddress
import AddressValueError
13 import hippotat
.slip
as slip
15 from optparse
import OptionParser
16 from configparser
import ConfigParser
17 from configparser
import NoOptionError
21 # these need to be defined here so that they can be imported by import *
23 optparser
= OptionParser()
26 def __init__(self
, d
= { }):
29 return 'ConfigResults('+repr(self
.__dict__
)+')'
33 #---------- packet parsing ----------
35 def packet_addrs(packet
):
36 version
= packet
[0] >> 4
40 factory
= ipaddress
.IPv4Address
44 factory
= ipaddress
.IPv6Address
46 raise ValueError('unsupported IP version %d' % version
)
47 saddr
= factory(packet
[ saddroff
: saddroff
+ addrlen
])
48 daddr
= factory(packet
[ saddroff
+ addrlen
: saddroff
+ addrlen
*2 ])
51 #---------- address handling ----------
55 r
= ipaddress
.IPv4Address(input)
56 except AddressValueError
:
57 r
= ipaddress
.IPv6Address(input)
62 r
= ipaddress
.IPv4Network(input)
63 except NetworkValueError
:
64 r
= ipaddress
.IPv6Network(input)
67 #---------- ipif (SLIP) subprocess ----------
69 class _IpifProcessProtocol(twisted
.internet
.protocol
.ProcessProtocol
):
70 def __init__(self
, router
):
73 def connectionMade(self
): pass
74 def outReceived(self
, data
):
75 #print('RECV ', repr(data))
77 packets
= slip
.decode(self
._buffer
)
78 self
._buffer
= packets
.pop()
79 for packet
in packets
:
80 if not len(packet
): continue
81 (saddr
, daddr
) = packet_addrs(packet
)
82 self
._router(packet
, saddr
, daddr
)
83 def processEnded(self
, status
):
84 status
.raiseException()
86 def start_ipif(command
, router
):
88 ipif
= _IpifProcessProtocol(router
)
89 reactor
.spawnProcess(ipif
,
90 '/bin/sh',['sh','-xc', command
],
91 childFDs
={0:'w', 1:'r', 2:2})
93 def queue_inbound(packet
):
94 ipif
.transport
.write(slip
.delimiter
)
95 ipif
.transport
.write(slip
.encode(packet
))
96 ipif
.transport
.write(slip
.delimiter
)
98 #---------- packet queue ----------
101 def __init__(self
, max_queue_time
):
102 self
._max_queue_time
= max_queue_time
103 self
._pq
= collections
.deque() # packets
105 def append(self
, packet
):
106 self
._pq
.append((time
.monotonic(), packet
))
110 try: (queuetime
, packet
) = self
._pq
[0]
111 except IndexError: return False
113 age
= time
.monotonic() - queuetime
114 if age
> self
.max_queue_time
:
115 # strip old packets off the front
122 # caller must have checked nonempty
123 try: (dummy
, packet
) = self
._pq
[0]
124 except IndexError: return None
127 #---------- error handling ----------
130 print('CRASH ', err
, file=sys
.stderr
)
132 except twisted
.internet
.error
.ReactorNotRunning
: pass
134 def crash_on_defer(defer
):
135 defer
.addErrback(lambda err
: crash(err
))
137 def crash_on_critical(event
):
138 if event
.get('log_level') >= LogLevel
.critical
:
139 crash(twisted
.logger
.formatEvent(event
))
141 #---------- config processing ----------
143 def process_cfg_common_always():
145 c
.mtu
= cfg
.get('virtual','mtu')
147 def process_cfg_ipif(section
, varmap
):
149 try: v
= getattr(c
, s
)
150 except KeyError: pass
155 c
.ipif_command
= cfg
.get(section
,'ipif', vars=c
.__dict__
)
157 def process_cfg_network():
158 c
.network
= ipnetwork(cfg
.get('virtual','network'))
159 if c
.network
.num_addresses
< 3 + 2:
160 raise ValueError('network needs at least 2^3 addresses')
162 def process_cfg_server():
164 c
.server
= cfg
.get('virtual','server')
165 except NoOptionError
:
166 process_cfg_network()
167 c
.server
= next(c
.network
.hosts())
170 def __init__(self
, port
, addrspec
):
174 self
.addr
= ipaddress
.IPv4Address(addrspec
)
175 self
._endpointfactory
= twisted
.internet
.endpoints
.TCP4ServerEndpoint
177 except AddressValueError
:
178 self
.addr
= ipaddress
.IPv6Address(addrspec
)
179 self
._endpointfactory
= twisted
.internet
.endpoints
.TCP6ServerEndpoint
181 def make_endpoint(self
):
182 return self
._endpointfactory(reactor
, self
.port
, self
.addr
)
184 url
= 'http://' + (self
._inurl % self
.addr
)
185 if self
.port
!= 80: url
+= ':%d' % self
.port
189 def process_cfg_saddrs():
190 port
= cfg
.getint('server','port')
193 for addrspec
in cfg
.get('server','addrs').split():
194 sa
= ServerAddr(port
, addrspec
)
197 def process_cfg_clients(constructor
):
199 for cs
in cfg
.sections():
200 if not (':' in cs
or '.' in cs
): continue
202 pw
= cfg
.get(cs
, 'password')
203 constructor(ci
,cs
,pw
)
205 #---------- startup ----------
207 def common_startup(defcfg
):
208 twisted
.logger
.globalLogPublisher
.addObserver(crash_on_critical
)
210 optparser
.add_option('-c', '--config', dest
='configfile',
211 default
='/etc/hippotat/config')
212 (opts
, args
) = optparser
.parse_args()
213 if len(args
): optparser
.error('no non-option arguments please')
215 cfg
.read_string(defcfg
)
216 cfg
.read(opts
.configfile
)
220 print('CRASHED (end)', file=sys
.stderr
)