From 3bb2303d42adb3f37420f168b009ecfe64f888cd Mon Sep 17 00:00:00 2001 From: Mark Wooding Date: Sat, 16 Jun 2018 19:32:22 +0100 Subject: [PATCH] Integrate the TrIPE server into the Java edifice. And probably other things too. We're still in broad brushstrokes mode here. --- Makefile | 36 ++-- admin.scala | 22 +-- jni.c | 558 +++++++++++++++++++++++++++++++++++++++++++++++------------- keys.scala | 6 +- sys.scala | 120 +++++++------ 5 files changed, 543 insertions(+), 199 deletions(-) diff --git a/Makefile b/Makefile index 6e5ca0e..5edb29a 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ CC = gcc CFLAGS = -O2 -g -Wall -pedantic -Werror ## Native linker. -LD = gcc +LD = gcc -Wl,-z,defs LDFLAGS.so = -shared ## External `pkg-config' packages required. @@ -107,14 +107,30 @@ V_AT_0 = @ ###-------------------------------------------------------------------------- ### External native packages. -PKGS_CFLAGS := $(foreach p,$(PKGS),$(shell pkg-config --cflags $p)) -PKGS_LIBS := $(foreach p,$(PKGS),$(shell pkg-config --libs $p)) +EXTPREFIX = $(abs_builddir)/$(OUTDIR)/inst + +join-paths = $(if $(filter /%,$2),$2,$1/$2) +ext-srcdir = $(or $($1_SRCDIR),../$1) + +PKG_CONFIG = PKG_CONFIG_LIBDIR=$(OUTDIR)/inst/lib/pkgconfig \ + pkg-config --static + +PKGS_CFLAGS := $(foreach p,$(PKGS),$(shell $(PKG_CONFIG) --cflags $p)) +PKGS_LIBS := $(foreach p,$(PKGS),$(shell $(PKG_CONFIG) --libs $p)) ALL_CFLAGS = $(CFLAGS) -fPIC \ $(addprefix -I,$(JNI_INCLUDES)) \ + -I$(OUTDIR)/inst/include \ + -I$(call ext-srcdir,tripe)/common \ + -I$(call ext-srcdir,tripe)/priv \ + -I$(call ext-srcdir,tripe)/server \ + -I$(OUTDIR)/build/tripe/config \ $(PKGS_CFLAGS) -LIBS = $(PKGS_LIBS) +LIBS = $(OUTDIR)/build/tripe/server/libtripe.a \ + $(OUTDIR)/build/tripe/priv/libpriv.a \ + $(OUTDIR)/build/tripe/common/libcommon.a \ + -L$(OUTDIR)/inst/lib $(PKGS_LIBS) ###-------------------------------------------------------------------------- ### Various other tweaks and overrides. @@ -177,14 +193,16 @@ $(OUTDIR)/%.class-stamp: %.scala ###-------------------------------------------------------------------------- ### Native-code libraries. -SHLIBS += toy -toy_SOURCES = jni.c +SHLIBS += tripe +tripe_SOURCES = jni.c shlibfile = $(patsubst %,$(OUTDIR)/lib%.so,$1) SHLIBFILES = $(call shlibfile,$(SHLIBS)) TARGETS += $(SHLIBFILES) ALL_SOURCES += $(foreach l,$(SHLIBS),$($l_SOURCES)) +$(call objects,$(tripe_SOURCES),.o): $(call stamps,ext,tripe) + $(SHLIBFILES): $(OUTDIR)/lib%.so: $$(call objects,$$($$*_SOURCES),.o) $(call v_tag,LD)$(LD) $(LDFLAGS.so) -o$@ $^ $(LIBS) @@ -203,7 +221,6 @@ CLASSES += tar:util CLASSES += progress:sys,util CLASSES += keys:progress,tar,sys,util CLASSES += terminal:progress,sys,util -CLASSES += main:sys ## Machinery for parsing the `CLASSES' list. COMMA = , @@ -224,11 +241,6 @@ DISTFILES += $(foreach c,$(CLASSES),\ ###-------------------------------------------------------------------------- ### External packages. -EXTPREFIX = $(abs_builddir)/$(OUTDIR)/inst - -join-paths = $(if $(filter /%,$2),$2,$1/$2) -ext-srcdir = $(or $($1_SRCDIR),../$1) - EXTERNALS += adns adns_CONFIG = --disable-dynamic diff --git a/admin.scala b/admin.scala index fab8305..52a2912 100644 --- a/admin.scala +++ b/admin.scala @@ -27,7 +27,7 @@ package uk.org.distorted.tripe; package object admin { /*----- Imports -----------------------------------------------------------*/ -import java.io.{BufferedReader, Reader, Writer}; +import java.io.{BufferedReader, InputStreamReader, OutputStreamWriter}; import java.util.concurrent.locks.{Condition, ReentrantLock => Lock}; import scala.collection.mutable.{HashMap, Publisher}; @@ -35,6 +35,7 @@ import scala.concurrent.Channel; import scala.util.control.Breaks; import Implicits._; +import sys.{serverInput, serverOutput}; /*----- Classification of server messages ---------------------------------*/ @@ -77,8 +78,7 @@ class CommandFailed(val msg: Seq[String]) extends Exception { class ConnectionLostException extends Exception; -class Connection(val in: Reader, val out: Writer) - extends Publisher[AsyncMessage] +object Connection extends Publisher[AsyncMessage] { /* Synchronization. * @@ -87,12 +87,15 @@ class Connection(val in: Reader, val out: Writer) * hold the `Connection' lock before locking any individual `Job' objects. */ - var livep: Boolean = true; // Is this connection still alive? - var fgjob: Option[this.Job] = None; // Foreground job, if there is one. - val jobmap = new HashMap[String, this.Job]; // Maps tags to extant jobs. - var bgseq = 0; // Next background job tag. + private var livep: Boolean = true; // Is this connection still alive? + private var fgjob: Option[this.Job] = None; // Foreground job, if there is one. + private val jobmap = new HashMap[String, this.Job]; // Maps tags to extant jobs. + private var bgseq = 0; // Next background job tag. - type Pub = Connection; + private val in = new BufferedReader(new InputStreamReader(serverInput)); + private val out = new OutputStreamWriter(serverOutput); + + type Pub = Connection.type; class Job extends Iterator[Seq[String]] { private[Connection] val ch = new Channel[JobMessage]; @@ -183,8 +186,6 @@ println(";; write command"); def submit(toks: String*): this.Job = submit(false, toks: _*); - def close() { synchronized { out.close(); } } - /* These two expect the connection lock to be held. */ def foregroundJob: Job = fgjob.getOrElse { throw new ServerFailed("no foreground job"); } @@ -267,7 +268,6 @@ println(s";; line: $line"); } } publish(ConnectionLost); - in.close(); out.close(); } } } diff --git a/jni.c b/jni.c index 9aa79f8..9ca3651 100644 --- a/jni.c +++ b/jni.c @@ -36,17 +36,22 @@ #include #include -#include - -#include +#include +#include +#include +#include +#include #include #include #include #include +#include #include -#include -#include -#include + +#include + +//#include +#include #include #include @@ -55,6 +60,9 @@ #include +#define TUN_INTERNALS +#include + #undef sun /*----- Magic class names and similar -------------------------------------*/ @@ -63,22 +71,37 @@ #define JNIFUNC(f) Java_uk_org_distorted_tripe_sys_package_00024_##f /* The little class for bundling up error codes. */ -#define ERRENTRY "uk/org/distorted/tripe/sys/package$ErrorEntry" +#define ERRENTCLS "uk/org/distorted/tripe/sys/package$ErrorEntry" + +/* The `sys' package class. */ +#define SYSCLS "uk/org/distorted/tripe/sys/package" + +/* The server lock class. */ +#define LOCKCLS "uk/org/distorted/tripe/sys/package$ServerLock" /* The `stat' class. */ -#define STAT "uk/org/distorted/tripe/sys/package$FileInfo" +#define STATCLS "uk/org/distorted/tripe/sys/package$FileInfo" + +/* Standard Java classes. */ +#define FDCLS "java/io/FileDescriptor" +#define STRCLS "java/lang/String" +#define RANDCLS "java/security/SecureRandom" /* Exception class names. */ #define NULLERR "java/lang/NullPointerException" #define TYPEERR "uk/org/distorted/tripe/sys/package$NativeObjectTypeException" #define SYSERR "uk/org/distorted/tripe/sys/package$SystemError" +#define NAMEERR "uk/org/distorted/tripe/sys/package$NameResolutionException" +#define INITERR "uk/org/distorted/tripe/sys/package$InitializationException" #define ARGERR "java/lang/IllegalArgumentException" +#define STERR "java/lang/IllegalStateException" #define BOUNDSERR "java/lang/IndexOutOfBoundsException" -/*----- Miscellaneous utilities -------------------------------------------*/ +/*----- Essential state ---------------------------------------------------*/ -static void put_cstring(JNIEnv *jni, jbyteArray v, const char *p) - { if (p) (*jni)->ReleaseByteArrayElements(jni, v, (jbyte *)p, JNI_ABORT); } +static JNIEnv *jni_tripe = 0; + +/*----- Miscellaneous utilities -------------------------------------------*/ static void vexcept(JNIEnv *jni, const char *clsname, const char *msg, va_list *ap) @@ -164,6 +187,9 @@ static const char *get_cstring(JNIEnv *jni, jbyteArray v) return ((const char *)(*jni)->GetByteArrayElements(jni, v, 0)); } +static void put_cstring(JNIEnv *jni, jbyteArray v, const char *p) + { if (p) (*jni)->ReleaseByteArrayElements(jni, v, (jbyte *)p, JNI_ABORT); } + static void vexcept_syserror(JNIEnv *jni, const char *clsname, int err, const char *msg, va_list *ap) { @@ -769,11 +795,11 @@ JNIEXPORT jobject JNIFUNC(errtab)(JNIEnv *jni, jobject cls) jobject e; eltcls = - (*jni)->FindClass(jni, ERRENTRY); + (*jni)->FindClass(jni, ERRENTCLS); assert(eltcls); v = (*jni)->NewObjectArray(jni, N(errtab), eltcls, 0); if (!v) return (0); init = (*jni)->GetMethodID(jni, eltcls, "", - "(Ljava/lang/String;I)V"); + "(L"STRCLS";I)V"); assert(init); for (i = 0; i < N(errtab); i++) { @@ -792,7 +818,7 @@ JNIEXPORT jobject JNIFUNC(strerror)(JNIEnv *jni, jobject cls, jint err) static void fdguts(JNIEnv *jni, jclass *cls, jfieldID *fid) { - *cls = (*jni)->FindClass(jni, "java/io/FileDescriptor"); assert(cls); + *cls = (*jni)->FindClass(jni, FDCLS); assert(cls); *fid = (*jni)->GetFieldID(jni, *cls, "fd", "I"); // OpenJDK if (!*fid) *fid = (*jni)->GetFieldID(jni, *cls, "descriptor", "I"); // Android assert(*fid); @@ -1005,7 +1031,7 @@ static jobject xltstat(JNIEnv *jni, const struct stat *st) else if (S_ISLNK(st->st_mode)) modehack |= 0120000; else if (S_ISSOCK(st->st_mode)) modehack |= 0140000; - cls = (*jni)->FindClass(jni, STAT); assert(cls); + cls = (*jni)->FindClass(jni, STATCLS); assert(cls); init = (*jni)->GetMethodID(jni, cls, "", "(IIJIIIIIIJIJJJJ)V"); assert(init); return ((*jni)->NewObject(jni, cls, init, @@ -1146,7 +1172,7 @@ struct trigger { static const struct native_type trigger_type = { "trigger", sizeof(struct trigger), 0x65ffd8b4 }; -JNIEXPORT wrapper JNICALL JNIFUNC(makeTrigger)(JNIEnv *jni, jobject cls) +JNIEXPORT wrapper JNICALL JNIFUNC(make_1trigger)(JNIEnv *jni, jobject cls) { struct trigger trig; int fd[2]; @@ -1174,8 +1200,8 @@ end: return (ret); } -JNIEXPORT void JNICALL JNIFUNC(destroyTrigger)(JNIEnv *jni, jobject cls, - wrapper wtrig) +JNIEXPORT void JNICALL JNIFUNC(destroy_1trigger)(JNIEnv *jni, jobject cls, + wrapper wtrig) { struct trigger trig; @@ -1185,8 +1211,8 @@ JNIEXPORT void JNICALL JNIFUNC(destroyTrigger)(JNIEnv *jni, jobject cls, update_wrapper(jni, &trigger_type, wtrig, &trig); } -JNIEXPORT void JNICALL JNIFUNC(resetTrigger)(JNIEnv *jni, jobject cls, - wrapper wtrig) +JNIEXPORT void JNICALL JNIFUNC(reset_1trigger)(JNIEnv *jni, jobject cls, + wrapper wtrig) { struct trigger trig; char buf[64]; @@ -1218,84 +1244,398 @@ JNIEXPORT void JNICALL JNIFUNC(trigger)(JNIEnv *jni, jobject cls, except_syserror(jni, SYSERR, errno, "failed to pull trigger"); } -/*----- A server connection, using a Unix-domain socket -------------------*/ +/*----- A tunnel supplied by Java -----------------------------------------*/ -struct conn { - struct native_base _base; - int fd; - unsigned f; -#define CF_CLOSERD 1u -#define CF_CLOSEWR 2u -#define CF_CLOSEMASK (CF_CLOSERD | CF_CLOSEWR) +struct tunnel { + const tunnel_ops *ops; + sel_file f; + struct peer *p; }; -static const struct native_type conn_type = - { "conn", sizeof(struct conn), 0xed030167 }; -JNIEXPORT wrapper JNICALL JNIFUNC(connect)(JNIEnv *jni, jobject cls, - jobject path, wrapper wtrig) +static const struct tunnel_ops tun_java; + +static int t_init(void) { return (0); } + +static void t_read(int fd, unsigned mode, void *v) { - struct conn conn; - struct trigger trig; - struct sockaddr_un sun; - int rc, maxfd; - fd_set rfds, wfds; - const char *pathstr = 0; - int err; - socklen_t sz; - wrapper ret = 0; - int nb; - int fd = -1; + tunnel *t = v; + ssize_t n; + buf b; - if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end; - pathstr = get_cstring(jni, path); if (!pathstr) goto end; - if (strlen(pathstr) >= sizeof(sun.sun_path)) { - except(jni, ARGERR, - "Unix-domain socket path `%s' too long", pathstr); + n = read(fd, buf_i, sizeof(buf_i)); + if (n < 0) { + a_warn("TUN", "%s", p_ifname(t->p), "java", + "read-error", "?ERRNO", A_END); + return; + } + IF_TRACING(T_TUNNEL, { + trace(T_TUNNEL, "tun-java: packet arrived"); + trace_block(T_PACKET, "tunnel: packet contents", buf_i, n); + }) + buf_init(&b, buf_i, n); + p_tun(t->p, &b); +} + +static tunnel *t_create(peer *p, int fd, char **ifn) +{ + JNIEnv *jni = jni_tripe; + tunnel *t = 0; + const char *name = p_name(p); + jbyteArray jname; + size_t n = strlen(p_name(p)); + jclass cls, metacls; + jstring jclsname, jexcmsg; + const char *clsname, *excmsg; + jmethodID mid; + jthrowable exc; + + assert(jni); + + jname = wrap_cstring(jni, name); + cls = (*jni)->FindClass(jni, SYSCLS); assert(cls); + mid = (*jni)->GetStaticMethodID(jni, cls, "getTunnelFd", "([B)I"); + assert(mid); + fd = (*jni)->CallStaticIntMethod(jni, cls, mid, jname); + + exc = (*jni)->ExceptionOccurred(jni); + if (exc) { + cls = (*jni)->GetObjectClass(jni, exc); + metacls = (*jni)->GetObjectClass(jni, cls); + mid = (*jni)->GetMethodID(jni, metacls, + "getName", "()L"STRCLS";"); + assert(mid); + jclsname = (*jni)->CallObjectMethod(jni, cls, mid); + clsname = (*jni)->GetStringUTFChars(jni, jclsname, 0); + mid = (*jni)->GetMethodID(jni, cls, + "getMessage", "()L"STRCLS";"); + jexcmsg = (*jni)->CallObjectMethod(jni, exc, mid); + excmsg = (*jni)->GetStringUTFChars(jni, jexcmsg, 0); + a_warn("TUN", "-", "java", "get-tunnel-fd-failed", + "%s", clsname, "%s", excmsg, A_END); + (*jni)->ReleaseStringUTFChars(jni, jclsname, clsname); + (*jni)->ReleaseStringUTFChars(jni, jexcmsg, excmsg); + (*jni)->ExceptionClear(jni); goto end; } - INIT_NATIVE(conn, &conn); - fd = socket(PF_UNIX, SOCK_STREAM, 0); if (fd < 0) goto err; - nb = set_nonblocking(jni, fd, 1); if (nb < 0) goto end; + t = CREATE(tunnel); + t->ops = &tun_java; + t->p = p; + sel_initfile(&sel, &t->f, fd, SEL_READ, t_read, t); - sun.sun_family = AF_UNIX; - strcpy(sun.sun_path, (char *)pathstr); - if (!connect(fd, (struct sockaddr *)&sun, sizeof(sun))) goto connected; - else if (errno != EINPROGRESS) goto err; + if (!*ifn) { + *ifn = xmalloc(n + 5); + sprintf(*ifn, "vpn-%s", name); + } - maxfd = trig.rfd; - if (maxfd < fd) maxfd = fd; - for (;;) { - FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); - FD_ZERO(&wfds); FD_SET(fd, &wfds); - rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err; - if (FD_ISSET(trig.rfd, &rfds)) goto end; - if (FD_ISSET(fd, &wfds)) { - sz = sizeof(sun); - if (!getpeername(fd, (struct sockaddr *)&sun, &sz)) goto connected; - else if (errno != ENOTCONN) goto err; - sz = sizeof(err); - if (!getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &sz)) errno = err; - goto err; - } +end: + return (t); +} + +static void t_inject(tunnel *t, buf *b) +{ + IF_TRACING(T_TUNNEL, { + trace(T_TUNNEL, "tun-java: inject decrypted packet"); + trace_block(T_PACKET, "tunnel: packet contents", BBASE(b), BLEN(b)); + }) + DISCARD(write(t->f.fd, BBASE(b), BLEN(b))); +} + +static void t_destroy(tunnel *t) + { sel_rmfile(&t->f); close(t->f.fd); DESTROY(t); } + +static const struct tunnel_ops tun_java = { + "java", 0, + /* init */ t_init, + /* create */ t_create, + /* setifname */ 0, + /* inject */ t_inject, + /* destroy */ t_destroy +}; + + +JNIEXPORT jint JNICALL JNIFUNC(open_1tun)(JNIEnv *jni, jobject cls) +{ + int ret = -1; + int fd = -1; + struct ifreq iff; + + if ((fd = open("/dev/net/tun", O_RDWR)) < 0) { + except_syserror(jni, SYSERR, errno, "failed to open tunnel device"); + goto end; } -connected: - if (set_nonblocking(jni, fd, nb) < 0) goto end; - conn.fd = fd; fd = -1; - conn.f = 0; - ret = wrap(jni, &conn_type, &conn); - goto end; + if (set_nonblocking(jni, fd, 1) || set_closeonexec(jni, fd)) goto end; + + memset(&iff, 0, sizeof(iff)); + iff.ifr_name[0] = 0; + iff.ifr_flags = IFF_TUN | IFF_NO_PI; + if (ioctl(fd, TUNSETIFF, &iff) < 0) { + except_syserror(jni, SYSERR, errno, "failed to configure tunnel device"); + goto end; + } + + ret = fd; fd = -1; -err: - except_syserror(jni, SYSERR, errno, - "failed to connect to Unix-domain socket `%s'", pathstr); end: if (fd != -1) close(fd); - put_cstring(jni, path, pathstr); return (ret); } +/*----- A custom noise source ---------------------------------------------*/ + +static void javanoise(rand_pool *r) +{ + JNIEnv *jni = jni_tripe; + jclass cls; + jmethodID mid; + jbyteArray v; + jbyte *p; + jsize n; + + noise_devrandom(r); + + assert(jni); + cls = (*jni)->FindClass(jni, RANDCLS); assert(cls); + mid = (*jni)->GetStaticMethodID(jni, cls, "getSeed", "(I)[B"); assert(mid); + v = (*jni)->CallStaticObjectMethod(jni, cls, mid, 32); + if (v) { + n = (*jni)->GetArrayLength(jni, v); + p = (*jni)->GetByteArrayElements(jni, v, 0); + rand_add(r, p, n, n); + (*jni)->ReleaseByteArrayElements(jni, v, p, JNI_ABORT); + } + if ((*jni)->ExceptionOccurred(jni)) { + (*jni)->ExceptionDescribe(jni); + (*jni)->ExceptionClear(jni); + } +} + +static const rand_source javasource = { javanoise, noise_timer }; + +/*----- Embedding the TrIPE server ----------------------------------------*/ + +static void lock_tripe(JNIEnv *jni) +{ + jclass cls = (*jni)->FindClass(jni, LOCKCLS); assert(cls); + (*jni)->MonitorEnter(jni, cls); +} + +static void unlock_tripe(JNIEnv *jni) +{ + jclass cls = (*jni)->FindClass(jni, LOCKCLS); assert(cls); + (*jni)->MonitorExit(jni, cls); +} + +#define STATES(_) \ + _(INIT) \ + _(RESOLVE) \ + _(KEYS) \ + _(BIND) \ + _(READY) \ + _(RUNNING) + +enum { +#define DEFTAG(st) st, + STATES(DEFTAG) +#undef DEFTAG + MAXSTATE +}; + +static const char *statetab[] = { +#define DEFNAME(st) #st, + STATES(DEFNAME) +#undef DEFNAME +}; + +static unsigned state = INIT; +static int clientsk = -1; + +static const char *statename(unsigned st) +{ + if (st >= MAXSTATE) return (""); + else return (statetab[st]); +} + +static int ensure_state(JNIEnv *jni, unsigned want) +{ + unsigned cur; + + lock_tripe(jni); + cur = state; + unlock_tripe(jni); + + if (cur != want) { + except(jni, STERR, "server is in state %s (%u), not %s (%u)", + statename(cur), cur, statename(want), want); + return (-1); + } + return (0); +} + +JNIEXPORT void JNICALL JNIFUNC(base_1init)(JNIEnv *jni, jobject cls) +{ + int fd[2]; + int i; + + for (i = 0; i < N(fd); i++) fd[i] = -1; + + lock_tripe(jni); + jni_tripe = jni; + if (ensure_state(jni, INIT)) goto end; + + if (socketpair(PF_UNIX, SOCK_STREAM, 0, fd)) { + except_syserror(jni, SYSERR, errno, "failed to create socket pair"); + goto end; + } + + clientsk = fd[0]; fd[0] = -1; + + rand_noisesrc(RAND_GLOBAL, &javasource); + rand_seed(RAND_GLOBAL, MAXHASHSZ); + lp_init(); + a_create(fd[1], fd[1], AF_NOTE | AF_WARN | AF_TRACE); fd[1] = -1; + a_switcherr(); + p_addtun(&tun_java); p_setdflttun(&tun_java); + p_init(); + kx_init(); + + state++; + +end: + for (i = 0; i < N(fd); i++) if (fd[i] != -1) close(fd[i]); + jni_tripe = 0; + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(setup_1resolver)(JNIEnv *jni, jobject cls) +{ + lock_tripe(jni); + if (ensure_state(jni, RESOLVE)) goto end; + + if (a_init()) + { except(jni, INITERR, "failed to initialize resolver"); return; } + + state++; + +end: + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(load_1keys)(JNIEnv *jni, jobject cls, + jobject privstr, jobject pubstr, + jobject tagstr) +{ + const char *priv = 0, *pub = 0, *tag = 0; + + lock_tripe(jni); + if (ensure_state(jni, KEYS)) return; + + priv = get_cstring(jni, privstr); if (!priv) goto end; + pub = get_cstring(jni, pubstr); if (!pub) goto end; + tag = get_cstring(jni, tagstr); if (!tag) goto end; + + if (km_init(priv, pub, tag)) + { except(jni, INITERR, "failed to load initial keys"); goto end; } + + state++; + +end: + put_cstring(jni, privstr, priv); + put_cstring(jni, pubstr, pub); + put_cstring(jni, tagstr, tag); + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(unload_1keys)(JNIEnv *jni, jobject cls) +{ + lock_tripe(jni); + if (ensure_state(jni, KEYS + 1)) goto end; + + km_clear(); + + state--; + +end: + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(bind)(JNIEnv *jni, jobject cls, + jbyteArray hoststr, jbyteArray svcstr) +{ + const char *host = 0, *svc = 0; + struct addrinfo hint, *ai = 0; + int err; + + lock_tripe(jni); + if (ensure_state(jni, BIND)) goto end; + + if (hoststr) { host = get_cstring(jni, hoststr); if (!host) goto end; } + svc = get_cstring(jni, svcstr); if (!svc) goto end; + + hint.ai_socktype = SOCK_DGRAM; + hint.ai_family = AF_UNSPEC; + hint.ai_protocol = IPPROTO_UDP; + hint.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + err = getaddrinfo(host, svc, &hint, &ai); + if (err) { + except(jni, NAMEERR, "failed to resolve %c%s%c, port `%s': %s", + host ? '`' : '<', host ? host : "nil", host ? '\'' : '>', + svc, gai_strerror(err)); + goto end; + } + + if (p_bind(ai)) + { except(jni, INITERR, "failed to bind master socket"); goto end; } + + state++; + +end: + if (ai) freeaddrinfo(ai); + put_cstring(jni, hoststr, host); + put_cstring(jni, svcstr, svc); + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(unbind)(JNIEnv *jni, jobject cls) +{ + lock_tripe(jni); + if (ensure_state(jni, BIND + 1)) goto end; + + p_unbind(); + + state--; + +end: + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(mark)(JNIEnv *jni, jobject cls, jint seq) +{ + lock_tripe(jni); + a_notify("MARK", "%d", seq, A_END); + unlock_tripe(jni); +} + +JNIEXPORT void JNICALL JNIFUNC(run)(JNIEnv *jni, jobject cls) +{ + lock_tripe(jni); + if (ensure_state(jni, READY)) goto end; + assert(!jni_tripe); + jni_tripe = jni; + state = RUNNING; + unlock_tripe(jni); + + lp_run(); + + lock_tripe(jni); + jni_tripe = 0; + state = READY; + +end: + unlock_tripe(jni); +} + static int check_buffer_bounds(JNIEnv *jni, const char *what, jbyteArray buf, jint start, jint len) { @@ -1324,18 +1664,18 @@ static int check_buffer_bounds(JNIEnv *jni, const char *what, } JNIEXPORT void JNICALL JNIFUNC(send)(JNIEnv *jni, jobject cls, - wrapper wconn, jbyteArray buf, + jbyteArray buf, jint start, jint len, wrapper wtrig) { - struct conn conn; struct trigger trig; int rc, maxfd; ssize_t n; fd_set rfds, wfds; jbyte *p = 0; - if (unwrap(jni, &conn, &conn_type, wconn)) goto end; + if (ensure_state(jni, RUNNING)) goto end; + if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end; if (check_buffer_bounds(jni, "send", buf, start, len)) goto end; @@ -1343,14 +1683,14 @@ JNIEXPORT void JNICALL JNIFUNC(send)(JNIEnv *jni, jobject cls, if (!p) goto end; maxfd = trig.rfd; - if (maxfd < conn.fd) maxfd = conn.fd; + if (maxfd < clientsk) maxfd = clientsk; while (len) { FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); - FD_ZERO(&wfds); FD_SET(conn.fd, &wfds); + FD_ZERO(&wfds); FD_SET(clientsk, &wfds); rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err; if (FD_ISSET(trig.rfd, &rfds)) break; - if (FD_ISSET(conn.fd, &wfds)) { - n = send(conn.fd, p + start, len, 0); + if (FD_ISSET(clientsk, &wfds)) { + n = send(clientsk, p + start, len, 0); if (n >= 0) { start += n; len -= n; } else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err; } @@ -1365,18 +1705,24 @@ end: } JNIEXPORT jint JNICALL JNIFUNC(recv)(JNIEnv *jni, jobject cls, - wrapper wconn, jbyteArray buf, + jbyteArray buf, jint start, jint len, wrapper wtrig) { - struct conn conn; struct trigger trig; int maxfd; fd_set rfds; jbyte *p = 0; jint rc = -1; - if (unwrap(jni, &conn, &conn_type, wconn)) goto end; + lock_tripe(jni); + if (clientsk == -1) { + except(jni, STERR, "client connection not established"); + unlock_tripe(jni); + goto end; + } + unlock_tripe(jni); + if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end; if (check_buffer_bounds(jni, "send", buf, start, len)) goto end; @@ -1384,15 +1730,15 @@ JNIEXPORT jint JNICALL JNIFUNC(recv)(JNIEnv *jni, jobject cls, if (!p) goto end; maxfd = trig.rfd; - if (maxfd < conn.fd) maxfd = conn.fd; + if (maxfd < clientsk) maxfd = clientsk; for (;;) { - FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); FD_SET(conn.fd, &rfds); + FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); FD_SET(clientsk, &rfds); rc = select(maxfd + 1, &rfds, 0, 0, 0); if (rc < 0) goto err; if (FD_ISSET(trig.rfd, &rfds)) { break; } - if (FD_ISSET(conn.fd, &rfds)) { - rc = recv(conn.fd, p + start, len, 0); + if (FD_ISSET(clientsk, &rfds)) { + rc = recv(clientsk, p + start, len, 0); if (rc >= 0) break; else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err; } @@ -1407,28 +1753,4 @@ end: return (rc); } -JNIEXPORT void JNICALL JNIFUNC(closeconn)(JNIEnv *jni, jobject cls, - wrapper wconn, jint how) -{ - struct conn conn; - int rc; - - if (unwrap(jni, &conn, &conn_type, wconn)) goto end; - if (conn.fd == -1) goto end; - - how &= CF_CLOSEMASK&~conn.f; - conn.f |= how; - if ((conn.f&CF_CLOSEMASK) == CF_CLOSEMASK) { - close(conn.fd); - conn.fd = -1; - } else { - if (how&CF_CLOSERD) shutdown(conn.fd, SHUT_RD); - if (how&CF_CLOSEWR) shutdown(conn.fd, SHUT_WR); - } - rc = update_wrapper(jni, &conn_type, wconn, &conn); assert(!rc); - -end: - return; -} - /*----- That's all, folks -------------------------------------------------*/ diff --git a/keys.scala b/keys.scala index b9595ec..bdbede9 100644 --- a/keys.scala +++ b/keys.scala @@ -120,7 +120,7 @@ private val DEFAULTS: Seq[(String, Config => String)] = "sig-fresh" -> { _ => "always" }, "fingerprint-hash" -> { _("hash") }); -private def parseConfig(file: File): Config = { +private def parseConfig(file: File): HashMap[String, String] = { /* Build the new configuration in a temporary place. */ var m = HashMap[String, String](); @@ -131,7 +131,7 @@ private def parseConfig(file: File): Config = { for (line <- lines(in)) { line match { case RX_COMMENT() => ok; - case RX_KEYVAL(key, value) => m += key -> value; + case RX_KEYVAL(key, value) => m(key) = value; case _ => throw new ConfigSyntaxError(file.getPath, lno, "failed to parse line"); @@ -150,7 +150,7 @@ private def readConfig(file: File): Config = { /* Fill in defaults where things have been missed out. */ for ((key, dflt) <- DEFAULTS) { if (!(m contains key)) { - try { m += key -> dflt(m); } + try { m(key) = dflt(m); } catch { case e: DefaultFailed => throw new ConfigDefaultFailed(file.getPath, key, diff --git a/sys.scala b/sys.scala index 6931431..402bf1d 100644 --- a/sys.scala +++ b/sys.scala @@ -28,7 +28,7 @@ package uk.org.distorted.tripe; package object sys { /*----- Imports -----------------------------------------------------------*/ import scala.collection.convert.decorateAsJava._; -import scala.collection.mutable.HashSet; +import scala.collection.mutable.{HashMap, HashSet}; import java.io.{BufferedReader, BufferedWriter, Closeable, File, FileDescriptor, FileInputStream, FileOutputStream, @@ -124,7 +124,7 @@ import StringImplicits._; /*----- Main code ---------------------------------------------------------*/ /* Import the native code library. */ -System.loadLibrary("toy"); +System.loadLibrary("tripe"); /* Native types. * @@ -810,15 +810,15 @@ private final val maxTriggers = 2; private var nTriggers = 0; private var triggers: List[Wrapper] = Nil; -@native protected def makeTrigger(): Wrapper; -@native protected def destroyTrigger(trig: Wrapper); -@native protected def resetTrigger(trig: Wrapper); +@native protected def make_trigger(): Wrapper; +@native protected def destroy_trigger(trig: Wrapper); +@native protected def reset_trigger(trig: Wrapper); @native protected def trigger(trig: Wrapper); private def getTrigger(): Wrapper = { triggerLock synchronized { if (nTriggers == 0) - makeTrigger() + make_trigger() else { val trig = triggers.head; triggers = triggers.tail; @@ -829,10 +829,10 @@ private def getTrigger(): Wrapper = { } private def putTrigger(trig: Wrapper) { - resetTrigger(trig); + reset_trigger(trig); triggerLock synchronized { if (nTriggers >= maxTriggers) - destroyTrigger(trig); + destroy_trigger(trig); else { triggers ::= trig; nTriggers += 1; @@ -859,59 +859,69 @@ def interruptWithTrigger[T](body: Wrapper => T): T = { }; } -/*----- Connecting to a server --------------------------------------------*/ +/*----- Glue for the VPN server -------------------------------------------*/ -/* Primitive operations. */ -final val CF_CLOSERD = 1; -final val CF_CLOSEWR = 2; -final val CF_CLOSEMASK = CF_CLOSERD | CF_CLOSEWR; -@native protected def connect(path: CString, trig: Wrapper): Wrapper; -@native protected def send(conn: Wrapper, buf: CString, - start: Int, len: Int, trig: Wrapper); -@native protected def recv(conn: Wrapper, buf: CString, - start: Int, len: Int, trig: Wrapper): Int; -@native def closeconn(conn: Wrapper, how: Int); - -class Connection(path: String) extends Closeable { - - /* The underlying primitive connection. */ - private[this] val conn = interruptWithTrigger { trig => - connect(path.toCString, trig); - }; - - /* Alternative constructors. */ - def this(file: File) { this(file.getPath); } +/* The lock class. This is only a class because they're much easier to find + * than loose objects through JNI. + */ +private class ServerLock; - /* Cleanup.*/ - override def close() { closeconn(conn, CF_CLOSEMASK); } - override protected def finalize() { super.finalize(); close(); } +/* Exceptions. */ +class NameResolutionException(msg: String) extends Exception(msg); +class InitializationException(msg: String) extends Exception(msg); - class Input private[Connection] extends InputStream { - /* An input stream which reads from the connection. */ +/* Primitive operations. */ +@native protected def open_tun(): Int; +@native protected def base_init(); +@native protected def setup_resolver(); +@native def load_keys(priv: CString, pub: CString, tag: CString); +@native def unload_keys(); +@native def bind(host: CString, svc: CString); +@native def unbind(); +@native def mark(seq: Int); +@native def run(); +@native protected def send(buf: CString, start: Int, len: Int, + trig: Wrapper); +@native protected def recv(buf: CString, start: Int, len: Int, + trig: Wrapper): Int; + +base_init(); +setup_resolver(); + +/* Tunnel descriptor plumbing. */ +val pending = HashMap[String, Int](); + +def getTunnelFd(peer: CString): Int = + pending synchronized { pending(peer.toJString) }; +def storeTunnelFd(peer: String, fd: Int) + { pending synchronized { pending(peer) = fd; } } +def withdrawTunnelFd(peer: String) + { pending synchronized { pending -= peer; } } +def withTunnelFd[T](peer: String, fd: Int)(body: => T): T = { + storeTunnelFd(peer, fd); + try { body } finally { withdrawTunnelFd(peer); } +} - override def read(): Int = { - val buf = new Array[Byte](1); - val n = read(buf, 0, 1); - if (n < 0) -1 else buf(0)&0xff; - } - override def read(buf: Array[Byte]): Int = - read(buf, 0, buf.length); - override def read(buf: Array[Byte], start: Int, len: Int) = - interruptWithTrigger { trig => recv(conn, buf, start, len, trig); }; - override def close() { closeconn(conn, CF_CLOSERD); } +/* Server I/O. */ +lazy val serverInput: InputStream = new InputStream { + override def read(): Int = { + val buf = new Array[Byte](1); + val n = read(buf, 0, 1); + if (n < 0) -1 else buf(0)&0xff; } - lazy val input = new Input; - - class Output private[Connection] extends OutputStream { - /* An output stream which writes to the connection. */ + override def read(buf: Array[Byte]): Int = + read(buf, 0, buf.length); + override def read(buf: Array[Byte], start: Int, len: Int) = + interruptWithTrigger { trig => recv(buf, start, len, trig); }; + override def close() { } +} - override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); } - override def write(buf: Array[Byte]) { write(buf, 0, buf.length); } - override def write(buf: Array[Byte], start: Int, len: Int) - { interruptWithTrigger { trig => send(conn, buf, start, len, trig); } } - override def close() { closeconn(conn, CF_CLOSEWR); } - } - lazy val output = new Output; +lazy val serverOutput: OutputStream = new OutputStream { + override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); } + override def write(buf: Array[Byte]) { write(buf, 0, buf.length); } + override def write(buf: Array[Byte], start: Int, len: Int) + { interruptWithTrigger { trig => send(buf, start, len, trig); } } + override def close() { } } /*----- Crypto-library hacks ----------------------------------------------*/ -- 2.11.0