More work in progress.
authorMark Wooding <mdw@distorted.org.uk>
Wed, 13 Jun 2018 00:55:56 +0000 (01:55 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Wed, 13 Jun 2018 00:55:56 +0000 (01:55 +0100)
Makefile
jni.c
keys.scala
progress.scala
sys.scala
terminal.scala [new file with mode: 0644]
util.scala

index e00445f..c4f471a 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -198,7 +198,9 @@ CLASSES                     += util
 CLASSES                        += sys:util
 CLASSES                        += admin:sys,util
 CLASSES                        += tar:util
-CLASSES                        += keys:tar,sys,util
+CLASSES                        += progress:sys,util
+CLASSES                        += keys:progress,tar,sys,util
+CLASSES                        += terminal:progress,sys,util
 CLASSES                        += main:sys
 
 ## Machinery for parsing the `CLASSES' list.
@@ -249,7 +251,7 @@ all:: $(TARGETS)
 
 clean::; rm -f $(CLEANFILES) $(TARGETS)
 
-repl: $(TARGETS)
+repl: all
        $(SCALA_REPL) -cp $(CLASSDIR) -Djava.library.path=$(OUTDIR)
 .PHONY: repl
 
diff --git a/jni.c b/jni.c
index f65e464..9aa79f8 100644 (file)
--- a/jni.c
+++ b/jni.c
@@ -39,6 +39,7 @@
 #include <jni.h>
 
 #include <sys/types.h>
+#include <sys/select.h>
 #include <sys/socket.h>
 #include <sys/stat.h>
 #include <sys/sysmacros.h>
@@ -192,6 +193,31 @@ static void except_syserror(JNIEnv *jni, const char *clsname,
   va_end(ap);
 }
 
+static int set_nonblocking(JNIEnv *jni, int fd, int nb)
+{
+  int f0 = fcntl(fd, F_GETFL), f1;
+  if (f0 < 0) goto err;
+  if (nb) f1 = f0 | O_NONBLOCK;
+  else f1 = f0&~O_NONBLOCK;
+  if (fcntl(fd, F_SETFL, f1)) goto err;
+  return (f0 & O_NONBLOCK);
+err:
+    except_syserror(jni, SYSERR, errno,
+                   "failed to set descriptor nonblocking");
+    return (-1);
+}
+
+static int set_closeonexec(JNIEnv *jni, int fd)
+{
+  int f = fcntl(fd, F_GETFD);
+  if (f < 0 || fcntl(fd, F_SETFD, f | FD_CLOEXEC)) {
+    except_syserror(jni, SYSERR, errno,
+                   "failed to set descriptor close-on-exec");
+    return (-1);
+  }
+  return (0);
+}
+
 /*----- Wrapping native types ---------------------------------------------*/
 
 /* There's no way defined in the JNI to stash a C pointer in a Java object.
@@ -808,7 +834,7 @@ JNIEXPORT jboolean JNIFUNC(isatty)(JNIEnv *jni, jobject cls, jobject jfd)
 
 /* Java has these already, as methods on `java.io.File' objects.  Alas, these
  * methods are useless at reporting errors: they tend to return a `boolean'
- * success/ fail indicator, and throw away any more detailed information.
+ * success/fail indicator, and throw away any more detailed information.
  * There's better functionality in `java.nio.file.Files', but that only turns
  * up in Android API 26 (in 7.0 Nougat).  There's `android.system.Os', which
  * has a bunch of POSIX-shaped functions -- but they're only in Android API
@@ -1102,6 +1128,96 @@ end:
   put_cstring(jni, path, pathstr);
 }
 
+/*----- Triggers ----------------------------------------------------------*/
+
+/* A trigger is a gadget for waking up a thread which is blocking on I/O,
+ * and it's used to implement interruptability.
+ *
+ * Really, a trigger is a pipe.  A `blocking' I/O operation secretly uses
+ * select(2) to block on the descriptor of interest /and/ the read side of
+ * the trigger pipe.  To wake up a thread that's blocked, we just write a
+ * byte (nobody cares /which/ byte) to the write end.
+ */
+
+struct trigger {
+  struct native_base _base;
+  int rfd, wfd;
+};
+static const struct native_type trigger_type =
+       { "trigger", sizeof(struct trigger), 0x65ffd8b4 };
+
+JNIEXPORT wrapper JNICALL JNIFUNC(makeTrigger)(JNIEnv *jni, jobject cls)
+{
+  struct trigger trig;
+  int fd[2];
+  int i;
+  wrapper ret = 0;
+
+  fd[0] = fd[1] = -1;
+  if (pipe(fd)) {
+    except_syserror(jni, SYSERR, errno, "failed to create pipe");
+    goto end;
+  }
+  for (i = 0; i < 2; i++) {
+    if (set_nonblocking(jni, fd[i], 1) < 0 || set_closeonexec(jni, fd[i]))
+      goto end;
+  }
+
+  INIT_NATIVE(trigger, &trig);
+  trig.rfd = fd[0]; fd[0] = -1;
+  trig.wfd = fd[1]; fd[1] = -1;
+  ret = wrap(jni, &trigger_type, &trig);
+
+end:
+  for (i = 0; i < 2; i++)
+    if (fd[i] != -1) close(fd[i]);
+  return (ret);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(destroyTrigger)(JNIEnv *jni, jobject cls,
+                                              wrapper wtrig)
+{
+  struct trigger trig;
+
+  if (unwrap(jni, &trig, &trigger_type, wtrig)) return;
+  if (trig.rfd != -1) { close(trig.rfd); trig.rfd = -1; }
+  if (trig.wfd != -1) { close(trig.wfd); trig.wfd = -1; }
+  update_wrapper(jni, &trigger_type, wtrig, &trig);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(resetTrigger)(JNIEnv *jni, jobject cls,
+                                            wrapper wtrig)
+{
+  struct trigger trig;
+  char buf[64];
+  ssize_t n;
+
+  if (unwrap(jni, &trig, &trigger_type, wtrig)) return;
+  for (;;) {
+    n = read(trig.rfd, buf, sizeof(buf));
+    if (n > 0) continue;
+    assert(n < 0);
+    if (errno == EAGAIN || errno == EWOULDBLOCK) break;
+    else {
+      except_syserror(jni, SYSERR, errno, "failed to reset trigger");
+      break;
+    }
+  }
+}
+
+JNIEXPORT void JNICALL JNIFUNC(trigger)(JNIEnv *jni, jobject cls,
+                                       wrapper wtrig)
+{
+  struct trigger trig;
+  ssize_t n;
+  char c = 0;
+
+  if (unwrap(jni, &trig, &trigger_type, wtrig)) return;
+  n = write(trig.wfd, &c, 1);
+  if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK)
+    except_syserror(jni, SYSERR, errno, "failed to pull trigger");
+}
+
 /*----- A server connection, using a Unix-domain socket -------------------*/
 
 struct conn {
@@ -1116,14 +1232,21 @@ static const struct native_type conn_type =
        { "conn", sizeof(struct conn), 0xed030167 };
 
 JNIEXPORT wrapper JNICALL JNIFUNC(connect)(JNIEnv *jni, jobject cls,
-                                          jobject path)
+                                          jobject path, wrapper wtrig)
 {
   struct conn conn;
+  struct trigger trig;
   struct sockaddr_un sun;
+  int rc, maxfd;
+  fd_set rfds, wfds;
   const char *pathstr = 0;
-  jobject ret = 0;
+  int err;
+  socklen_t sz;
+  wrapper ret = 0;
+  int nb;
   int fd = -1;
 
+  if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end;
   pathstr = get_cstring(jni, path); if (!pathstr) goto end;
   if (strlen(pathstr) >= sizeof(sun.sun_path)) {
     except(jni, ARGERR,
@@ -1132,12 +1255,33 @@ JNIEXPORT wrapper JNICALL JNIFUNC(connect)(JNIEnv *jni, jobject cls,
   }
 
   INIT_NATIVE(conn, &conn);
-  fd = socket(SOCK_STREAM, PF_UNIX, 0); if (fd < 0) goto err;
+  fd = socket(PF_UNIX, SOCK_STREAM, 0); if (fd < 0) goto err;
+  nb = set_nonblocking(jni, fd, 1); if (nb < 0) goto end;
 
   sun.sun_family = AF_UNIX;
   strcpy(sun.sun_path, (char *)pathstr);
-  if (connect(fd, (struct sockaddr *)&sun, sizeof(sun))) goto err;
+  if (!connect(fd, (struct sockaddr *)&sun, sizeof(sun))) goto connected;
+  else if (errno != EINPROGRESS) goto err;
+
+  maxfd = trig.rfd;
+  if (maxfd < fd) maxfd = fd;
+  for (;;) {
+    FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds);
+    FD_ZERO(&wfds); FD_SET(fd, &wfds);
+    rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err;
+    if (FD_ISSET(trig.rfd, &rfds)) goto end;
+    if (FD_ISSET(fd, &wfds)) {
+      sz = sizeof(sun);
+      if (!getpeername(fd, (struct sockaddr *)&sun, &sz)) goto connected;
+      else if (errno != ENOTCONN) goto err;
+      sz = sizeof(err);
+      if (!getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &sz)) errno = err;
+      goto err;
+    }
+  }
 
+connected:
+  if (set_nonblocking(jni, fd, nb) < 0) goto end;
   conn.fd = fd; fd = -1;
   conn.f = 0;
   ret = wrap(jni, &conn_type, &conn);
@@ -1147,7 +1291,7 @@ err:
   except_syserror(jni, SYSERR, errno,
                  "failed to connect to Unix-domain socket `%s'", pathstr);
 end:
-  if (fd == -1) close(fd);
+  if (fd != -1) close(fd);
   put_cstring(jni, path, pathstr);
   return (ret);
 }
@@ -1181,28 +1325,40 @@ static int check_buffer_bounds(JNIEnv *jni, const char *what,
 
 JNIEXPORT void JNICALL JNIFUNC(send)(JNIEnv *jni, jobject cls,
                                     wrapper wconn, jbyteArray buf,
-                                    jint start, jint len)
+                                    jint start, jint len,
+                                    wrapper wtrig)
 {
   struct conn conn;
+  struct trigger trig;
+  int rc, maxfd;
   ssize_t n;
+  fd_set rfds, wfds;
   jbyte *p = 0;
 
   if (unwrap(jni, &conn, &conn_type, wconn)) goto end;
+  if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end;
   if (check_buffer_bounds(jni, "send", buf, start, len)) goto end;
 
   p = (*jni)->GetByteArrayElements(jni, buf, 0);
   if (!p) goto end;
 
+  maxfd = trig.rfd;
+  if (maxfd < conn.fd) maxfd = conn.fd;
   while (len) {
-    n = send(conn.fd, p + start, len, 0);
-    if (n < 0) {
-      except_syserror(jni, SYSERR,
-                     errno, "failed to send on connection");
-      goto end;
+    FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds);
+    FD_ZERO(&wfds); FD_SET(conn.fd, &wfds);
+    rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err;
+    if (FD_ISSET(trig.rfd, &rfds)) break;
+    if (FD_ISSET(conn.fd, &wfds)) {
+      n = send(conn.fd, p + start, len, 0);
+      if (n >= 0) { start += n; len -= n; }
+      else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err;
     }
-    start += n; len -= n;
   }
+  goto end;
 
+err:
+  except_syserror(jni, SYSERR, errno, "failed to send on connection");
 end:
   if (p) (*jni)->ReleaseByteArrayElements(jni, buf, p, JNI_ABORT);
   return;
@@ -1210,26 +1366,42 @@ end:
 
 JNIEXPORT jint JNICALL JNIFUNC(recv)(JNIEnv *jni, jobject cls,
                                     wrapper wconn, jbyteArray buf,
-                                    jint start, jint len)
+                                    jint start, jint len,
+                                    wrapper wtrig)
 {
   struct conn conn;
+  struct trigger trig;
+  int maxfd;
+  fd_set rfds;
   jbyte *p = 0;
   jint rc = -1;
 
   if (unwrap(jni, &conn, &conn_type, wconn)) goto end;
+  if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end;
   if (check_buffer_bounds(jni, "send", buf, start, len)) goto end;
 
   p = (*jni)->GetByteArrayElements(jni, buf, 0);
   if (!p) goto end;
 
-  rc = recv(conn.fd, p + start, len, 0);
-  if (rc < 0) {
-    except_syserror(jni, SYSERR,
-                   errno, "failed to read from connection");
-    goto end;
+  maxfd = trig.rfd;
+  if (maxfd < conn.fd) maxfd = conn.fd;
+  for (;;) {
+    FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); FD_SET(conn.fd, &rfds);
+    rc = select(maxfd + 1, &rfds, 0, 0, 0); if (rc < 0) goto err;
+    if (FD_ISSET(trig.rfd, &rfds)) {
+      break;
+    }
+    if (FD_ISSET(conn.fd, &rfds)) {
+      rc = recv(conn.fd, p + start, len, 0);
+      if (rc >= 0) break;
+      else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err;
+    }
   }
   if (!rc) rc = -1;
+  goto end;
 
+err:
+  except_syserror(jni, SYSERR, errno, "failed to read from connection");
 end:
   if (p) (*jni)->ReleaseByteArrayElements(jni, buf, p, 0);
   return (rc);
index b49e334..cec56a9 100644 (file)
@@ -37,15 +37,17 @@ import sys.{SystemError, hashsz, runCommand};
 import sys.Errno.EEXIST;
 import sys.FileImplicits._;
 
+import progress.{Eyecandy, SimpleModel, DataModel};
+
 /*----- Useful regular expressions ----------------------------------------*/
 
-private val RX_COMMENT = """(?x) ^ \s* (?: \# .* )? $""".r;
-private val RX_KEYVAL = """(?x) ^ \s*
+private final val RX_COMMENT = """(?x) ^ \s* (?: \# .* )? $""".r;
+private final val RX_KEYVAL = """(?x) ^ \s*
       ([-\w]+)
       (?:\s+(?!=)|\s*=\s*)
       (|\S|\S.*\S)
       \s* $""".r;
-private val RX_DOLLARSUBST = """(?x) \$ \{ ([-\w]+) \}""".r;
+private final val RX_DOLLARSUBST = """(?x) \$ \{ ([-\w]+) \}""".r;
 
 /*----- Things that go wrong ----------------------------------------------*/
 
@@ -106,26 +108,70 @@ private val DEFAULTS: Seq[(String, Config => String)] =
       "sig-fresh" -> { _ => "always" },
       "fingerprint-hash" -> { _("hash") });
 
-/*----- Managing a key repository -----------------------------------------*/
+private def readConfig(file: File): Config = {
 
-def downloadToFile(file: File, url: URL, maxlen: Long = Long.MaxValue) {
-  fetchURL(url, new URLFetchCallbacks {
-    val out = file.openForOutput();
-    private def toobig() {
-      throw new KeyConfigException(s"remote file `$url' is " +
-                                  "suspiciously large");
-    }
-    var totlen: Long = 0;
-    override def preflight(conn: URLConnection) {
-      totlen = conn.getContentLength;
-      if (totlen > maxlen) toobig();
+  /* Build the new configuration in a temporary place. */
+  var m = HashMap[String, String]();
+
+  /* Read the config file into our map. */
+  file.withReader { in =>
+    var lno = 1;
+    for (line <- lines(in)) {
+      line match {
+       case RX_COMMENT() => ok;
+       case RX_KEYVAL(key, value) => m += key -> value;
+       case _ =>
+         throw new ConfigSyntaxError(file.getPath, lno,
+                                     "failed to parse line");
+      }
+      lno += 1;
     }
-    override def done(win: Boolean) { out.close(); }
-    def write(buf: Array[Byte], n: Int, len: Long) {
-      if (len + n > maxlen) toobig();
-      out.write(buf, 0, n);
+  }
+
+  /* Fill in defaults where things have been missed out. */
+  for ((key, dflt) <- DEFAULTS) {
+    if (!(m contains key)) {
+      try { m += key -> dflt(m); }
+      catch {
+       case e: DefaultFailed =>
+         throw new ConfigDefaultFailed(file.getPath, key,
+                                       e.key, m(e.key));
+      }
     }
-  });
+  }
+
+  /* And we're done. */
+  m
+}
+
+/*----- Managing a key repository -----------------------------------------*/
+
+def downloadToFile(file: File, url: URL,
+                  maxlen: Long = Long.MaxValue,
+                  ic: Eyecandy) {
+  ic.job(new SimpleModel(s"connecting to `$url'", -1)) { jr =>
+    fetchURL(url, new URLFetchCallbacks {
+      val out = file.openForOutput();
+      private def toobig() {
+       throw new KeyConfigException(
+         s"remote file `$url' is suspiciously large");
+      }
+      var totlen: Long = 0;
+      override def preflight(conn: URLConnection) {
+       totlen = conn.getContentLength;
+       if (totlen > maxlen) toobig();
+       jr.change(new SimpleModel(s"downloading `$url'", totlen)
+                   with DataModel,
+                 0);
+      }
+      override def done(win: Boolean) { out.close(); }
+      def write(buf: Array[Byte], n: Int, len: Long) {
+       if (len + n > maxlen) toobig();
+       out.write(buf, 0, n);
+       jr.step(len + n);
+      }
+    })
+  }
 }
 
 /* Lifecycle notes
@@ -161,34 +207,65 @@ def downloadToFile(file: File, url: URL, maxlen: Long = Long.MaxValue) {
  * (delete old/)
  */
 
+class RepositoryStateException(val state: Repository.State.Value,
+                              msg: String)
+       extends Exception(msg);
+
+class KeyConfigException(msg: String) extends Exception(msg);
+
+private def launderFingerprint(fp: String): String =
+  fp filter { _.isLetterOrDigit };
+
+private def fingerprintsEqual(a: String, b: String) =
+  launderFingerprint(a) == launderFingerprint(b);
+
+private def keyFingerprint(kr: File, tag: String, hash: String): String = {
+  val (out, _) = runCommand("key", "-k", kr.getPath, "fingerprint",
+                           "-a", hash, "-f", "-secret", tag);
+  nextToken(out) match {
+    case Some((fp, _)) => fp
+    case _ =>
+      throw new java.io.IOException("unexpected output from `key fingerprint");
+  }
+}
+
 object Repository {
   object State extends Enumeration {
     val Empty, Pending, Confirmed, Updating, Committing, Live = Value;
   }
 }
 
-class RepositoryStateException(val state: Repository.State.Value,
-                              msg: String)
-       extends Exception(msg);
+def checkConfigSanity(file: File, ic: Eyecandy) {
+  ic.operation("checking new configuration") { _ =>
 
-class KeyConfigException(msg: String) extends Exception(msg);
+    /* Make sure we can read and understand the file. */
+    val conf = readConfig(file);
+
+    /* Make sure there are entries which we can use to update.  This won't
+     * guarantee that we can reliably update, but it will help.
+     */
+    conf("repos-url"); conf("sig-url");
+    conf("fingerprint-hash"); conf("sig-fresh");
+    conf("master-sequence"); conf("hk-master");
+  }
+}
 
 class Repository(val root: File) extends Closeable {
   import Repository.State.{Value => State, _};
 
   /* Important directories and files. */
-  private[this] val livedir = root + "live";
-  private[this] val livereposdir = livedir + "repos";
-  private[this] val newdir = root + "new";
-  private[this] val olddir = root + "old";
-  private[this] val pendingdir = root + "pending";
-  private[this] val tmpdir = root + "tmp";
+  private[this] val livedir = root/"live";
+  private[this] val livereposdir = livedir/"repos";
+  private[this] val newdir = root/"new";
+  private[this] val olddir = root/"old";
+  private[this] val pendingdir = root/"pending";
+  private[this] val tmpdir = root/"tmp";
 
   /* Take out a lock in case of other instances. */
   private[this] val lock = {
     try { root.mkdir_!(); }
     catch { case SystemError(EEXIST, _) => ok; }
-    (root + "lk").lock_!()
+    (root/"lk").lock_!()
   }
   def close() { lock.close(); }
 
@@ -269,9 +346,10 @@ class Repository(val root: File) extends Closeable {
     }
   } }
 
-  def destroy() {
+  def destroy(ic: Eyecandy) {
     /* Clear out the entire repository.  Everything.  It's all gone. */
-    root.foreachFile { f => if (f.getName != "lk") f.rmTree(); }
+    ic.operation("clearing configuration")
+      { _ => root.foreachFile { f => if (f.getName != "lk") f.rmTree(); } }
   }
 
   def clearTmp() {
@@ -293,66 +371,40 @@ class Repository(val root: File) extends Closeable {
        case Empty =>
          throw new RepositoryStateException(Empty, "repository is Empty");
       }
-      val file = dir + "tripe-keys.conf";
-
-      /* Build the new configuration in a temporary place. */
-      var m = HashMap[String, String]();
-
-      /* Read the config file into our map. */
-      file.withReader { in =>
-       var lno = 1;
-       for (line <- lines(in)) {
-         line match {
-           case RX_COMMENT() => ok;
-           case RX_KEYVAL(key, value) => m += key -> value;
-           case _ =>
-             throw new ConfigSyntaxError(file.getPath, lno,
-                                         "failed to parse line");
-         }
-         lno += 1;
-       }
-      }
-
-      /* Fill in defaults where things have been missed out. */
-      for ((key, dflt) <- DEFAULTS) {
-       if (!(m contains key)) {
-         try { m += key -> dflt(m); }
-         catch {
-           case e: DefaultFailed =>
-             throw new ConfigDefaultFailed(file.getPath, key,
-                                           e.key, m(e.key));
-         }
-       }
-      }
 
-      /* All done. */
-      _config = m;
+      /* And then read the configuration. */
+      _config = readConfig(dir/"tripe-keys.conf");
     }
 
     _config
   }
 
-  def fetchConfig(url: URL) {
+  def fetchConfig(url: URL, ic: Eyecandy) {
     /* Fetch an initial configuration file from a given URL. */
 
     checkState(Empty);
     clearTmp();
-    downloadToFile(tmpdir + "tripe-keys.conf", url);
-    tmpdir.rename_!(pendingdir);
+
+    val conffile = tmpdir/"tripe-keys.conf";
+    downloadToFile(conffile, url, 16*1024, ic);
+    checkConfigSanity(conffile, ic);
+    ic.operation("committing configuration")
+      { _ => tmpdir.rename_!(pendingdir); }
     invalidate();                      // should move to `Pending'
   }
 
-  def confirm() {
+  def confirm(ic: Eyecandy) {
     /* The user has approved the master key fingerprint in the `Pending'
      * configuration.  Advance to `Confirmed'.
      */
 
     checkState(Pending);
-    pendingdir.rename_!(livedir);
+    ic.operation("confirming configuration")
+      { _ => pendingdir.rename_!(livedir); }
     invalidate();                      // should move to `Confirmed'
   }
 
-  def update() {
+  def update(ic: Eyecandy) {
     /* Update the repository from the master.
      *
      * Fetch a (possibly new) archive; unpack it; verify the master key
@@ -364,66 +416,98 @@ class Repository(val root: File) extends Closeable {
     clearTmp();
 
     /* First thing is to download the tarball and signature. */
-    val tarfile = tmpdir + "tripe-keys.tar.gz";
-    downloadToFile(tarfile, new URL(conf("repos-url")));
-    val sigfile = tmpdir + "tripe-keys.sig";
+    val tarfile = tmpdir/"tripe-keys.tar.gz";
+    downloadToFile(tarfile, new URL(conf("repos-url")), 256*1024, ic);
+    val sigfile = tmpdir/"tripe-keys.sig";
     val seq = conf("master-sequence");
     downloadToFile(sigfile,
                   new URL(conf("sig-url").replaceAllLiterally("<SEQ>",
-                                                              seq)));
+                                                              seq)),
+                  4*1024, ic);
 
     /* Unpack the tarball.  Carefully. */
-    val unpkdir = tmpdir + "unpk";
-    unpkdir.mkdir_!();
-    withCleaner { clean =>
-      val tar = new TarFile(new GZIPInputStream(tarfile.open()));
-      clean { tar.close(); }
-      for (e <- tar) {
-
-       /* Check the filename to make sure it's not evil. */
-       if (e.name.split('/').exists { _ == ".." })
-         throw new KeyConfigException("invalid path in tarball");
-
-       /* Find out where this file points. */
-       val f = unpkdir + e.name;
-
-       /* Unpack it. */
-       if (e.isdir) {
-         /* A directory.  Create it if it doesn't exist already. */
-
-         try { f.mkdir_!(); }
-         catch { case SystemError(EEXIST, _) => ok; }
-       } else if (e.isreg) {
-         /* A regular file.  Write stuff to it. */
-
-         e.withStream { in =>
-           f.withOutput { out =>
-             for ((b, n) <- blocks(in)) out.write(b, 0, n);
+    val unpkdir = tmpdir/"unpk";
+    ic.operation("unpacking archive") { or =>
+      unpkdir.mkdir_!();
+      withCleaner { clean =>
+       val tar = new TarFile(new GZIPInputStream(tarfile.open()));
+       clean { tar.close(); }
+       for (e <- tar) {
+
+         /* Check the filename to make sure it's not evil. */
+         if (e.name(0) == '/' || e.name.split('/').exists { _ == ".." })
+           throw new KeyConfigException("invalid path in tarball");
+
+         /* Report on progress. */
+         or.step(s"entry `${e.name}'");
+
+         /* Find out where this file points. */
+         val f = unpkdir/e.name;
+
+         /* Unpack it. */
+         if (e.isdir) {
+           /* A directory.  Create it if it doesn't exist already. */
+
+           try { f.mkdir_!(); }
+           catch { case SystemError(EEXIST, _) => ok; }
+         } else if (e.isreg) {
+           /* A regular file.  Write stuff to it. */
+
+           e.withStream { in =>
+             f.withOutput { out =>
+               for ((b, n) <- blocks(in)) out.write(b, 0, n);
+             }
            }
-         }
-       } else {
-         /* Something else.  Be paranoid and reject it. */
+         } else {
+           /* Something else.  Be paranoid and reject it. */
 
-         throw new KeyConfigException("unexpected object type in tarball");
+           throw new KeyConfigException(
+             s"entry `${e.name}' has unexpected object type");
+         }
        }
       }
     }
 
     /* There ought to be a file in here called `repos/master.pub'. */
-    val reposdir = unpkdir + "repos";
+    val reposdir = unpkdir/"repos";
+    val masterfile = reposdir/"master.pub";
+
     if (!reposdir.isdir_!)
       throw new KeyConfigException("missing `repos/' directory");
-    val masterfile = reposdir + "master.pub";
     if (!masterfile.isreg_!)
       throw new KeyConfigException("missing `repos/master.pub' file");
+    val mastertag = s"master-$seq";
 
     /* Fetch the master key's fingerprint. */
-    val (out, _) = runCommand("key", "-k", masterfile.getPath,
-                             "fingerprint",
-                             "-f", "-secret",
-                             "-a", conf("fingerprint-hash"),
-                             s"master-$seq");
-    println(s";; $out");
+    ic.operation("checking master key fingerprint") { _ =>
+      val foundfp = keyFingerprint(masterfile, mastertag,
+                                  conf("fingerprint-hash"));
+      val wantfp = conf("hk-master");
+      if (!fingerprintsEqual(wantfp, foundfp)) {
+       throw new KeyConfigException(
+         s"master key #$seq has wrong fingerprint: " +
+         s"expected $wantfp but found $foundfp");
+      }
+    }
+
+    /* Check the archive signature. */
+    ic.operation("verifying archive signature") { or =>
+      runCommand("catsign", "-k", masterfile.getPath, "verify", "-aqC",
+                "-k", mastertag, "-t", conf("sig-fresh"),
+                sigfile.getPath, tarfile.getPath);
+    }
+
+    /* Confirm that the configuration in the new archive is sane. */
+    checkConfigSanity(unpkdir/"tripe-keys.conf", ic);
+
+    /* Now we just have to juggle the files about. */
+    ic.operation("committing new configuration") { _ =>
+      unpkdir.rename_!(newdir);
+      livedir.rename_!(olddir);
+      newdir.rename_!(livedir);
+    }
+
+    invalidate();                      // should move to `Live'
   }
 }
 
index 9bfbc64..ab40c5d 100644 (file)
@@ -29,226 +29,95 @@ package uk.org.distorted.tripe; package object progress {
 
 import scala.collection.mutable.{Publisher, Subscriber};
 
-import java.lang.Math.ceil;
 import java.lang.System.currentTimeMillis;
 
-/*----- Main code ---------------------------------------------------------*/
+/*----- Progress displays -------------------------------------------------*/
 
-def formatTime(t: Int): String =
-  if (t < -1) "???"
-  else {
-    val (s, t1) = (t%60, t/60);
-    val (m, h) = (t1%60, t1/60);
-    if (h > 0) f"$h%d:$m%02d:$s%02d"
-    else f"$m%02d:$s%02d"
-  }
-
-private val UDATA = Seq("kB", "MB", "GB", "TB", "PB", "EB");
-def formatBytes(n: Long): String = {
-  val (x, u) = ((n.toDouble, "B ") /: UDATA) { (xu, n) => (xu, n) match {
-    case ((x, u), name) if x >= 1024.0 => (x/1024.0, name)
-    case (xu, _) => xu
-  } }
-  f"$x%6.1f$u%s"
-}
-
-trait Eyecandy {
-  def set(line: String);
-  def clear();
-  def commit();
-  def commit(line: String) { commit(); set(line); commit(); }
-  def begin(job: Job);
-}
-
-abstract class Event;                  // other subclasses can be added!
-abstract class Progress extends Event { def cur: Long; } // it changed
-object Progress {
-  def unapply(p: Progress) =
-    if (p == null) None
-    else Some(p.cur);
-}
-case class Update(override val cur: Long) extends Progress; // progress has been made
-case class Changed(override val cur: Long) extends Progress; // what or max changed
-abstract class Stopped extends Event;  // job has stopped
-case object Done extends Stopped;      // job completed successfuly
-final case class Failed(why: String) extends Stopped; // job failed
-case object Cancelled extends Stopped; // job was cancelled
-
-trait Job extends Publisher[Event] {
-  def what: String;                    // imperative for what we're doing
-  def cur: Long;                       // current position in work
-  def max: Long;                       // maximum work to do
-  def format: String = {               // describe progress in useful terms
-    val c = cur;
-    val m = max;
-    if (m >= 0) {
-      val fm = m.formatted("%d");
-      s"%${fm.length}d/%s".format(c, fm) // ugh!
-    } else if (c > 0) s"$c"
-    else ""
-  }
-  def cancel();
+trait Model {
+  protected val t0 = currentTimeMillis;
 
-  private[this] val t0 = currentTimeMillis;
-  type Pub = Job;
+  def what: String;
+  def max: Long;
 
-  def taken: Double = (currentTimeMillis - t0)/1000.0;
-  def eta: Double =
+  def eta(cur: Long): Double = {
     /* Report the estimated time remaining in seconds, or -1 if no idea.
      *
-     * The model here is very stupid.  Weird jobs should override this and do
-     * something more sensible.
+     * The model here is very stupid.  Weird jobs should override this and
+     * do something more sensible.
      */
 
-    if (max < 0 || cur <= 0) -1
-    else taken*(max - cur)/cur.toDouble;
-}
-
-/*----- Terminal eyecandy (FIXME: split this out) -------------------------*/
-
-import java.io.FileDescriptor;
-import java.lang.System.{out => stdout};
-import sys.isatty;
-
-object TerminalEyecandy extends Eyecandy with Subscriber[Event, Job] {
-  private var last = "";
-  var eyecandyp = isatty(FileDescriptor.out);
-
-  /* Assume that characters take up one cell each.  This is going to fail
-   * badly for combining characters, zero-width characters, wide Asian
-   * characters, and lots of other Unicode characters.  The problem is that
-   * Java doesn't have any way to query the display width of a character,
-   * and, honestly, I don't care enough to do the (substantial) work required
-   * to do this properly.
-   */
-
-  def set(line: String) {
-    if (eyecandyp) {
-
-      /* If the old line is longer than the new one, then we must overprint
-       * the end part.
-       */
-      if (line.length < last.length) {
-       val n = last.length - line.length;
-       for (_ <- 0 until n) stdout.write('\b');
-       for (_ <- 0 until n) stdout.write(' ');
-      }
-
-      /* Figure out the length of the common prefix between what we had
-       * before and what we have now.
-       */
-      val m = (0 until (last.length min line.length)) prefixLength
-       { i => last(i) == line(i) };
-
-      /* Delete the tail from the old line and print the new version. */
-      for (_ <- m until last.length) stdout.write('\b');
-      stdout.print(line.substring(m));
-      stdout.flush();
-    }
-
-    /* Update the state. */
-    last = line;
+    val max = this.max;
+    val delta = currentTimeMillis - t0
+    if (max < 0 || cur <= 0) -1 else delta*(max - cur)/cur.toDouble
   }
 
-  def clear() { set(""); }
+  protected def fmt1(n: Long): String = n.toString;
 
-  def commit() {
-    if (last != "") {
-      if (eyecandyp) stdout.write('\n');
-      else stdout.println(last);
-      last = "";
-    }
+  def format(cur: Long): String = {
+    val max = this.max;
+    val fc = fmt1(cur);
+    if (max >= 0) { val fm = fmt1(max); s"%${fm.length}s/%s".format(fc, fm) }
+    else if (cur > 0) fc
+    else ""
   }
+}
 
-  private final val spinner = """/-\|""";
-  private var step: Int = 0;
-  private final val width = 40;
-
-  def begin(job: Job) { job.subscribe(this); }
-
-  def notify(job: Job, ev: Event) {
-    ev match {
-      case Progress(cur) =>
-       /* Redraw the status line. */
-
-       val max = job.max;
-
-       val sb = new StringBuilder;
-       sb ++= job.what; sb += ' ';
-
-       /* Step the spinner. */
-       step += 1; if (step >= spinner.length) step = 0;
-       sb += spinner(step); sb += ' ';
-
-       /* Progress bar. */
-       if (max < 0)
-         sb ++= "[unknown progress]";
-       else {
-         val n = (width*cur/max).toInt;
-         sb += '[';
-         for (_ <- 0 until n) sb += '=';
-         for (_ <- n until 40) sb += ' ';
-         sb += ']';
-
-         val f = job.format;
-         if (f != "") { sb += ' '; sb ++= f; }
-         sb ++= (100*cur/max).formatted(" %3d%%");
-
-         val eta = job.eta;
-         if (eta >= 0) {
-           sb += ' '; sb += '(';
-           sb ++= formatTime(ceil(eta).toInt);
-           sb += ')';
-         }
-       }
-
-       /* Done. */
-       set(sb.result);
-
-      case Done =>
-       val t = formatTime(ceil(job.taken).toInt);
-       set(s"${job.what} done ($t)"); commit();
-
-      case Cancelled =>
-       set(s"${job.what} CANCELLED"); commit();
+class SimpleModel(val what: String, val max: Long) extends Model;
 
-      case Failed(msg) =>
-       set(s"${job.what} FAILED: $msg"); commit();
+private val UDATA = Seq("kB", "MB", "GB", "TB", "PB", "EB");
 
-      case _ => ok;
-    }
+trait DataModel extends Model {
+  override def fmt1(n: Long): String = {
+    val (x, u) = ((n.toDouble, "B ") /: UDATA) { (xu, n) => (xu, n) match {
+      case ((x, u), name) if x >= 1024.0 => (x/1024.0, name)
+      case (xu, _) => xu
+    } }
+    f"$x%6.1f$u%s"
   }
 }
 
-/*----- Testing cruft -----------------------------------------------------*/
-
-trait AsyncJob extends Job {
-  protected def run();
-  private var _cur: Long = 0; override def cur = _cur;
+trait BaseReporter {
+  def done();
+  def failed(e: Exception);
+}
 
-  
+trait JobReporter extends BaseReporter {
+  def step(cur: Long);
+  def change(model: Model, cur: Long);
 }
 
+trait OperationReporter extends BaseReporter {
+  def step(detail: String);
+}
 
+def withReporter[T, R <: BaseReporter]
+       (rep: R, body: R => T): T = {
+  val ret = try { body(rep) }
+  catch { case e: Exception => rep.failed(e); throw e; }
+  rep.done();
+  ret
+}
 
+trait Eyecandy {
+  def note(msg: String);
+  def clear();
+  def commit();
+  def record(msg: String) { note(msg); commit(); }
+  def done();
+  def cancelled() { failed("cancelled"); }
+  def failed(msg: String);
 
-import Thread.sleep;
+  def beginJob(model: Model): JobReporter
+    // = new JobReporter(model);
 
-class ToyJob(val max: Long) extends Job {
-  val what = "Dummy job";
-  private var _i: Long = 0; def cur = _i;
+  def beginOperation(what: String): OperationReporter
+    // = new OperationReporter(what);
 
-  def cancel() { ??? }
-  def run() {
-    for (i <- 1l until max) { _i = i; publish(Update(i)); sleep(100); }
-    publish(Done);
-  }
-}
+  def job[T](model: Model)(body: JobReporter => T): T =
+    withReporter(beginJob(model), body);
 
-def testjob(n: Long) {
-  val j = new ToyJob(n);
-  TerminalEyecandy.begin(j);
-  j.run();
+  def operation[T](what: String)(body: OperationReporter => T): T =
+    withReporter(beginOperation(what), body);
 }
 
 /*----- That's all, folks -------------------------------------------------*/
index cf0a72d..51ac170 100644 (file)
--- a/sys.scala
+++ b/sys.scala
@@ -175,23 +175,23 @@ object Errno extends Enumeration {
   private var wrong = -256;            // next synthetic code
   private def nextwrong: Int = { val w = wrong; wrong -= 1; w }
 
-  class Type private[Errno](tag: String, val code: Int, id: Int)
-         extends Val(id, tag) {
+  class Val private[Errno](tag: String, val code: Int, id: Int)
+         extends super.Val(id, tag) {
     /* Our augmented error type. */
 
     def message: String = strerror(code).toJString;
   }
   private class UnknownError(code: Int)
-         extends Type("<unknown>", code, code);
+         extends Val("<unknown>", code, code);
 
-  private def err(tag: String, code: Int): Type = {
+  private def err(tag: String, code: Int): Val = {
     /* Construct an error symbol given its tag string and a code number. */
 
-    if (code < 0) new Type(tag, code, code)
-    else if (seen contains code) new Type(tag, code, nextwrong)
-    else { seen += code; new Type(tag, code, code) }
+    if (code < 0) new Val(tag, code, code)
+    else if (seen contains code) new Val(tag, code, nextwrong)
+    else { seen += code; new Val(tag, code, code) }
   }
-  private def err(tag: String): Type =
+  private def err(tag: String): Val =
     err(tag, tagmap.getOrElse(tag, nextwrong));
 
   def byid(id: Int): Value = {
@@ -377,7 +377,7 @@ object Errno extends Enumeration {
   val EHWPOISON = err("EHWPOISON");
   /***end***/
 }
-import Errno.{Type => Errno, EEXIST, EISDIR, ENOENT, ENOTDIR};
+import Errno.{Val => Errno, EEXIST, EISDIR, ENOENT, ENOTDIR};
 
 object SystemError {
   /* Pattern matching for `SystemError', below. */
@@ -596,7 +596,7 @@ object FileImplicits {
     /* Constructing names of files in a directory.  Honestly, I'm surprised
      * there isn't a method for this already.
      */
-    def +(sub: String): File = new File(file, sub);
+    def /(sub: String): File = new File(file, sub);
 
     /* Simple file operations. */
     def unlink_!() { unlink(file.getPath); }
@@ -739,7 +739,7 @@ def freshFile(d: File): File = {
     /* Make the filename, and try to create the file.  If we succeed, we
      * win.
      */
-    val f = new File(d, b.result); b.clear();
+    val f = d/b.result; b.clear();
     try { f.mkfile_!(); exit(f); }
     catch { case SystemError(EEXIST, _) => ok; }
   }
@@ -798,23 +798,81 @@ def runCommand(cmd: String*): (String, String) = {
   }
 }
 
+/*----- Interrupt triggers ------------------------------------------------*/
+
+private val triggerLock = new Object;
+private final val maxTriggers = 2;
+private var nTriggers = 0;
+private var triggers: List[Wrapper] = Nil;
+
+@native protected def makeTrigger(): Wrapper;
+@native protected def destroyTrigger(trig: Wrapper);
+@native protected def resetTrigger(trig: Wrapper);
+@native protected def trigger(trig: Wrapper);
+
+private def getTrigger(): Wrapper = {
+  triggerLock synchronized {
+    if (nTriggers == 0)
+      makeTrigger()
+    else {
+      val trig = triggers.head;
+      triggers = triggers.tail;
+      nTriggers -= 1;
+      trig
+    }
+  }
+}
+
+private def putTrigger(trig: Wrapper) {
+  resetTrigger(trig);
+  triggerLock synchronized {
+    if (nTriggers >= maxTriggers)
+      destroyTrigger(trig);
+    else {
+      triggers ::= trig;
+      nTriggers += 1;
+    }
+  }
+}
+
+private def withTrigger[T](body: Wrapper => T): T = {
+  val trig = getTrigger();
+  try { body(trig) }
+  finally { putTrigger(trig); }
+}
+
+def interruptWithTrigger[T](body: Wrapper => T): T = {
+  /* interruptWithTrigger { TRIG => BODY }
+   *
+   * Execute BODY and return its result.  If the thread receives an
+   * interrupt, the trigger TRIG will be pulled.  See `interruptably' for the
+   * full semantics.
+   */
+
+  withTrigger { trig =>
+    interruptably { body(trig) } onInterrupt { trigger(trig); }
+  };
+}
+
 /*----- Connecting to a server --------------------------------------------*/
 
 /* Primitive operations. */
 final val CF_CLOSERD = 1;
 final val CF_CLOSEWR = 2;
 final val CF_CLOSEMASK = CF_CLOSERD | CF_CLOSEWR;
-@native protected def connect(path: CString): Wrapper;
+@native protected def connect(path: CString, trig: Wrapper): Wrapper;
 @native protected def send(conn: Wrapper, buf: CString,
-                          start: Int, len: Int);
+                          start: Int, len: Int, trig: Wrapper);
 @native protected def recv(conn: Wrapper, buf: CString,
-                          start: Int, len: Int): Int;
+                          start: Int, len: Int, trig: Wrapper): Int;
 @native def closeconn(conn: Wrapper, how: Int);
 
 class Connection(path: String) extends Closeable {
 
   /* The underlying primitive connection. */
-  private[this] val conn = connect(path.toCString);
+  private[this] val conn = interruptWithTrigger { trig =>
+    connect(path.toCString, trig);
+  };
 
   /* Alternative constructors. */
   def this(file: File) { this(file.getPath); }
@@ -834,7 +892,7 @@ class Connection(path: String) extends Closeable {
     override def read(buf: Array[Byte]): Int =
       read(buf, 0, buf.length);
     override def read(buf: Array[Byte], start: Int, len: Int) =
-      recv(conn, buf, start, len);
+      interruptWithTrigger { trig => recv(conn, buf, start, len, trig); };
     override def close() { closeconn(conn, CF_CLOSERD); }
   }
   lazy val input = new Input;
@@ -845,7 +903,7 @@ class Connection(path: String) extends Closeable {
     override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); }
     override def write(buf: Array[Byte]) { write(buf, 0, buf.length); }
     override def write(buf: Array[Byte], start: Int, len: Int)
-      { send(conn, buf, start, len); }
+      { interruptWithTrigger { trig => send(conn, buf, start, len, trig); } }
     override def close() { closeconn(conn, CF_CLOSEWR); }
   }
   lazy val output = new Output;
diff --git a/terminal.scala b/terminal.scala
new file mode 100644 (file)
index 0000000..9722e2b
--- /dev/null
@@ -0,0 +1,177 @@
+/* -*-scala-*-
+ *
+ * Terminal-based progress eyecandy
+ *
+ * (c) 2018 Straylight/Edgeware
+ */
+
+/*----- Licensing notice --------------------------------------------------*
+ *
+ * This file is part of the Trivial IP Encryption (TrIPE) Android app.
+ *
+ * TrIPE is free software: you can redistribute it and/or modify it under
+ * the terms of the GNU General Public License as published by the Free
+ * Software Foundation; either version 3 of the License, or (at your
+ * option) any later version.
+ *
+ * TrIPE is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with TrIPE.  If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package uk.org.distorted.tripe; package progress;
+
+/*----- Imports -----------------------------------------------------------*/
+
+import java.io.FileDescriptor;
+import java.lang.Math.ceil;
+import java.lang.System.{currentTimeMillis, out => stdout};
+
+import sys.isatty;
+
+/*----- Main code ---------------------------------------------------------*/
+
+object TerminalEyecandy extends Eyecandy {
+  private var last = "";
+  var eyecandyp = isatty(FileDescriptor.out);
+
+  /* Assume that characters take up one cell each.  This is going to fail
+   * badly for combining characters, zero-width characters, wide Asian
+   * characters, and lots of other Unicode characters.  The problem is that
+   * Java doesn't have any way to query the display width of a character,
+   * and, honestly, I don't care enough to do the (substantial) work required
+   * to do this properly.
+   */
+
+  def note(line: String) {
+    if (eyecandyp) {
+
+      /* If the old line is longer than the new one, then we must overprint
+       * the end part.
+       */
+      if (line.length < last.length) {
+       val n = last.length - line.length;
+       for (_ <- 0 until n) stdout.write('\b');
+       for (_ <- 0 until n) stdout.write(' ');
+      }
+
+      /* Figure out the length of the common prefix between what we had
+       * before and what we have now.
+       */
+      val m = (0 until (last.length min line.length)) prefixLength
+       { i => last(i) == line(i) };
+
+      /* Delete the tail from the old line and print the new version. */
+      for (_ <- m until last.length) stdout.write('\b');
+      stdout.print(line.substring(m));
+      stdout.flush();
+    }
+
+    /* Update the state. */
+    last = line;
+  }
+
+  def clear() { note(""); }
+
+  def commit() {
+    if (last != "") {
+      if (eyecandyp) stdout.write('\n');
+      else stdout.println(last);
+      last = "";
+    }
+  }
+
+  def done() { clear(); }
+  def failed(msg: String) { record(s"FAILED!  $msg"); }
+
+  def beginJob(model: Model): progress.JobReporter =
+    new JobReporter(model);
+
+  def beginOperation(what: String): progress.OperationReporter =
+    new OperationReporter(what);
+
+  private[this] class JobReporter(private[this] var model: Model)
+         extends progress.JobReporter {
+    private final val width = 40;
+    private final val spinner = """/-\|""";
+    private final val mingap = 100;
+    private[this] var step: Int = 0;
+    private[this] var sweep: Int = 0;
+    private[this] val t0 = currentTimeMillis;
+    private[this] var last: Long = -1;
+
+    def change(model: Model, cur: Long)
+      { last = -1; this.model = model; step(cur); }
+
+    def step(cur: Long) {
+      val now = currentTimeMillis;
+      if (last >= 0 && now - last < mingap) return;
+      last = now;
+
+      val max = model.max;
+      val sb = new StringBuilder;
+      sb ++= model.what; sb += ' ';
+
+      /* Step the spinner. */
+      sb += spinner(step); sb += ' ';
+      step += 1; if (step >= spinner.length) step = 0;
+
+      /* Progress bar. */
+      sb += '[';
+      if (max <= 0) {
+       val l = sweep; val r = width - 1 - sweep;
+       val (lo, hi, x, y) = if (l < r) (l, r, '>', '<')
+                            else (r, l, '<', '>');
+       for (_ <- 0 until lo) sb += ' ';
+       sb += x;
+       for (_ <- lo + 1 until hi) sb += ' ';
+       sb += y;
+       for (_ <- hi + 1 until width) sb += ' ';
+       sweep += 1; if (sweep >= width) sweep = 0;
+      } else {
+       val n = (width*cur/max).toInt;
+       for (_ <- 0 until n) sb += '=';
+       for (_ <- n until width) sb += ' ';
+      }
+      sb += ']';
+
+      /* Quantitative progress. */
+      val f = model.format(cur); if (f != "") { sb += ' '; sb ++= f; }
+      if (max > 0) sb ++= (100*cur/max).formatted(" %3d%%");
+
+      /* Estimated time to completion. */
+      val eta = model.eta(cur);
+      if (eta >= 0) {
+       sb += ' '; sb += '(';
+       sb ++= formatTime(ceil(eta/1000.0).toInt);
+       sb += ')';
+      }
+
+      /* Done. */
+      note(sb.result);
+    }
+
+    def done() {
+      val t = formatTime(ceil((currentTimeMillis - t0)/1000.0).toInt);
+      record(s"${model.what} done ($t)");
+    }
+
+    def failed(e: Exception)
+      { record(s"${model.what} FAILED: ${e.getMessage}"); }
+
+    step(0);
+  }
+
+  class OperationReporter(what: String) extends progress.OperationReporter {
+    def step(detail: String) { note(s"$what: $detail"); }
+    def done() { record(s"$what: ok"); }
+    def failed(e: Exception) { record(s"$what: ${e.getMessage}"); }
+    step("...");
+  }
+}
+
+/*----- That's all, folks -------------------------------------------------*/
index 75b3677..ea776e7 100644 (file)
@@ -31,9 +31,12 @@ import scala.concurrent.duration.{Deadline, Duration};
 import scala.util.control.{Breaks, ControlThrowable};
 
 import java.io.{BufferedReader, Closeable, File, InputStream, Reader};
-import java.net.{URL, URLConnection};
+import java.net.{HttpURLConnection, URL, URLConnection};
 import java.nio.{ByteBuffer, CharBuffer};
+import java.nio.channels.{SelectionKey, Selector};
+import java.nio.channels.spi.{AbstractSelector, AbstractSelectableChannel};
 import java.nio.charset.Charset;
+import java.util.{Set => JSet};
 import java.util.concurrent.locks.{Lock, ReentrantLock};
 
 /*----- Miscellaneous useful things ---------------------------------------*/
@@ -194,6 +197,61 @@ def loopUnit(body: (=> Nothing) => Unit): Unit = {
 val BREAKS = new Breaks;
 import BREAKS.{breakable, break};
 
+/*----- Interruptably doing things ----------------------------------------*/
+
+private class InterruptCatcher[T](body: => T, onWakeup: => Unit)
+       extends AbstractSelector(null) {
+  /* Hook onto the VM's thread interruption machinery.
+   *
+   * The `run' method is the only really interesting one.  It will run the
+   * BODY, returning its result; if the thread is interrupted during this
+   * time, ONWAKEUP is invoked for effect.  The expectation is that ONWAKEUP
+   * will somehow cause BODY to stop early.
+   *
+   * Credit for this hack goes to Nicholas Wilson: see
+   * <https://github.com/NWilson/javaInterruptHook>.
+   */
+
+  private def nope: Nothing =
+    { throw new UnsupportedOperationException("can't do that"); }
+  protected def implCloseSelector() { }
+  protected def register(chan: AbstractSelectableChannel,
+                                 ops: Int, att: Any): SelectionKey = nope;
+  def keys(): JSet[SelectionKey] = nope;
+  def selectedKeys(): JSet[SelectionKey] = nope;
+  def select(): Int = nope;
+  def select(millis: Long): Int = nope;
+  def selectNow(): Int = nope;
+
+  def run(): T = try {
+    begin();
+    val ret = body;
+    if (Thread.interrupted()) throw new InterruptedException;
+    ret
+  } finally {
+    end();
+  }
+  def wakeup(): Selector = { onWakeup; this }
+}
+
+class PendingInterruptable[T] private[tripe](body: => T) {
+  /* This class exists to provide the `onInterrupt THUNK' syntax. */
+
+  def onInterrupt(thunk: => Unit): T =
+    new InterruptCatcher(body, thunk).run();
+}
+def interruptably[T](body: => T) = {
+  /* interruptably { BODY } onInterrupt { THUNK }
+   *
+   * Execute BODY and return its result.  If the thread receives an
+   * interrupt -- or is already in an interrupted state -- execute THUNK for
+   * effect; it is expected to cause BODY to return expeditiously, and when
+   * the BODY completes, an `InterruptedException' is thrown.
+   */
+
+  new PendingInterruptable(body);
+}
+
 /*----- A gadget for fetching URLs ----------------------------------------*/
 
 class URLFetchException(msg: String) extends Exception(msg);
@@ -208,47 +266,74 @@ def fetchURL(url: URL, cb: URLFetchCallbacks) {
   /* Fetch the URL, feeding the data through the callbacks CB. */
 
   withCleaner { clean =>
-    var win: Boolean = false;
-    clean { cb.done(win); }
+    var win: Boolean = false; clean { cb.done(win); }
 
-    /* Set up the connection, and run a preflight check. */
+    /* Set up the connection.  This isn't going to block, I think, and we
+     * need to use it in the interrupt handler.
+     */
     val c = url.openConnection();
-    cb.preflight(c);
-
-    /* Start fetching data. */
-    val in = c.getInputStream; clean { in.close(); }
-    val explen = c.getContentLength;
 
-    /* Read a buffer at a time, and give it to the callback.  Maintain a
-     * running total.
+    /* Java's default URL handlers don't respond to interrupts, so we have to
+     * take over this duty.
      */
-    var len: Long = 0;
-    blockUnit { exit =>
-      for ((buf, n) <- blocks(in)) {
-       cb.write(buf, n, len);
-       len += n;
-       if (explen != -1 && len > explen) exit;
+    interruptably {
+      /* Run the caller's preflight check.  This must be done here, since it
+       * might well block while it discovers things like the content length.
+       */
+      cb.preflight(c);
+
+      /* Start fetching data. */
+      val in = c.getInputStream; clean { in.close(); }
+      val explen = c.getContentLength;
+
+      /* Read a buffer at a time, and give it to the callback.  Maintain a
+       * running total.
+       */
+      var len: Long = 0;
+      blockUnit { exit =>
+       for ((buf, n) <- blocks(in)) {
+         cb.write(buf, n, len);
+         len += n;
+         if (explen != -1 && len > explen) exit;
+       }
       }
-    }
 
-    /* I can't find it documented anywhere that the existing machinery
-     * checks the received stream against the advertised content length.
-     * It doesn't hurt to check again, anyway.
-     */
-    if (explen != -1 && explen != len) {
-      throw new URLFetchException(
-       s"received $len /= $explen bytes from `$url'");
-    }
+      /* I can't find it documented anywhere that the existing machinery
+       * checks the received stream against the advertised content length.
+       * It doesn't hurt to check again, anyway.
+       */
+      if (explen != -1 && explen != len) {
+       throw new URLFetchException(
+         s"received $len /= $explen bytes from `$url'");
+      }
 
-    /* Glorious success is ours. */
-    win = true;
+      /* Glorious success is ours. */
+      win = true;
+    } onInterrupt {
+      /* Oh.  How do we do this? */
+
+      c match {
+       case c: HttpURLConnection =>
+         /* It's an HTTP connection (what happened to the case here?).
+          * HTTPS connections match too because they're a subclass.  Getting
+          * the input stream will block, but there's an easier way.
+          */
+         c.disconnect();
+
+       case _ =>
+         /* It's something else.  Let's hope that getting the input stream
+          * doesn't block.
+          */
+       c.getInputStream.close();
+      }
+    }
   }
 }
 
 /*----- Threading things --------------------------------------------------*/
 
-def thread[T](name: String, run: Boolean = true, daemon: Boolean = true)
-            (f: => T): Thread = {
+def thread(name: String, run: Boolean = true, daemon: Boolean = true)
+         (f: => Unit): Thread = {
   /* Make a thread with a given name, and maybe start running it. */
 
   val t = new Thread(new Runnable { def run() { f; } }, name);
@@ -257,6 +342,28 @@ def thread[T](name: String, run: Boolean = true, daemon: Boolean = true)
   t
 }
 
+class ValueThread[T](name: String, group: ThreadGroup = null,
+                    stacksz: Long = 0)(body: => T)
+       extends Thread(group, null, name, stacksz) {
+  private[this] var exc: Throwable = _;
+  private[this] var ret: T = _;
+
+  override def run() {
+    try { ret = body; }
+    catch { case e: Throwable => exc = e; }
+  }
+  def get: T =
+    if (isAlive) throw new IllegalArgumentException("still running");
+    else if (exc != null) throw exc;
+    else ret;
+}
+def valueThread[T](name: String, run: Boolean = true)
+                 (body: => T): ValueThread[T] = {
+  val t = new ValueThread(name)(body);
+  if (run) t.start();
+  t
+}
+
 /*----- Quoting and parsing tokens ----------------------------------------*/
 
 def quoteTokens(v: Seq[String]): String = {
@@ -416,7 +523,7 @@ def blocks(in: InputStream, blksz: Int):
     }
   }
 def blocks(in: InputStream):
-       BufferedIterator[(Array[Byte], Int)] = blocks(in, 4096);
+       BufferedIterator[(Array[Byte], Int)] = blocks(in, 65536);
 
 def blocks(in: BufferedReader, blksz: Int):
        BufferedIterator[(Array[Char], Int)] =
@@ -430,7 +537,7 @@ def blocks(in: BufferedReader, blksz: Int):
     }
   }
 def blocks(in: BufferedReader):
-       BufferedIterator[(Array[Char], Int)] = blocks(in, 4096);
+       BufferedIterator[(Array[Char], Int)] = blocks(in, 65536);
 def blocks(r: Reader, blksz: Int): BufferedIterator[(Array[Char], Int)] =
   blocks(bufferedReader(r), blksz);
 def blocks(r: Reader): BufferedIterator[(Array[Char], Int)] =
@@ -454,6 +561,15 @@ def oxford(conj: String, things: Seq[String]): String = things match {
     sb.result
 }
 
+def formatTime(t: Int): String =
+  if (t < -1) "???"
+  else {
+    val (s, t1) = (t%60, t/60);
+    val (m, h) = (t1%60, t1/60);
+    if (h > 0) f"$h%d:$m%02d:$s%02d"
+    else f"$m%02d:$s%02d"
+  }
+
 /*----- That's all, folks -------------------------------------------------*/
 
 }