X-Git-Url: https://git.distorted.org.uk/~mdw/tripe-android/blobdiff_plain/c8292b34485a2e00e676023d4164dd5841e4659f..0157de026e802e94a2d0db0421b02ffca986c616:/sys.scala diff --git a/sys.scala b/sys.scala index cf0a72d..c449414 100644 --- a/sys.scala +++ b/sys.scala @@ -28,7 +28,7 @@ package uk.org.distorted.tripe; package object sys { /*----- Imports -----------------------------------------------------------*/ import scala.collection.convert.decorateAsJava._; -import scala.collection.mutable.HashSet; +import scala.collection.mutable.{HashMap, HashSet}; import java.io.{BufferedReader, BufferedWriter, Closeable, File, FileDescriptor, FileInputStream, FileOutputStream, @@ -38,6 +38,8 @@ import java.nio.{ByteBuffer, CharBuffer}; import java.nio.charset.Charset; import java.util.Date; +import Implicits.truish; + /*----- Some magic for C strings ------------------------------------------*/ type CString = Array[Byte]; @@ -124,7 +126,7 @@ import StringImplicits._; /*----- Main code ---------------------------------------------------------*/ /* Import the native code library. */ -System.loadLibrary("toy"); +System.loadLibrary("tripe"); /* Native types. * @@ -175,23 +177,23 @@ object Errno extends Enumeration { private var wrong = -256; // next synthetic code private def nextwrong: Int = { val w = wrong; wrong -= 1; w } - class Type private[Errno](tag: String, val code: Int, id: Int) - extends Val(id, tag) { + class Val private[Errno](tag: String, val code: Int, id: Int) + extends super.Val(id, tag) { /* Our augmented error type. */ def message: String = strerror(code).toJString; } private class UnknownError(code: Int) - extends Type("", code, code); + extends Val("", code, code); - private def err(tag: String, code: Int): Type = { + private def err(tag: String, code: Int): Val = { /* Construct an error symbol given its tag string and a code number. */ - if (code < 0) new Type(tag, code, code) - else if (seen contains code) new Type(tag, code, nextwrong) - else { seen += code; new Type(tag, code, code) } + if (code < 0) new Val(tag, code, code) + else if (seen contains code) new Val(tag, code, nextwrong) + else { seen += code; new Val(tag, code, code) } } - private def err(tag: String): Type = + private def err(tag: String): Val = err(tag, tagmap.getOrElse(tag, nextwrong)); def byid(id: Int): Value = { @@ -377,7 +379,7 @@ object Errno extends Enumeration { val EHWPOISON = err("EHWPOISON"); /***end***/ } -import Errno.{Type => Errno, EEXIST, EISDIR, ENOENT, ENOTDIR}; +import Errno.{Val => Errno, EEXIST, EISDIR, ENOENT, ENOTDIR}; object SystemError { /* Pattern matching for `SystemError', below. */ @@ -438,8 +440,11 @@ def stat(path: String): sys.FileInfo = stat(path.toCString); def lstat(path: String): sys.FileInfo = lstat(path.toCString); object FileInfo extends Enumeration { - /* A simple enumeration of things a file might be. */ - val FIFO, CHR, DIR, BLK, REG, LNK, SOCK, UNK = Value; + /* A simple enumeration of things a file might be. + * + * `HDLNK' is a hard link, used in `tar' files. + */ + val FIFO, CHR, DIR, BLK, REG, LNK, SOCK, HDLNK, UNK = Value; type Type = Value; } import FileInfo._; @@ -596,7 +601,7 @@ object FileImplicits { /* Constructing names of files in a directory. Honestly, I'm surprised * there isn't a method for this already. */ - def +(sub: String): File = new File(file, sub); + def /(sub: String): File = new File(file, sub); /* Simple file operations. */ def unlink_!() { unlink(file.getPath); } @@ -637,19 +642,15 @@ object FileImplicits { def islnk_! : Boolean = statish(lstat _, _.ftype == LNK, false); def issock_! : Boolean = statish(stat _, _.ftype == SOCK, false); + /* Slightly more cooked file operations. */ def remove_!() { /* Delete a file, or directory, whatever it is. */ - while (true) { - try { unlink_!(); return; } - catch { - case SystemError(ENOENT, _) => return; - case SystemError(EISDIR, _) => ok; - } - try { rmdir_!(); return; } - catch { - case SystemError(ENOENT, _) => return; - case SystemError(ENOTDIR, _) => ok; - } + try { unlink_!(); return; } + catch { + case SystemError(ENOENT, _) => return; + case SystemError(EISDIR, _) => + try { rmdir_!(); return; } + catch { case SystemError(ENOENT, _) => return; } } } @@ -662,6 +663,12 @@ object FileImplicits { walk(file); } + def mkdirNew_!() { + /* Make a directory if there's nothing there already. */ + try { mkdir_!(); } + catch { case SystemError(EEXIST, _) => ok; } + } + /* File locking. */ def lock_!(flags: Int): FileLock = new FileLock(file.getPath, flags); def lock_!(): FileLock = lock_!(LKF_EXCL | 0x1b6); @@ -739,7 +746,7 @@ def freshFile(d: File): File = { /* Make the filename, and try to create the file. If we succeed, we * win. */ - val f = new File(d, b.result); b.clear(); + val f = d/b.result; b.clear(); try { f.mkfile_!(); exit(f); } catch { case SystemError(EEXIST, _) => ok; } } @@ -768,8 +775,9 @@ def runCommand(cmd: String*): (String, String) = { withCleaner { clean => /* Create the child process and pick up the ends of its streams. */ - val pb = new ProcessBuilder(cmd.asJava).redirectInput(devnull); + val pb = new ProcessBuilder(cmd.asJava); val kid = pb.start(); clean { kid.destroy(); } + kid.getOutputStream.close(); val out = kid.getInputStream(); clean { out.close(); } val err = kid.getErrorStream(); clean { err.close(); } @@ -791,64 +799,132 @@ def runCommand(cmd: String*): (String, String) = { /* Check the exit status. */ val rc = kid.exitValue; - if (rc != 0) throw new SubprocessFailed(cmd, rc, berr.result); + if (rc) throw new SubprocessFailed(cmd, rc, berr.result); /* We're all done. */ return (bout.result, berr.result); } } -/*----- Connecting to a server --------------------------------------------*/ - -/* Primitive operations. */ -final val CF_CLOSERD = 1; -final val CF_CLOSEWR = 2; -final val CF_CLOSEMASK = CF_CLOSERD | CF_CLOSEWR; -@native protected def connect(path: CString): Wrapper; -@native protected def send(conn: Wrapper, buf: CString, - start: Int, len: Int); -@native protected def recv(conn: Wrapper, buf: CString, - start: Int, len: Int): Int; -@native def closeconn(conn: Wrapper, how: Int); - -class Connection(path: String) extends Closeable { - - /* The underlying primitive connection. */ - private[this] val conn = connect(path.toCString); - - /* Alternative constructors. */ - def this(file: File) { this(file.getPath); } - - /* Cleanup.*/ - override def close() { closeconn(conn, CF_CLOSEMASK); } - override protected def finalize() { super.finalize(); close(); } - - class Input private[Connection] extends InputStream { - /* An input stream which reads from the connection. */ +/*----- Interrupt triggers ------------------------------------------------*/ + +private val triggerLock = new Object; +private final val maxTriggers = 2; +private var nTriggers = 0; +private var triggers: List[Wrapper] = Nil; + +@native protected def make_trigger(): Wrapper; +@native protected def destroy_trigger(trig: Wrapper); +@native protected def reset_trigger(trig: Wrapper); +@native protected def trigger(trig: Wrapper); + +private def getTrigger(): Wrapper = { + triggerLock synchronized { + if (!nTriggers) + make_trigger() + else { + val trig = triggers.head; + triggers = triggers.tail; + nTriggers -= 1; + trig + } + } +} - override def read(): Int = { - val buf = new Array[Byte](1); - val n = read(buf, 0, 1); - if (n < 0) -1 else buf(0)&0xff; +private def putTrigger(trig: Wrapper) { + reset_trigger(trig); + triggerLock synchronized { + if (nTriggers >= maxTriggers) + destroy_trigger(trig); + else { + triggers ::= trig; + nTriggers += 1; } - override def read(buf: Array[Byte]): Int = - read(buf, 0, buf.length); - override def read(buf: Array[Byte], start: Int, len: Int) = - recv(conn, buf, start, len); - override def close() { closeconn(conn, CF_CLOSERD); } } - lazy val input = new Input; +} - class Output private[Connection] extends OutputStream { - /* An output stream which writes to the connection. */ +private def withTrigger[T](body: Wrapper => T): T = { + val trig = getTrigger(); + try { body(trig) } + finally { putTrigger(trig); } +} + +def interruptWithTrigger[T](body: Wrapper => T): T = { + /* interruptWithTrigger { TRIG => BODY } + * + * Execute BODY and return its result. If the thread receives an + * interrupt, the trigger TRIG will be pulled. See `interruptably' for the + * full semantics. + */ - override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); } - override def write(buf: Array[Byte]) { write(buf, 0, buf.length); } - override def write(buf: Array[Byte], start: Int, len: Int) - { send(conn, buf, start, len); } - override def close() { closeconn(conn, CF_CLOSEWR); } + withTrigger { trig => + interruptably { body(trig) } onInterrupt { trigger(trig); } + }; +} + +/*----- Glue for the VPN server -------------------------------------------*/ + +/* The lock class. This is only a class because they're much easier to find + * than loose objects through JNI. + */ +private class ServerLock; + +/* Exceptions. */ +class NameResolutionException(msg: String) extends Exception(msg); +class InitializationException(msg: String) extends Exception(msg); + +/* Primitive operations. */ +@native protected def open_tun(): Int; +@native protected def base_init(); +@native protected def setup_resolver(); +@native def load_keys(priv: CString, pub: CString, tag: CString); +@native def unload_keys(); +@native def bind(host: CString, svc: CString); +@native def unbind(); +@native def mark(seq: Int); +@native def run(); +@native protected def send(buf: CString, start: Int, len: Int, + trig: Wrapper); +@native protected def recv(buf: CString, start: Int, len: Int, + trig: Wrapper): Int; + +base_init(); +setup_resolver(); + +/* Tunnel descriptor plumbing. */ +val pending = HashMap[String, Int](); + +def getTunnelFd(peer: CString): Int = + pending synchronized { pending(peer.toJString) }; +def storeTunnelFd(peer: String, fd: Int) + { pending synchronized { pending(peer) = fd; } } +def withdrawTunnelFd(peer: String) + { pending synchronized { pending -= peer; } } +def withTunnelFd[T](peer: String, fd: Int)(body: => T): T = { + storeTunnelFd(peer, fd); + try { body } finally { withdrawTunnelFd(peer); } +} + +/* Server I/O. */ +lazy val serverInput: InputStream = new InputStream { + override def read(): Int = { + val buf = new Array[Byte](1); + val n = read(buf, 0, 1); + if (n < 0) -1 else buf(0)&0xff; } - lazy val output = new Output; + override def read(buf: Array[Byte]): Int = + read(buf, 0, buf.length); + override def read(buf: Array[Byte], start: Int, len: Int) = + interruptWithTrigger { trig => recv(buf, start, len, trig); }; + override def close() { } +} + +lazy val serverOutput: OutputStream = new OutputStream { + override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); } + override def write(buf: Array[Byte]) { write(buf, 0, buf.length); } + override def write(buf: Array[Byte], start: Int, len: Int) + { interruptWithTrigger { trig => send(buf, start, len, trig); } } + override def close() { } } /*----- Crypto-library hacks ----------------------------------------------*/