From 04a5abaece151705e9bd7026653f79938a7a2fbc Mon Sep 17 00:00:00 2001 From: Mark Wooding Date: Wed, 13 Jun 2018 01:55:56 +0100 Subject: [PATCH] More work in progress. --- Makefile | 6 +- jni.c | 210 ++++++++++++++++++++++++++++++++++---- keys.scala | 312 ++++++++++++++++++++++++++++++++++++--------------------- progress.scala | 253 +++++++++++----------------------------------- sys.scala | 92 +++++++++++++---- terminal.scala | 177 ++++++++++++++++++++++++++++++++ util.scala | 180 +++++++++++++++++++++++++++------ 7 files changed, 854 insertions(+), 376 deletions(-) create mode 100644 terminal.scala diff --git a/Makefile b/Makefile index e00445f..c4f471a 100644 --- a/Makefile +++ b/Makefile @@ -198,7 +198,9 @@ CLASSES += util CLASSES += sys:util CLASSES += admin:sys,util CLASSES += tar:util -CLASSES += keys:tar,sys,util +CLASSES += progress:sys,util +CLASSES += keys:progress,tar,sys,util +CLASSES += terminal:progress,sys,util CLASSES += main:sys ## Machinery for parsing the `CLASSES' list. @@ -249,7 +251,7 @@ all:: $(TARGETS) clean::; rm -f $(CLEANFILES) $(TARGETS) -repl: $(TARGETS) +repl: all $(SCALA_REPL) -cp $(CLASSDIR) -Djava.library.path=$(OUTDIR) .PHONY: repl diff --git a/jni.c b/jni.c index f65e464..9aa79f8 100644 --- a/jni.c +++ b/jni.c @@ -39,6 +39,7 @@ #include #include +#include #include #include #include @@ -192,6 +193,31 @@ static void except_syserror(JNIEnv *jni, const char *clsname, va_end(ap); } +static int set_nonblocking(JNIEnv *jni, int fd, int nb) +{ + int f0 = fcntl(fd, F_GETFL), f1; + if (f0 < 0) goto err; + if (nb) f1 = f0 | O_NONBLOCK; + else f1 = f0&~O_NONBLOCK; + if (fcntl(fd, F_SETFL, f1)) goto err; + return (f0 & O_NONBLOCK); +err: + except_syserror(jni, SYSERR, errno, + "failed to set descriptor nonblocking"); + return (-1); +} + +static int set_closeonexec(JNIEnv *jni, int fd) +{ + int f = fcntl(fd, F_GETFD); + if (f < 0 || fcntl(fd, F_SETFD, f | FD_CLOEXEC)) { + except_syserror(jni, SYSERR, errno, + "failed to set descriptor close-on-exec"); + return (-1); + } + return (0); +} + /*----- Wrapping native types ---------------------------------------------*/ /* There's no way defined in the JNI to stash a C pointer in a Java object. @@ -808,7 +834,7 @@ JNIEXPORT jboolean JNIFUNC(isatty)(JNIEnv *jni, jobject cls, jobject jfd) /* Java has these already, as methods on `java.io.File' objects. Alas, these * methods are useless at reporting errors: they tend to return a `boolean' - * success/ fail indicator, and throw away any more detailed information. + * success/fail indicator, and throw away any more detailed information. * There's better functionality in `java.nio.file.Files', but that only turns * up in Android API 26 (in 7.0 Nougat). There's `android.system.Os', which * has a bunch of POSIX-shaped functions -- but they're only in Android API @@ -1102,6 +1128,96 @@ end: put_cstring(jni, path, pathstr); } +/*----- Triggers ----------------------------------------------------------*/ + +/* A trigger is a gadget for waking up a thread which is blocking on I/O, + * and it's used to implement interruptability. + * + * Really, a trigger is a pipe. A `blocking' I/O operation secretly uses + * select(2) to block on the descriptor of interest /and/ the read side of + * the trigger pipe. To wake up a thread that's blocked, we just write a + * byte (nobody cares /which/ byte) to the write end. + */ + +struct trigger { + struct native_base _base; + int rfd, wfd; +}; +static const struct native_type trigger_type = + { "trigger", sizeof(struct trigger), 0x65ffd8b4 }; + +JNIEXPORT wrapper JNICALL JNIFUNC(makeTrigger)(JNIEnv *jni, jobject cls) +{ + struct trigger trig; + int fd[2]; + int i; + wrapper ret = 0; + + fd[0] = fd[1] = -1; + if (pipe(fd)) { + except_syserror(jni, SYSERR, errno, "failed to create pipe"); + goto end; + } + for (i = 0; i < 2; i++) { + if (set_nonblocking(jni, fd[i], 1) < 0 || set_closeonexec(jni, fd[i])) + goto end; + } + + INIT_NATIVE(trigger, &trig); + trig.rfd = fd[0]; fd[0] = -1; + trig.wfd = fd[1]; fd[1] = -1; + ret = wrap(jni, &trigger_type, &trig); + +end: + for (i = 0; i < 2; i++) + if (fd[i] != -1) close(fd[i]); + return (ret); +} + +JNIEXPORT void JNICALL JNIFUNC(destroyTrigger)(JNIEnv *jni, jobject cls, + wrapper wtrig) +{ + struct trigger trig; + + if (unwrap(jni, &trig, &trigger_type, wtrig)) return; + if (trig.rfd != -1) { close(trig.rfd); trig.rfd = -1; } + if (trig.wfd != -1) { close(trig.wfd); trig.wfd = -1; } + update_wrapper(jni, &trigger_type, wtrig, &trig); +} + +JNIEXPORT void JNICALL JNIFUNC(resetTrigger)(JNIEnv *jni, jobject cls, + wrapper wtrig) +{ + struct trigger trig; + char buf[64]; + ssize_t n; + + if (unwrap(jni, &trig, &trigger_type, wtrig)) return; + for (;;) { + n = read(trig.rfd, buf, sizeof(buf)); + if (n > 0) continue; + assert(n < 0); + if (errno == EAGAIN || errno == EWOULDBLOCK) break; + else { + except_syserror(jni, SYSERR, errno, "failed to reset trigger"); + break; + } + } +} + +JNIEXPORT void JNICALL JNIFUNC(trigger)(JNIEnv *jni, jobject cls, + wrapper wtrig) +{ + struct trigger trig; + ssize_t n; + char c = 0; + + if (unwrap(jni, &trig, &trigger_type, wtrig)) return; + n = write(trig.wfd, &c, 1); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) + except_syserror(jni, SYSERR, errno, "failed to pull trigger"); +} + /*----- A server connection, using a Unix-domain socket -------------------*/ struct conn { @@ -1116,14 +1232,21 @@ static const struct native_type conn_type = { "conn", sizeof(struct conn), 0xed030167 }; JNIEXPORT wrapper JNICALL JNIFUNC(connect)(JNIEnv *jni, jobject cls, - jobject path) + jobject path, wrapper wtrig) { struct conn conn; + struct trigger trig; struct sockaddr_un sun; + int rc, maxfd; + fd_set rfds, wfds; const char *pathstr = 0; - jobject ret = 0; + int err; + socklen_t sz; + wrapper ret = 0; + int nb; int fd = -1; + if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end; pathstr = get_cstring(jni, path); if (!pathstr) goto end; if (strlen(pathstr) >= sizeof(sun.sun_path)) { except(jni, ARGERR, @@ -1132,12 +1255,33 @@ JNIEXPORT wrapper JNICALL JNIFUNC(connect)(JNIEnv *jni, jobject cls, } INIT_NATIVE(conn, &conn); - fd = socket(SOCK_STREAM, PF_UNIX, 0); if (fd < 0) goto err; + fd = socket(PF_UNIX, SOCK_STREAM, 0); if (fd < 0) goto err; + nb = set_nonblocking(jni, fd, 1); if (nb < 0) goto end; sun.sun_family = AF_UNIX; strcpy(sun.sun_path, (char *)pathstr); - if (connect(fd, (struct sockaddr *)&sun, sizeof(sun))) goto err; + if (!connect(fd, (struct sockaddr *)&sun, sizeof(sun))) goto connected; + else if (errno != EINPROGRESS) goto err; + + maxfd = trig.rfd; + if (maxfd < fd) maxfd = fd; + for (;;) { + FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); + FD_ZERO(&wfds); FD_SET(fd, &wfds); + rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err; + if (FD_ISSET(trig.rfd, &rfds)) goto end; + if (FD_ISSET(fd, &wfds)) { + sz = sizeof(sun); + if (!getpeername(fd, (struct sockaddr *)&sun, &sz)) goto connected; + else if (errno != ENOTCONN) goto err; + sz = sizeof(err); + if (!getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &sz)) errno = err; + goto err; + } + } +connected: + if (set_nonblocking(jni, fd, nb) < 0) goto end; conn.fd = fd; fd = -1; conn.f = 0; ret = wrap(jni, &conn_type, &conn); @@ -1147,7 +1291,7 @@ err: except_syserror(jni, SYSERR, errno, "failed to connect to Unix-domain socket `%s'", pathstr); end: - if (fd == -1) close(fd); + if (fd != -1) close(fd); put_cstring(jni, path, pathstr); return (ret); } @@ -1181,28 +1325,40 @@ static int check_buffer_bounds(JNIEnv *jni, const char *what, JNIEXPORT void JNICALL JNIFUNC(send)(JNIEnv *jni, jobject cls, wrapper wconn, jbyteArray buf, - jint start, jint len) + jint start, jint len, + wrapper wtrig) { struct conn conn; + struct trigger trig; + int rc, maxfd; ssize_t n; + fd_set rfds, wfds; jbyte *p = 0; if (unwrap(jni, &conn, &conn_type, wconn)) goto end; + if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end; if (check_buffer_bounds(jni, "send", buf, start, len)) goto end; p = (*jni)->GetByteArrayElements(jni, buf, 0); if (!p) goto end; + maxfd = trig.rfd; + if (maxfd < conn.fd) maxfd = conn.fd; while (len) { - n = send(conn.fd, p + start, len, 0); - if (n < 0) { - except_syserror(jni, SYSERR, - errno, "failed to send on connection"); - goto end; + FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); + FD_ZERO(&wfds); FD_SET(conn.fd, &wfds); + rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err; + if (FD_ISSET(trig.rfd, &rfds)) break; + if (FD_ISSET(conn.fd, &wfds)) { + n = send(conn.fd, p + start, len, 0); + if (n >= 0) { start += n; len -= n; } + else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err; } - start += n; len -= n; } + goto end; +err: + except_syserror(jni, SYSERR, errno, "failed to send on connection"); end: if (p) (*jni)->ReleaseByteArrayElements(jni, buf, p, JNI_ABORT); return; @@ -1210,26 +1366,42 @@ end: JNIEXPORT jint JNICALL JNIFUNC(recv)(JNIEnv *jni, jobject cls, wrapper wconn, jbyteArray buf, - jint start, jint len) + jint start, jint len, + wrapper wtrig) { struct conn conn; + struct trigger trig; + int maxfd; + fd_set rfds; jbyte *p = 0; jint rc = -1; if (unwrap(jni, &conn, &conn_type, wconn)) goto end; + if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end; if (check_buffer_bounds(jni, "send", buf, start, len)) goto end; p = (*jni)->GetByteArrayElements(jni, buf, 0); if (!p) goto end; - rc = recv(conn.fd, p + start, len, 0); - if (rc < 0) { - except_syserror(jni, SYSERR, - errno, "failed to read from connection"); - goto end; + maxfd = trig.rfd; + if (maxfd < conn.fd) maxfd = conn.fd; + for (;;) { + FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); FD_SET(conn.fd, &rfds); + rc = select(maxfd + 1, &rfds, 0, 0, 0); if (rc < 0) goto err; + if (FD_ISSET(trig.rfd, &rfds)) { + break; + } + if (FD_ISSET(conn.fd, &rfds)) { + rc = recv(conn.fd, p + start, len, 0); + if (rc >= 0) break; + else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err; + } } if (!rc) rc = -1; + goto end; +err: + except_syserror(jni, SYSERR, errno, "failed to read from connection"); end: if (p) (*jni)->ReleaseByteArrayElements(jni, buf, p, 0); return (rc); diff --git a/keys.scala b/keys.scala index b49e334..cec56a9 100644 --- a/keys.scala +++ b/keys.scala @@ -37,15 +37,17 @@ import sys.{SystemError, hashsz, runCommand}; import sys.Errno.EEXIST; import sys.FileImplicits._; +import progress.{Eyecandy, SimpleModel, DataModel}; + /*----- Useful regular expressions ----------------------------------------*/ -private val RX_COMMENT = """(?x) ^ \s* (?: \# .* )? $""".r; -private val RX_KEYVAL = """(?x) ^ \s* +private final val RX_COMMENT = """(?x) ^ \s* (?: \# .* )? $""".r; +private final val RX_KEYVAL = """(?x) ^ \s* ([-\w]+) (?:\s+(?!=)|\s*=\s*) (|\S|\S.*\S) \s* $""".r; -private val RX_DOLLARSUBST = """(?x) \$ \{ ([-\w]+) \}""".r; +private final val RX_DOLLARSUBST = """(?x) \$ \{ ([-\w]+) \}""".r; /*----- Things that go wrong ----------------------------------------------*/ @@ -106,26 +108,70 @@ private val DEFAULTS: Seq[(String, Config => String)] = "sig-fresh" -> { _ => "always" }, "fingerprint-hash" -> { _("hash") }); -/*----- Managing a key repository -----------------------------------------*/ +private def readConfig(file: File): Config = { -def downloadToFile(file: File, url: URL, maxlen: Long = Long.MaxValue) { - fetchURL(url, new URLFetchCallbacks { - val out = file.openForOutput(); - private def toobig() { - throw new KeyConfigException(s"remote file `$url' is " + - "suspiciously large"); - } - var totlen: Long = 0; - override def preflight(conn: URLConnection) { - totlen = conn.getContentLength; - if (totlen > maxlen) toobig(); + /* Build the new configuration in a temporary place. */ + var m = HashMap[String, String](); + + /* Read the config file into our map. */ + file.withReader { in => + var lno = 1; + for (line <- lines(in)) { + line match { + case RX_COMMENT() => ok; + case RX_KEYVAL(key, value) => m += key -> value; + case _ => + throw new ConfigSyntaxError(file.getPath, lno, + "failed to parse line"); + } + lno += 1; } - override def done(win: Boolean) { out.close(); } - def write(buf: Array[Byte], n: Int, len: Long) { - if (len + n > maxlen) toobig(); - out.write(buf, 0, n); + } + + /* Fill in defaults where things have been missed out. */ + for ((key, dflt) <- DEFAULTS) { + if (!(m contains key)) { + try { m += key -> dflt(m); } + catch { + case e: DefaultFailed => + throw new ConfigDefaultFailed(file.getPath, key, + e.key, m(e.key)); + } } - }); + } + + /* And we're done. */ + m +} + +/*----- Managing a key repository -----------------------------------------*/ + +def downloadToFile(file: File, url: URL, + maxlen: Long = Long.MaxValue, + ic: Eyecandy) { + ic.job(new SimpleModel(s"connecting to `$url'", -1)) { jr => + fetchURL(url, new URLFetchCallbacks { + val out = file.openForOutput(); + private def toobig() { + throw new KeyConfigException( + s"remote file `$url' is suspiciously large"); + } + var totlen: Long = 0; + override def preflight(conn: URLConnection) { + totlen = conn.getContentLength; + if (totlen > maxlen) toobig(); + jr.change(new SimpleModel(s"downloading `$url'", totlen) + with DataModel, + 0); + } + override def done(win: Boolean) { out.close(); } + def write(buf: Array[Byte], n: Int, len: Long) { + if (len + n > maxlen) toobig(); + out.write(buf, 0, n); + jr.step(len + n); + } + }) + } } /* Lifecycle notes @@ -161,34 +207,65 @@ def downloadToFile(file: File, url: URL, maxlen: Long = Long.MaxValue) { * (delete old/) */ +class RepositoryStateException(val state: Repository.State.Value, + msg: String) + extends Exception(msg); + +class KeyConfigException(msg: String) extends Exception(msg); + +private def launderFingerprint(fp: String): String = + fp filter { _.isLetterOrDigit }; + +private def fingerprintsEqual(a: String, b: String) = + launderFingerprint(a) == launderFingerprint(b); + +private def keyFingerprint(kr: File, tag: String, hash: String): String = { + val (out, _) = runCommand("key", "-k", kr.getPath, "fingerprint", + "-a", hash, "-f", "-secret", tag); + nextToken(out) match { + case Some((fp, _)) => fp + case _ => + throw new java.io.IOException("unexpected output from `key fingerprint"); + } +} + object Repository { object State extends Enumeration { val Empty, Pending, Confirmed, Updating, Committing, Live = Value; } } -class RepositoryStateException(val state: Repository.State.Value, - msg: String) - extends Exception(msg); +def checkConfigSanity(file: File, ic: Eyecandy) { + ic.operation("checking new configuration") { _ => -class KeyConfigException(msg: String) extends Exception(msg); + /* Make sure we can read and understand the file. */ + val conf = readConfig(file); + + /* Make sure there are entries which we can use to update. This won't + * guarantee that we can reliably update, but it will help. + */ + conf("repos-url"); conf("sig-url"); + conf("fingerprint-hash"); conf("sig-fresh"); + conf("master-sequence"); conf("hk-master"); + } +} class Repository(val root: File) extends Closeable { import Repository.State.{Value => State, _}; /* Important directories and files. */ - private[this] val livedir = root + "live"; - private[this] val livereposdir = livedir + "repos"; - private[this] val newdir = root + "new"; - private[this] val olddir = root + "old"; - private[this] val pendingdir = root + "pending"; - private[this] val tmpdir = root + "tmp"; + private[this] val livedir = root/"live"; + private[this] val livereposdir = livedir/"repos"; + private[this] val newdir = root/"new"; + private[this] val olddir = root/"old"; + private[this] val pendingdir = root/"pending"; + private[this] val tmpdir = root/"tmp"; /* Take out a lock in case of other instances. */ private[this] val lock = { try { root.mkdir_!(); } catch { case SystemError(EEXIST, _) => ok; } - (root + "lk").lock_!() + (root/"lk").lock_!() } def close() { lock.close(); } @@ -269,9 +346,10 @@ class Repository(val root: File) extends Closeable { } } } - def destroy() { + def destroy(ic: Eyecandy) { /* Clear out the entire repository. Everything. It's all gone. */ - root.foreachFile { f => if (f.getName != "lk") f.rmTree(); } + ic.operation("clearing configuration") + { _ => root.foreachFile { f => if (f.getName != "lk") f.rmTree(); } } } def clearTmp() { @@ -293,66 +371,40 @@ class Repository(val root: File) extends Closeable { case Empty => throw new RepositoryStateException(Empty, "repository is Empty"); } - val file = dir + "tripe-keys.conf"; - - /* Build the new configuration in a temporary place. */ - var m = HashMap[String, String](); - - /* Read the config file into our map. */ - file.withReader { in => - var lno = 1; - for (line <- lines(in)) { - line match { - case RX_COMMENT() => ok; - case RX_KEYVAL(key, value) => m += key -> value; - case _ => - throw new ConfigSyntaxError(file.getPath, lno, - "failed to parse line"); - } - lno += 1; - } - } - - /* Fill in defaults where things have been missed out. */ - for ((key, dflt) <- DEFAULTS) { - if (!(m contains key)) { - try { m += key -> dflt(m); } - catch { - case e: DefaultFailed => - throw new ConfigDefaultFailed(file.getPath, key, - e.key, m(e.key)); - } - } - } - /* All done. */ - _config = m; + /* And then read the configuration. */ + _config = readConfig(dir/"tripe-keys.conf"); } _config } - def fetchConfig(url: URL) { + def fetchConfig(url: URL, ic: Eyecandy) { /* Fetch an initial configuration file from a given URL. */ checkState(Empty); clearTmp(); - downloadToFile(tmpdir + "tripe-keys.conf", url); - tmpdir.rename_!(pendingdir); + + val conffile = tmpdir/"tripe-keys.conf"; + downloadToFile(conffile, url, 16*1024, ic); + checkConfigSanity(conffile, ic); + ic.operation("committing configuration") + { _ => tmpdir.rename_!(pendingdir); } invalidate(); // should move to `Pending' } - def confirm() { + def confirm(ic: Eyecandy) { /* The user has approved the master key fingerprint in the `Pending' * configuration. Advance to `Confirmed'. */ checkState(Pending); - pendingdir.rename_!(livedir); + ic.operation("confirming configuration") + { _ => pendingdir.rename_!(livedir); } invalidate(); // should move to `Confirmed' } - def update() { + def update(ic: Eyecandy) { /* Update the repository from the master. * * Fetch a (possibly new) archive; unpack it; verify the master key @@ -364,66 +416,98 @@ class Repository(val root: File) extends Closeable { clearTmp(); /* First thing is to download the tarball and signature. */ - val tarfile = tmpdir + "tripe-keys.tar.gz"; - downloadToFile(tarfile, new URL(conf("repos-url"))); - val sigfile = tmpdir + "tripe-keys.sig"; + val tarfile = tmpdir/"tripe-keys.tar.gz"; + downloadToFile(tarfile, new URL(conf("repos-url")), 256*1024, ic); + val sigfile = tmpdir/"tripe-keys.sig"; val seq = conf("master-sequence"); downloadToFile(sigfile, new URL(conf("sig-url").replaceAllLiterally("", - seq))); + seq)), + 4*1024, ic); /* Unpack the tarball. Carefully. */ - val unpkdir = tmpdir + "unpk"; - unpkdir.mkdir_!(); - withCleaner { clean => - val tar = new TarFile(new GZIPInputStream(tarfile.open())); - clean { tar.close(); } - for (e <- tar) { - - /* Check the filename to make sure it's not evil. */ - if (e.name.split('/').exists { _ == ".." }) - throw new KeyConfigException("invalid path in tarball"); - - /* Find out where this file points. */ - val f = unpkdir + e.name; - - /* Unpack it. */ - if (e.isdir) { - /* A directory. Create it if it doesn't exist already. */ - - try { f.mkdir_!(); } - catch { case SystemError(EEXIST, _) => ok; } - } else if (e.isreg) { - /* A regular file. Write stuff to it. */ - - e.withStream { in => - f.withOutput { out => - for ((b, n) <- blocks(in)) out.write(b, 0, n); + val unpkdir = tmpdir/"unpk"; + ic.operation("unpacking archive") { or => + unpkdir.mkdir_!(); + withCleaner { clean => + val tar = new TarFile(new GZIPInputStream(tarfile.open())); + clean { tar.close(); } + for (e <- tar) { + + /* Check the filename to make sure it's not evil. */ + if (e.name(0) == '/' || e.name.split('/').exists { _ == ".." }) + throw new KeyConfigException("invalid path in tarball"); + + /* Report on progress. */ + or.step(s"entry `${e.name}'"); + + /* Find out where this file points. */ + val f = unpkdir/e.name; + + /* Unpack it. */ + if (e.isdir) { + /* A directory. Create it if it doesn't exist already. */ + + try { f.mkdir_!(); } + catch { case SystemError(EEXIST, _) => ok; } + } else if (e.isreg) { + /* A regular file. Write stuff to it. */ + + e.withStream { in => + f.withOutput { out => + for ((b, n) <- blocks(in)) out.write(b, 0, n); + } } - } - } else { - /* Something else. Be paranoid and reject it. */ + } else { + /* Something else. Be paranoid and reject it. */ - throw new KeyConfigException("unexpected object type in tarball"); + throw new KeyConfigException( + s"entry `${e.name}' has unexpected object type"); + } } } } /* There ought to be a file in here called `repos/master.pub'. */ - val reposdir = unpkdir + "repos"; + val reposdir = unpkdir/"repos"; + val masterfile = reposdir/"master.pub"; + if (!reposdir.isdir_!) throw new KeyConfigException("missing `repos/' directory"); - val masterfile = reposdir + "master.pub"; if (!masterfile.isreg_!) throw new KeyConfigException("missing `repos/master.pub' file"); + val mastertag = s"master-$seq"; /* Fetch the master key's fingerprint. */ - val (out, _) = runCommand("key", "-k", masterfile.getPath, - "fingerprint", - "-f", "-secret", - "-a", conf("fingerprint-hash"), - s"master-$seq"); - println(s";; $out"); + ic.operation("checking master key fingerprint") { _ => + val foundfp = keyFingerprint(masterfile, mastertag, + conf("fingerprint-hash")); + val wantfp = conf("hk-master"); + if (!fingerprintsEqual(wantfp, foundfp)) { + throw new KeyConfigException( + s"master key #$seq has wrong fingerprint: " + + s"expected $wantfp but found $foundfp"); + } + } + + /* Check the archive signature. */ + ic.operation("verifying archive signature") { or => + runCommand("catsign", "-k", masterfile.getPath, "verify", "-aqC", + "-k", mastertag, "-t", conf("sig-fresh"), + sigfile.getPath, tarfile.getPath); + } + + /* Confirm that the configuration in the new archive is sane. */ + checkConfigSanity(unpkdir/"tripe-keys.conf", ic); + + /* Now we just have to juggle the files about. */ + ic.operation("committing new configuration") { _ => + unpkdir.rename_!(newdir); + livedir.rename_!(olddir); + newdir.rename_!(livedir); + } + + invalidate(); // should move to `Live' } } diff --git a/progress.scala b/progress.scala index 9bfbc64..ab40c5d 100644 --- a/progress.scala +++ b/progress.scala @@ -29,226 +29,95 @@ package uk.org.distorted.tripe; package object progress { import scala.collection.mutable.{Publisher, Subscriber}; -import java.lang.Math.ceil; import java.lang.System.currentTimeMillis; -/*----- Main code ---------------------------------------------------------*/ +/*----- Progress displays -------------------------------------------------*/ -def formatTime(t: Int): String = - if (t < -1) "???" - else { - val (s, t1) = (t%60, t/60); - val (m, h) = (t1%60, t1/60); - if (h > 0) f"$h%d:$m%02d:$s%02d" - else f"$m%02d:$s%02d" - } - -private val UDATA = Seq("kB", "MB", "GB", "TB", "PB", "EB"); -def formatBytes(n: Long): String = { - val (x, u) = ((n.toDouble, "B ") /: UDATA) { (xu, n) => (xu, n) match { - case ((x, u), name) if x >= 1024.0 => (x/1024.0, name) - case (xu, _) => xu - } } - f"$x%6.1f$u%s" -} - -trait Eyecandy { - def set(line: String); - def clear(); - def commit(); - def commit(line: String) { commit(); set(line); commit(); } - def begin(job: Job); -} - -abstract class Event; // other subclasses can be added! -abstract class Progress extends Event { def cur: Long; } // it changed -object Progress { - def unapply(p: Progress) = - if (p == null) None - else Some(p.cur); -} -case class Update(override val cur: Long) extends Progress; // progress has been made -case class Changed(override val cur: Long) extends Progress; // what or max changed -abstract class Stopped extends Event; // job has stopped -case object Done extends Stopped; // job completed successfuly -final case class Failed(why: String) extends Stopped; // job failed -case object Cancelled extends Stopped; // job was cancelled - -trait Job extends Publisher[Event] { - def what: String; // imperative for what we're doing - def cur: Long; // current position in work - def max: Long; // maximum work to do - def format: String = { // describe progress in useful terms - val c = cur; - val m = max; - if (m >= 0) { - val fm = m.formatted("%d"); - s"%${fm.length}d/%s".format(c, fm) // ugh! - } else if (c > 0) s"$c" - else "" - } - def cancel(); +trait Model { + protected val t0 = currentTimeMillis; - private[this] val t0 = currentTimeMillis; - type Pub = Job; + def what: String; + def max: Long; - def taken: Double = (currentTimeMillis - t0)/1000.0; - def eta: Double = + def eta(cur: Long): Double = { /* Report the estimated time remaining in seconds, or -1 if no idea. * - * The model here is very stupid. Weird jobs should override this and do - * something more sensible. + * The model here is very stupid. Weird jobs should override this and + * do something more sensible. */ - if (max < 0 || cur <= 0) -1 - else taken*(max - cur)/cur.toDouble; -} - -/*----- Terminal eyecandy (FIXME: split this out) -------------------------*/ - -import java.io.FileDescriptor; -import java.lang.System.{out => stdout}; -import sys.isatty; - -object TerminalEyecandy extends Eyecandy with Subscriber[Event, Job] { - private var last = ""; - var eyecandyp = isatty(FileDescriptor.out); - - /* Assume that characters take up one cell each. This is going to fail - * badly for combining characters, zero-width characters, wide Asian - * characters, and lots of other Unicode characters. The problem is that - * Java doesn't have any way to query the display width of a character, - * and, honestly, I don't care enough to do the (substantial) work required - * to do this properly. - */ - - def set(line: String) { - if (eyecandyp) { - - /* If the old line is longer than the new one, then we must overprint - * the end part. - */ - if (line.length < last.length) { - val n = last.length - line.length; - for (_ <- 0 until n) stdout.write('\b'); - for (_ <- 0 until n) stdout.write(' '); - } - - /* Figure out the length of the common prefix between what we had - * before and what we have now. - */ - val m = (0 until (last.length min line.length)) prefixLength - { i => last(i) == line(i) }; - - /* Delete the tail from the old line and print the new version. */ - for (_ <- m until last.length) stdout.write('\b'); - stdout.print(line.substring(m)); - stdout.flush(); - } - - /* Update the state. */ - last = line; + val max = this.max; + val delta = currentTimeMillis - t0 + if (max < 0 || cur <= 0) -1 else delta*(max - cur)/cur.toDouble } - def clear() { set(""); } + protected def fmt1(n: Long): String = n.toString; - def commit() { - if (last != "") { - if (eyecandyp) stdout.write('\n'); - else stdout.println(last); - last = ""; - } + def format(cur: Long): String = { + val max = this.max; + val fc = fmt1(cur); + if (max >= 0) { val fm = fmt1(max); s"%${fm.length}s/%s".format(fc, fm) } + else if (cur > 0) fc + else "" } +} - private final val spinner = """/-\|"""; - private var step: Int = 0; - private final val width = 40; - - def begin(job: Job) { job.subscribe(this); } - - def notify(job: Job, ev: Event) { - ev match { - case Progress(cur) => - /* Redraw the status line. */ - - val max = job.max; - - val sb = new StringBuilder; - sb ++= job.what; sb += ' '; - - /* Step the spinner. */ - step += 1; if (step >= spinner.length) step = 0; - sb += spinner(step); sb += ' '; - - /* Progress bar. */ - if (max < 0) - sb ++= "[unknown progress]"; - else { - val n = (width*cur/max).toInt; - sb += '['; - for (_ <- 0 until n) sb += '='; - for (_ <- n until 40) sb += ' '; - sb += ']'; - - val f = job.format; - if (f != "") { sb += ' '; sb ++= f; } - sb ++= (100*cur/max).formatted(" %3d%%"); - - val eta = job.eta; - if (eta >= 0) { - sb += ' '; sb += '('; - sb ++= formatTime(ceil(eta).toInt); - sb += ')'; - } - } - - /* Done. */ - set(sb.result); - - case Done => - val t = formatTime(ceil(job.taken).toInt); - set(s"${job.what} done ($t)"); commit(); - - case Cancelled => - set(s"${job.what} CANCELLED"); commit(); +class SimpleModel(val what: String, val max: Long) extends Model; - case Failed(msg) => - set(s"${job.what} FAILED: $msg"); commit(); +private val UDATA = Seq("kB", "MB", "GB", "TB", "PB", "EB"); - case _ => ok; - } +trait DataModel extends Model { + override def fmt1(n: Long): String = { + val (x, u) = ((n.toDouble, "B ") /: UDATA) { (xu, n) => (xu, n) match { + case ((x, u), name) if x >= 1024.0 => (x/1024.0, name) + case (xu, _) => xu + } } + f"$x%6.1f$u%s" } } -/*----- Testing cruft -----------------------------------------------------*/ - -trait AsyncJob extends Job { - protected def run(); - private var _cur: Long = 0; override def cur = _cur; +trait BaseReporter { + def done(); + def failed(e: Exception); +} - +trait JobReporter extends BaseReporter { + def step(cur: Long); + def change(model: Model, cur: Long); } +trait OperationReporter extends BaseReporter { + def step(detail: String); +} +def withReporter[T, R <: BaseReporter] + (rep: R, body: R => T): T = { + val ret = try { body(rep) } + catch { case e: Exception => rep.failed(e); throw e; } + rep.done(); + ret +} +trait Eyecandy { + def note(msg: String); + def clear(); + def commit(); + def record(msg: String) { note(msg); commit(); } + def done(); + def cancelled() { failed("cancelled"); } + def failed(msg: String); -import Thread.sleep; + def beginJob(model: Model): JobReporter + // = new JobReporter(model); -class ToyJob(val max: Long) extends Job { - val what = "Dummy job"; - private var _i: Long = 0; def cur = _i; + def beginOperation(what: String): OperationReporter + // = new OperationReporter(what); - def cancel() { ??? } - def run() { - for (i <- 1l until max) { _i = i; publish(Update(i)); sleep(100); } - publish(Done); - } -} + def job[T](model: Model)(body: JobReporter => T): T = + withReporter(beginJob(model), body); -def testjob(n: Long) { - val j = new ToyJob(n); - TerminalEyecandy.begin(j); - j.run(); + def operation[T](what: String)(body: OperationReporter => T): T = + withReporter(beginOperation(what), body); } /*----- That's all, folks -------------------------------------------------*/ diff --git a/sys.scala b/sys.scala index cf0a72d..51ac170 100644 --- a/sys.scala +++ b/sys.scala @@ -175,23 +175,23 @@ object Errno extends Enumeration { private var wrong = -256; // next synthetic code private def nextwrong: Int = { val w = wrong; wrong -= 1; w } - class Type private[Errno](tag: String, val code: Int, id: Int) - extends Val(id, tag) { + class Val private[Errno](tag: String, val code: Int, id: Int) + extends super.Val(id, tag) { /* Our augmented error type. */ def message: String = strerror(code).toJString; } private class UnknownError(code: Int) - extends Type("", code, code); + extends Val("", code, code); - private def err(tag: String, code: Int): Type = { + private def err(tag: String, code: Int): Val = { /* Construct an error symbol given its tag string and a code number. */ - if (code < 0) new Type(tag, code, code) - else if (seen contains code) new Type(tag, code, nextwrong) - else { seen += code; new Type(tag, code, code) } + if (code < 0) new Val(tag, code, code) + else if (seen contains code) new Val(tag, code, nextwrong) + else { seen += code; new Val(tag, code, code) } } - private def err(tag: String): Type = + private def err(tag: String): Val = err(tag, tagmap.getOrElse(tag, nextwrong)); def byid(id: Int): Value = { @@ -377,7 +377,7 @@ object Errno extends Enumeration { val EHWPOISON = err("EHWPOISON"); /***end***/ } -import Errno.{Type => Errno, EEXIST, EISDIR, ENOENT, ENOTDIR}; +import Errno.{Val => Errno, EEXIST, EISDIR, ENOENT, ENOTDIR}; object SystemError { /* Pattern matching for `SystemError', below. */ @@ -596,7 +596,7 @@ object FileImplicits { /* Constructing names of files in a directory. Honestly, I'm surprised * there isn't a method for this already. */ - def +(sub: String): File = new File(file, sub); + def /(sub: String): File = new File(file, sub); /* Simple file operations. */ def unlink_!() { unlink(file.getPath); } @@ -739,7 +739,7 @@ def freshFile(d: File): File = { /* Make the filename, and try to create the file. If we succeed, we * win. */ - val f = new File(d, b.result); b.clear(); + val f = d/b.result; b.clear(); try { f.mkfile_!(); exit(f); } catch { case SystemError(EEXIST, _) => ok; } } @@ -798,23 +798,81 @@ def runCommand(cmd: String*): (String, String) = { } } +/*----- Interrupt triggers ------------------------------------------------*/ + +private val triggerLock = new Object; +private final val maxTriggers = 2; +private var nTriggers = 0; +private var triggers: List[Wrapper] = Nil; + +@native protected def makeTrigger(): Wrapper; +@native protected def destroyTrigger(trig: Wrapper); +@native protected def resetTrigger(trig: Wrapper); +@native protected def trigger(trig: Wrapper); + +private def getTrigger(): Wrapper = { + triggerLock synchronized { + if (nTriggers == 0) + makeTrigger() + else { + val trig = triggers.head; + triggers = triggers.tail; + nTriggers -= 1; + trig + } + } +} + +private def putTrigger(trig: Wrapper) { + resetTrigger(trig); + triggerLock synchronized { + if (nTriggers >= maxTriggers) + destroyTrigger(trig); + else { + triggers ::= trig; + nTriggers += 1; + } + } +} + +private def withTrigger[T](body: Wrapper => T): T = { + val trig = getTrigger(); + try { body(trig) } + finally { putTrigger(trig); } +} + +def interruptWithTrigger[T](body: Wrapper => T): T = { + /* interruptWithTrigger { TRIG => BODY } + * + * Execute BODY and return its result. If the thread receives an + * interrupt, the trigger TRIG will be pulled. See `interruptably' for the + * full semantics. + */ + + withTrigger { trig => + interruptably { body(trig) } onInterrupt { trigger(trig); } + }; +} + /*----- Connecting to a server --------------------------------------------*/ /* Primitive operations. */ final val CF_CLOSERD = 1; final val CF_CLOSEWR = 2; final val CF_CLOSEMASK = CF_CLOSERD | CF_CLOSEWR; -@native protected def connect(path: CString): Wrapper; +@native protected def connect(path: CString, trig: Wrapper): Wrapper; @native protected def send(conn: Wrapper, buf: CString, - start: Int, len: Int); + start: Int, len: Int, trig: Wrapper); @native protected def recv(conn: Wrapper, buf: CString, - start: Int, len: Int): Int; + start: Int, len: Int, trig: Wrapper): Int; @native def closeconn(conn: Wrapper, how: Int); class Connection(path: String) extends Closeable { /* The underlying primitive connection. */ - private[this] val conn = connect(path.toCString); + private[this] val conn = interruptWithTrigger { trig => + connect(path.toCString, trig); + }; /* Alternative constructors. */ def this(file: File) { this(file.getPath); } @@ -834,7 +892,7 @@ class Connection(path: String) extends Closeable { override def read(buf: Array[Byte]): Int = read(buf, 0, buf.length); override def read(buf: Array[Byte], start: Int, len: Int) = - recv(conn, buf, start, len); + interruptWithTrigger { trig => recv(conn, buf, start, len, trig); }; override def close() { closeconn(conn, CF_CLOSERD); } } lazy val input = new Input; @@ -845,7 +903,7 @@ class Connection(path: String) extends Closeable { override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); } override def write(buf: Array[Byte]) { write(buf, 0, buf.length); } override def write(buf: Array[Byte], start: Int, len: Int) - { send(conn, buf, start, len); } + { interruptWithTrigger { trig => send(conn, buf, start, len, trig); } } override def close() { closeconn(conn, CF_CLOSEWR); } } lazy val output = new Output; diff --git a/terminal.scala b/terminal.scala new file mode 100644 index 0000000..9722e2b --- /dev/null +++ b/terminal.scala @@ -0,0 +1,177 @@ +/* -*-scala-*- + * + * Terminal-based progress eyecandy + * + * (c) 2018 Straylight/Edgeware + */ + +/*----- Licensing notice --------------------------------------------------* + * + * This file is part of the Trivial IP Encryption (TrIPE) Android app. + * + * TrIPE is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation; either version 3 of the License, or (at your + * option) any later version. + * + * TrIPE is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * for more details. + * + * You should have received a copy of the GNU General Public License + * along with TrIPE. If not, see . + */ + +package uk.org.distorted.tripe; package progress; + +/*----- Imports -----------------------------------------------------------*/ + +import java.io.FileDescriptor; +import java.lang.Math.ceil; +import java.lang.System.{currentTimeMillis, out => stdout}; + +import sys.isatty; + +/*----- Main code ---------------------------------------------------------*/ + +object TerminalEyecandy extends Eyecandy { + private var last = ""; + var eyecandyp = isatty(FileDescriptor.out); + + /* Assume that characters take up one cell each. This is going to fail + * badly for combining characters, zero-width characters, wide Asian + * characters, and lots of other Unicode characters. The problem is that + * Java doesn't have any way to query the display width of a character, + * and, honestly, I don't care enough to do the (substantial) work required + * to do this properly. + */ + + def note(line: String) { + if (eyecandyp) { + + /* If the old line is longer than the new one, then we must overprint + * the end part. + */ + if (line.length < last.length) { + val n = last.length - line.length; + for (_ <- 0 until n) stdout.write('\b'); + for (_ <- 0 until n) stdout.write(' '); + } + + /* Figure out the length of the common prefix between what we had + * before and what we have now. + */ + val m = (0 until (last.length min line.length)) prefixLength + { i => last(i) == line(i) }; + + /* Delete the tail from the old line and print the new version. */ + for (_ <- m until last.length) stdout.write('\b'); + stdout.print(line.substring(m)); + stdout.flush(); + } + + /* Update the state. */ + last = line; + } + + def clear() { note(""); } + + def commit() { + if (last != "") { + if (eyecandyp) stdout.write('\n'); + else stdout.println(last); + last = ""; + } + } + + def done() { clear(); } + def failed(msg: String) { record(s"FAILED! $msg"); } + + def beginJob(model: Model): progress.JobReporter = + new JobReporter(model); + + def beginOperation(what: String): progress.OperationReporter = + new OperationReporter(what); + + private[this] class JobReporter(private[this] var model: Model) + extends progress.JobReporter { + private final val width = 40; + private final val spinner = """/-\|"""; + private final val mingap = 100; + private[this] var step: Int = 0; + private[this] var sweep: Int = 0; + private[this] val t0 = currentTimeMillis; + private[this] var last: Long = -1; + + def change(model: Model, cur: Long) + { last = -1; this.model = model; step(cur); } + + def step(cur: Long) { + val now = currentTimeMillis; + if (last >= 0 && now - last < mingap) return; + last = now; + + val max = model.max; + val sb = new StringBuilder; + sb ++= model.what; sb += ' '; + + /* Step the spinner. */ + sb += spinner(step); sb += ' '; + step += 1; if (step >= spinner.length) step = 0; + + /* Progress bar. */ + sb += '['; + if (max <= 0) { + val l = sweep; val r = width - 1 - sweep; + val (lo, hi, x, y) = if (l < r) (l, r, '>', '<') + else (r, l, '<', '>'); + for (_ <- 0 until lo) sb += ' '; + sb += x; + for (_ <- lo + 1 until hi) sb += ' '; + sb += y; + for (_ <- hi + 1 until width) sb += ' '; + sweep += 1; if (sweep >= width) sweep = 0; + } else { + val n = (width*cur/max).toInt; + for (_ <- 0 until n) sb += '='; + for (_ <- n until width) sb += ' '; + } + sb += ']'; + + /* Quantitative progress. */ + val f = model.format(cur); if (f != "") { sb += ' '; sb ++= f; } + if (max > 0) sb ++= (100*cur/max).formatted(" %3d%%"); + + /* Estimated time to completion. */ + val eta = model.eta(cur); + if (eta >= 0) { + sb += ' '; sb += '('; + sb ++= formatTime(ceil(eta/1000.0).toInt); + sb += ')'; + } + + /* Done. */ + note(sb.result); + } + + def done() { + val t = formatTime(ceil((currentTimeMillis - t0)/1000.0).toInt); + record(s"${model.what} done ($t)"); + } + + def failed(e: Exception) + { record(s"${model.what} FAILED: ${e.getMessage}"); } + + step(0); + } + + class OperationReporter(what: String) extends progress.OperationReporter { + def step(detail: String) { note(s"$what: $detail"); } + def done() { record(s"$what: ok"); } + def failed(e: Exception) { record(s"$what: ${e.getMessage}"); } + step("..."); + } +} + +/*----- That's all, folks -------------------------------------------------*/ diff --git a/util.scala b/util.scala index 75b3677..ea776e7 100644 --- a/util.scala +++ b/util.scala @@ -31,9 +31,12 @@ import scala.concurrent.duration.{Deadline, Duration}; import scala.util.control.{Breaks, ControlThrowable}; import java.io.{BufferedReader, Closeable, File, InputStream, Reader}; -import java.net.{URL, URLConnection}; +import java.net.{HttpURLConnection, URL, URLConnection}; import java.nio.{ByteBuffer, CharBuffer}; +import java.nio.channels.{SelectionKey, Selector}; +import java.nio.channels.spi.{AbstractSelector, AbstractSelectableChannel}; import java.nio.charset.Charset; +import java.util.{Set => JSet}; import java.util.concurrent.locks.{Lock, ReentrantLock}; /*----- Miscellaneous useful things ---------------------------------------*/ @@ -194,6 +197,61 @@ def loopUnit(body: (=> Nothing) => Unit): Unit = { val BREAKS = new Breaks; import BREAKS.{breakable, break}; +/*----- Interruptably doing things ----------------------------------------*/ + +private class InterruptCatcher[T](body: => T, onWakeup: => Unit) + extends AbstractSelector(null) { + /* Hook onto the VM's thread interruption machinery. + * + * The `run' method is the only really interesting one. It will run the + * BODY, returning its result; if the thread is interrupted during this + * time, ONWAKEUP is invoked for effect. The expectation is that ONWAKEUP + * will somehow cause BODY to stop early. + * + * Credit for this hack goes to Nicholas Wilson: see + * . + */ + + private def nope: Nothing = + { throw new UnsupportedOperationException("can't do that"); } + protected def implCloseSelector() { } + protected def register(chan: AbstractSelectableChannel, + ops: Int, att: Any): SelectionKey = nope; + def keys(): JSet[SelectionKey] = nope; + def selectedKeys(): JSet[SelectionKey] = nope; + def select(): Int = nope; + def select(millis: Long): Int = nope; + def selectNow(): Int = nope; + + def run(): T = try { + begin(); + val ret = body; + if (Thread.interrupted()) throw new InterruptedException; + ret + } finally { + end(); + } + def wakeup(): Selector = { onWakeup; this } +} + +class PendingInterruptable[T] private[tripe](body: => T) { + /* This class exists to provide the `onInterrupt THUNK' syntax. */ + + def onInterrupt(thunk: => Unit): T = + new InterruptCatcher(body, thunk).run(); +} +def interruptably[T](body: => T) = { + /* interruptably { BODY } onInterrupt { THUNK } + * + * Execute BODY and return its result. If the thread receives an + * interrupt -- or is already in an interrupted state -- execute THUNK for + * effect; it is expected to cause BODY to return expeditiously, and when + * the BODY completes, an `InterruptedException' is thrown. + */ + + new PendingInterruptable(body); +} + /*----- A gadget for fetching URLs ----------------------------------------*/ class URLFetchException(msg: String) extends Exception(msg); @@ -208,47 +266,74 @@ def fetchURL(url: URL, cb: URLFetchCallbacks) { /* Fetch the URL, feeding the data through the callbacks CB. */ withCleaner { clean => - var win: Boolean = false; - clean { cb.done(win); } + var win: Boolean = false; clean { cb.done(win); } - /* Set up the connection, and run a preflight check. */ + /* Set up the connection. This isn't going to block, I think, and we + * need to use it in the interrupt handler. + */ val c = url.openConnection(); - cb.preflight(c); - - /* Start fetching data. */ - val in = c.getInputStream; clean { in.close(); } - val explen = c.getContentLength; - /* Read a buffer at a time, and give it to the callback. Maintain a - * running total. + /* Java's default URL handlers don't respond to interrupts, so we have to + * take over this duty. */ - var len: Long = 0; - blockUnit { exit => - for ((buf, n) <- blocks(in)) { - cb.write(buf, n, len); - len += n; - if (explen != -1 && len > explen) exit; + interruptably { + /* Run the caller's preflight check. This must be done here, since it + * might well block while it discovers things like the content length. + */ + cb.preflight(c); + + /* Start fetching data. */ + val in = c.getInputStream; clean { in.close(); } + val explen = c.getContentLength; + + /* Read a buffer at a time, and give it to the callback. Maintain a + * running total. + */ + var len: Long = 0; + blockUnit { exit => + for ((buf, n) <- blocks(in)) { + cb.write(buf, n, len); + len += n; + if (explen != -1 && len > explen) exit; + } } - } - /* I can't find it documented anywhere that the existing machinery - * checks the received stream against the advertised content length. - * It doesn't hurt to check again, anyway. - */ - if (explen != -1 && explen != len) { - throw new URLFetchException( - s"received $len /= $explen bytes from `$url'"); - } + /* I can't find it documented anywhere that the existing machinery + * checks the received stream against the advertised content length. + * It doesn't hurt to check again, anyway. + */ + if (explen != -1 && explen != len) { + throw new URLFetchException( + s"received $len /= $explen bytes from `$url'"); + } - /* Glorious success is ours. */ - win = true; + /* Glorious success is ours. */ + win = true; + } onInterrupt { + /* Oh. How do we do this? */ + + c match { + case c: HttpURLConnection => + /* It's an HTTP connection (what happened to the case here?). + * HTTPS connections match too because they're a subclass. Getting + * the input stream will block, but there's an easier way. + */ + c.disconnect(); + + case _ => + /* It's something else. Let's hope that getting the input stream + * doesn't block. + */ + c.getInputStream.close(); + } + } } } /*----- Threading things --------------------------------------------------*/ -def thread[T](name: String, run: Boolean = true, daemon: Boolean = true) - (f: => T): Thread = { +def thread(name: String, run: Boolean = true, daemon: Boolean = true) + (f: => Unit): Thread = { /* Make a thread with a given name, and maybe start running it. */ val t = new Thread(new Runnable { def run() { f; } }, name); @@ -257,6 +342,28 @@ def thread[T](name: String, run: Boolean = true, daemon: Boolean = true) t } +class ValueThread[T](name: String, group: ThreadGroup = null, + stacksz: Long = 0)(body: => T) + extends Thread(group, null, name, stacksz) { + private[this] var exc: Throwable = _; + private[this] var ret: T = _; + + override def run() { + try { ret = body; } + catch { case e: Throwable => exc = e; } + } + def get: T = + if (isAlive) throw new IllegalArgumentException("still running"); + else if (exc != null) throw exc; + else ret; +} +def valueThread[T](name: String, run: Boolean = true) + (body: => T): ValueThread[T] = { + val t = new ValueThread(name)(body); + if (run) t.start(); + t +} + /*----- Quoting and parsing tokens ----------------------------------------*/ def quoteTokens(v: Seq[String]): String = { @@ -416,7 +523,7 @@ def blocks(in: InputStream, blksz: Int): } } def blocks(in: InputStream): - BufferedIterator[(Array[Byte], Int)] = blocks(in, 4096); + BufferedIterator[(Array[Byte], Int)] = blocks(in, 65536); def blocks(in: BufferedReader, blksz: Int): BufferedIterator[(Array[Char], Int)] = @@ -430,7 +537,7 @@ def blocks(in: BufferedReader, blksz: Int): } } def blocks(in: BufferedReader): - BufferedIterator[(Array[Char], Int)] = blocks(in, 4096); + BufferedIterator[(Array[Char], Int)] = blocks(in, 65536); def blocks(r: Reader, blksz: Int): BufferedIterator[(Array[Char], Int)] = blocks(bufferedReader(r), blksz); def blocks(r: Reader): BufferedIterator[(Array[Char], Int)] = @@ -454,6 +561,15 @@ def oxford(conj: String, things: Seq[String]): String = things match { sb.result } +def formatTime(t: Int): String = + if (t < -1) "???" + else { + val (s, t1) = (t%60, t/60); + val (m, h) = (t1%60, t1/60); + if (h > 0) f"$h%d:$m%02d:$s%02d" + else f"$m%02d:$s%02d" + } + /*----- That's all, folks -------------------------------------------------*/ } -- 2.11.0