1baaed33c68703758d160ee2d32f9e1b0059b471
4 signal
.signal(signal
.SIGINT
, signal
.SIG_DFL
)
7 from twisted
.internet
import reactor
8 from twisted
.logger
import LogLevel
9 import twisted
.internet
.endpoints
12 from ipaddress
import AddressValueError
14 import hippotat
.slip
as slip
16 from optparse
import OptionParser
17 from configparser
import ConfigParser
18 from configparser
import NoOptionError
22 # these need to be defined here so that they can be imported by import *
24 optparser
= OptionParser()
27 def __init__(self
, d
= { }):
30 return 'ConfigResults('+repr(self
.__dict__
)+')'
34 #---------- packet parsing ----------
36 def packet_addrs(packet
):
37 version
= packet
[0] >> 4
41 factory
= ipaddress
.IPv4Address
45 factory
= ipaddress
.IPv6Address
47 raise ValueError('unsupported IP version %d' % version
)
48 saddr
= factory(packet
[ saddroff
: saddroff
+ addrlen
])
49 daddr
= factory(packet
[ saddroff
+ addrlen
: saddroff
+ addrlen
*2 ])
52 #---------- address handling ----------
56 r
= ipaddress
.IPv4Address(input)
57 except AddressValueError
:
58 r
= ipaddress
.IPv6Address(input)
63 r
= ipaddress
.IPv4Network(input)
64 except NetworkValueError
:
65 r
= ipaddress
.IPv6Network(input)
68 #---------- ipif (SLIP) subprocess ----------
70 class _IpifProcessProtocol(twisted
.internet
.protocol
.ProcessProtocol
):
71 def __init__(self
, router
):
74 def connectionMade(self
): pass
75 def outReceived(self
, data
):
76 #print('IPIF-GOT ', repr(data))
78 packets
= slip
.decode(self
._buffer
)
79 self
._buffer
= packets
.pop()
80 for packet
in packets
:
81 if not len(packet
): continue
82 (saddr
, daddr
) = packet_addrs(packet
)
83 self
._router(packet
, saddr
, daddr
)
84 def processEnded(self
, status
):
85 status
.raiseException()
87 def start_ipif(command
, router
):
89 ipif
= _IpifProcessProtocol(router
)
90 reactor
.spawnProcess(ipif
,
91 '/bin/sh',['sh','-xc', command
],
92 childFDs
={0:'w', 1:'r', 2:2})
94 def queue_inbound(packet
):
95 ipif
.transport
.write(slip
.delimiter
)
96 ipif
.transport
.write(slip
.encode(packet
))
97 ipif
.transport
.write(slip
.delimiter
)
99 #---------- packet queue ----------
102 def __init__(self
, max_queue_time
):
103 self
._max_queue_time
= max_queue_time
104 self
._pq
= collections
.deque() # packets
106 def append(self
, packet
):
107 self
._pq
.append((time
.monotonic(), packet
))
111 try: (queuetime
, packet
) = self
._pq
[0]
112 except IndexError: return False
114 age
= time
.monotonic() - queuetime
115 if age
> self
.max_queue_time
:
116 # strip old packets off the front
123 # caller must have checked nonempty
124 try: (dummy
, packet
) = self
._pq
[0]
125 except IndexError: return None
128 #---------- error handling ----------
131 print('CRASH ', err
, file=sys
.stderr
)
133 except twisted
.internet
.error
.ReactorNotRunning
: pass
135 def crash_on_defer(defer
):
136 defer
.addErrback(lambda err
: crash(err
))
138 def crash_on_critical(event
):
139 if event
.get('log_level') >= LogLevel
.critical
:
140 crash(twisted
.logger
.formatEvent(event
))
142 #---------- config processing ----------
144 def process_cfg_common_always():
146 c
.mtu
= cfg
.get('virtual','mtu')
148 def process_cfg_ipif(section
, varmap
):
150 try: v
= getattr(c
, s
)
151 except AttributeError: continue
156 c
.ipif_command
= cfg
.get(section
,'ipif', vars=c
.__dict__
)
158 def process_cfg_network():
159 c
.network
= ipnetwork(cfg
.get('virtual','network'))
160 if c
.network
.num_addresses
< 3 + 2:
161 raise ValueError('network needs at least 2^3 addresses')
163 def process_cfg_server():
165 c
.server
= cfg
.get('virtual','server')
166 except NoOptionError
:
167 process_cfg_network()
168 c
.server
= next(c
.network
.hosts())
171 def __init__(self
, port
, addrspec
):
175 self
.addr
= ipaddress
.IPv4Address(addrspec
)
176 self
._endpointfactory
= twisted
.internet
.endpoints
.TCP4ServerEndpoint
178 except AddressValueError
:
179 self
.addr
= ipaddress
.IPv6Address(addrspec
)
180 self
._endpointfactory
= twisted
.internet
.endpoints
.TCP6ServerEndpoint
182 def make_endpoint(self
):
183 return self
._endpointfactory(reactor
, self
.port
, self
.addr
)
185 url
= 'http://' + (self
._inurl % self
.addr
)
186 if self
.port
!= 80: url
+= ':%d' % self
.port
190 def process_cfg_saddrs():
191 try: port
= cfg
.getint('server','port')
192 except NoOptionError
: port
= 80
195 for addrspec
in cfg
.get('server','addrs').split():
196 sa
= ServerAddr(port
, addrspec
)
199 def process_cfg_clients(constructor
):
201 for cs
in cfg
.sections():
202 if not (':' in cs
or '.' in cs
): continue
204 pw
= cfg
.get(cs
, 'password')
205 constructor(ci
,cs
,pw
)
207 #---------- startup ----------
209 def common_startup(defcfg
):
210 twisted
.logger
.globalLogPublisher
.addObserver(crash_on_critical
)
212 optparser
.add_option('-c', '--config', dest
='configfile',
213 default
='/etc/hippotat/config')
214 (opts
, args
) = optparser
.parse_args()
215 if len(args
): optparser
.error('no non-option arguments please')
217 cfg
.read_string(defcfg
)
218 cfg
.read(opts
.configfile
)
222 print('CRASHED (end)', file=sys
.stderr
)