udpkey.c: Refactor the client side of the protocol.
[udpkey] / udpkey.c
index b3dca9b..9c23ee0 100644 (file)
--- a/udpkey.c
+++ b/udpkey.c
@@ -801,6 +801,7 @@ static int dolisten(int argc, char *argv[])
 
 struct query {
   struct query *next;
+  const char *tag;
   octet *k;
   size_t sz;
   struct server *s;
@@ -810,16 +811,151 @@ struct server {
   struct server *next;
   struct sockaddr_in sin;
   struct kinfo k;
+  const struct client_protocol *proto;
   mp *u;
   ge *U;
   octet *h;
 };
 
+struct client_protocol {
+  const char *name;
+  int (*setup)(struct query *, struct server *);
+  int (*receive)(struct query *, struct server *, buf *, buf *);
+  int (*retransmit)(struct query *, struct server *, buf *);
+};
+
 /* Record a successful fetch of key material for a query Q.  The data starts
  * at K and is SZ bytes long.  The data is copied: it's safe to overwrite it.
  */
-static void donequery(struct query *q, const void *k, size_t sz)
-  { q->k = xmalloc(sz); memcpy(q->k, k, sz); q->sz = sz; nq--; }
+static int donequery(struct query *q, struct server *s,
+                     const void *k, size_t sz)
+{
+  octet *tt;
+  ghash *h = 0;
+  int diffp;
+
+  /* If we have a hash, check that the fragment matches it. */
+  if (s && s->h) {
+    h = GH_INIT(s->k.hc);
+    GH_HASH(h, k, sz);
+    tt = GH_DONE(h, 0);
+    diffp = memcmp(tt, s->h, h->ops->c->hashsz);
+    GH_DESTROY(h);
+    if (diffp) {
+      moan("response from %s:%d doesn't match hash",
+          inet_ntoa(s->sin.sin_addr), ntohs(s->sin.sin_port));
+      return (-1);
+    }
+  }
+
+  /* Stash a copy of the key fragment for later. */
+  q->k = xmalloc(sz);
+  memcpy(q->k, k, sz);
+  q->sz = sz; nq--;
+
+  /* All good. */
+  return (0);
+}
+
+static int setup_v0(struct query *q, struct server *s)
+{
+  /* Choose an ephemeral private key u.  Let x be our private key.  We
+   * compute U = u P and transmit this.
+   */
+  s->u = mprand_range(MP_NEW, s->k.g->r, &rand_global, 0);
+  s->U = G_CREATE(s->k.g);
+  G_EXP(s->k.g, s->U, s->k.g->g, s->u);
+  D( debug_mp("u", s->u); debug_ge("U", s->k.g, s->U); )
+
+  return (0);
+}
+
+static int retransmit_v0(struct query *q, struct server *s, buf *bout)
+{
+  buf_putstrz(bout, q->tag);
+  G_TOBUF(s->k.g, bout, s->U);
+  return (0);
+}
+
+static int receive_v0(struct query *q, struct server *s, buf *bin, buf *bout)
+{
+  ge *R, *V = 0, *W = 0, *Y = 0, *Z = 0;
+  octet *kk, *t, *tt;
+  gcipher *c = 0;
+  gmac *m = 0;
+  ghash *h = 0;
+  size_t n, ksz;
+  octet *p;
+  int rc = -1;
+
+  R = G_CREATE(s->k.g);
+  V = G_CREATE(s->k.g); W = G_CREATE(s->k.g);
+  Y = G_CREATE(s->k.g); Z = G_CREATE(s->k.g);
+  if (G_FROMBUF(s->k.g, bin, V)) {
+    moan("invalid Diffie--Hellman vector from %s:%d",
+        inet_ntoa(s->sin.sin_addr), ntohs(s->sin.sin_port));
+    goto done;
+  }
+  if (G_FROMBUF(s->k.g, bin, W)) {
+    moan("invalid clue vector from %s:%d",
+        inet_ntoa(s->sin.sin_addr), ntohs(s->sin.sin_port));
+    goto done;
+  }
+  D( debug_ge("V", s->k.g, V); debug_ge("W", s->k.g, W); )
+
+  /* We have V and W from the server; determine Y = u V, R = W + Y and
+   * Z = x R, and then derive the symmetric keys.
+   */
+  G_EXP(s->k.g, Y, V, s->u);
+  G_MUL(s->k.g, R, W, Y);
+  G_EXP(s->k.g, Z, R, s->k.x);
+  D( debug_ge("R", s->k.g, R);
+     debug_ge("Y", s->k.g, Y);
+     debug_ge("Z", s->k.g, Z); )
+  derive(&s->k, R, Z, "cipher", s->k.cc->name, s->k.cc->keysz, &kk, &ksz);
+  c = GC_INIT(s->k.cc, kk, ksz);
+  derive(&s->k, R, Z, "mac", s->k.cc->name, s->k.cc->keysz, &kk, &ksz);
+  m = GM_KEY(s->k.mc, kk, ksz);
+
+  /* Find where the MAC tag is. */
+  if ((t = buf_get(bin, s->k.tagsz)) == 0) {
+    moan("missing tag from %s:%d",
+        inet_ntoa(s->sin.sin_addr), ntohs(s->sin.sin_port));
+    goto done;
+  }
+
+  /* Check the integrity of the ciphertext against the tag. */
+  p = BCUR(bin); n = BLEFT(bin);
+  h = GM_INIT(m);
+  GH_HASH(h, p, n);
+  tt = GH_DONE(h, 0);
+  if (!ct_memeq(t, tt, s->k.tagsz)) {
+    moan("incorrect tag from %s:%d",
+        inet_ntoa(s->sin.sin_addr), ntohs(s->sin.sin_port));
+    goto done;
+  }
+
+  /* Decrypt the result and declare this server done. */
+  GC_DECRYPT(c, p, p, n);
+  rc = donequery(q, s, p, n);
+
+done:
+  /* Clear up and go home. */
+  if (R) G_DESTROY(s->k.g, R);
+  if (V) G_DESTROY(s->k.g, V);
+  if (W) G_DESTROY(s->k.g, W);
+  if (Y) G_DESTROY(s->k.g, Y);
+  if (Z) G_DESTROY(s->k.g, Z);
+  if (c) GC_DESTROY(c);
+  if (m) GM_DESTROY(m);
+  if (h) GH_DESTROY(h);
+  return (rc);
+}
+
+static const struct client_protocol prototab[] = {
+  { "v0", setup_v0, receive_v0, retransmit_v0 },
+  { 0 }
+};
 
 /* Initialize a query to a remote server. */
 static struct query *qinit_net(const char *tag, const char *spec)
@@ -827,11 +963,13 @@ static struct query *qinit_net(const char *tag, const char *spec)
   struct query *q;
   struct server *s, **stail;
   dstr d = DSTR_INIT, dd = DSTR_INIT;
+  const struct client_protocol *proto;
   hex_ctx hc;
   char *p, *pp, ch;
 
   /* Allocate the query block. */
   q = CREATE(struct query);
+  q->tag = tag;
   stail = &q->s;
 
   /* Put the spec somewhere we can hack at it. */
@@ -859,6 +997,20 @@ static struct query *qinit_net(const char *tag, const char *spec)
     ch = *pp; *pp++ = 0;
     s->sin.sin_port = htons(getport(p));
 
+    /* See if there's a protocol name. */
+    if (ch != '?')
+      p = "v0";
+    else {
+      p = pp;
+      pp += strcspn(pp, ";#=");
+      ch = *pp; *pp++ = 0;
+    }
+    for (proto = prototab; proto->name; proto++)
+      if (strcmp(proto->name, p) == 0) goto found_proto;
+    die(1, "unknown protocol name `%s'", p);
+  found_proto:
+    s->proto = proto;
+
     /* If there's a key tag then extract that; otherwise use a default. */
     if (ch != '=')
       p = "udpkey-kem";
@@ -870,14 +1022,6 @@ static struct query *qinit_net(const char *tag, const char *spec)
     if (loadkey(p, &s->k, 1)) exit(1);
     D( debug_mp("x", s->k.x); debug_ge("X", s->k.g, s->k.X); )
 
-    /* Choose an ephemeral private key u.  Let x be our private key.  We
-     * compute U = u P and transmit this.
-     */
-    s->u = mprand_range(MP_NEW, s->k.g->r, &rand_global, 0);
-    s->U = G_CREATE(s->k.g);
-    G_EXP(s->k.g, s->U, s->k.g->g, s->u);
-    D( debug_mp("u", s->u); debug_ge("U", s->k.g, s->U); )
-
     /* Link the server on. */
     *stail = s; stail = &s->next;
 
@@ -896,6 +1040,9 @@ static struct query *qinit_net(const char *tag, const char *spec)
       ch = *pp++;
     }
 
+    /* Initialize the protocol. */
+    if (s->proto->setup(q, s)) die(1, "failed to initialize protocol");
+
     /* If there are more servers, then continue parsing. */
     if (!ch) break;
     else if (ch != ';') die(1, "invalid syntax: expected `;'");
@@ -922,7 +1069,7 @@ static struct query *qinit_file(const char *tag, const char *file)
   if (snarf(file, &k, &sz))
     die(1, "failed to read `%s': %s", file, strerror(errno));
   q->s = 0;
-  donequery(q, k, sz);
+  donequery(q, 0, k, sz);
   return (q);
 }
 
@@ -941,15 +1088,10 @@ static int doquery(int argc, char *argv[])
   fd_set fdin;
   struct timeval now, when, tv;
   struct sockaddr_in sin;
-  ge *R, *V = 0, *W = 0, *Y = 0, *Z = 0;
-  octet *kk, *t, *tt;
-  gcipher *c = 0;
-  gmac *m = 0;
-  ghash *h = 0;
   socklen_t len;
   unsigned next = 0;
   buf bin, bout;
-  size_t n, j, ksz;
+  size_t n, j;
   ssize_t nn;
 
   /* Create a socket.  We just use the one socket for everything.  We don't
@@ -989,8 +1131,7 @@ static int doquery(int argc, char *argv[])
        if (q->k) continue;
        for (s = q->s; s; s = s->next) {
          buf_init(&bout, obuf, sizeof(obuf));
-         buf_putstrz(&bout, tag);
-         G_TOBUF(s->k.g, &bout, s->U);
+         if (s->proto->retransmit(q, s, &bout)) continue;
          if (BBAD(&bout)) {
            moan("overflow while constructing request!");
            continue;
@@ -1022,7 +1163,7 @@ static int doquery(int argc, char *argv[])
          else if (errno == EINTR) continue;
          else {
            moan("error receiving reply: %s", strerror(errno));
-           goto again;
+           continue;
          }
        }
 
@@ -1039,89 +1180,27 @@ static int doquery(int argc, char *argv[])
        }
        moan("received reply from unexpected source %s:%d",
             inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
-       goto again;
+       continue;
 
       found:
        /* If the query we found has now been satisfied, ignore this packet.
         */
-       if (q->k) goto again;
-
-       /* Start parsing the reply. */
-       buf_init(&bin, ibuf, nn);
-       R = G_CREATE(s->k.g);
-       V = G_CREATE(s->k.g); W = G_CREATE(s->k.g);
-       Y = G_CREATE(s->k.g); Z = G_CREATE(s->k.g);
-       if (G_FROMBUF(s->k.g, &bin, V)) {
-         moan("invalid Diffie--Hellman vector from %s:%d",
-              inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
-         goto again;
-       }
-       if (G_FROMBUF(s->k.g, &bin, W)) {
-         moan("invalid clue vector from %s:%d",
-              inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
-         goto again;
-       }
-       D( debug_ge("V", s->k.g, V); debug_ge("W", s->k.g, W); )
+       if (q->k) continue;
 
-       /* We have V and W from the server; determine Y = u V, R = W + Y and
-        * Z = x R, and then derive the symmetric keys.
+       /* Parse the reply, and either finish the job or get a message to
+        * send back to the server.
         */
-       G_EXP(s->k.g, Y, V, s->u);
-       G_MUL(s->k.g, R, W, Y);
-       G_EXP(s->k.g, Z, R, s->k.x);
-       D( debug_ge("R", s->k.g, R);
-          debug_ge("Y", s->k.g, Y);
-          debug_ge("Z", s->k.g, Z); )
-       derive(&s->k, R, Z, "cipher", s->k.cc->name, s->k.cc->keysz,
-              &kk, &ksz);
-       c = GC_INIT(s->k.cc, kk, ksz);
-       derive(&s->k, R, Z, "mac", s->k.cc->name, s->k.cc->keysz,
-              &kk, &ksz);
-       m = GM_KEY(s->k.mc, kk, ksz);
-
-       /* Find where the MAC tag is. */
-       if ((t = buf_get(&bin, s->k.tagsz)) == 0) {
-         moan("missing tag from %s:%d",
-              inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
-         goto again;
-       }
-
-       /* Check the integrity of the ciphertext against the tag. */
-       p = BCUR(&bin); n = BLEFT(&bin);
-       h = GM_INIT(m);
-       GH_HASH(h, p, n);
-       tt = GH_DONE(h, 0);
-       if (!ct_memeq(t, tt, s->k.tagsz)) {
-         moan("incorrect tag from %s:%d",
-              inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
-         goto again;
-       }
-
-       /* Decrypt the result and declare this server done. */
-       GC_DECRYPT(c, p, p, n);
-       if (s->h) {
-         GH_DESTROY(h);
-         h = GH_INIT(s->k.hc);
-         GH_HASH(h, p, n);
-         tt = GH_DONE(h, 0);
-         if (memcmp(tt, s->h, h->ops->c->hashsz) != 0) {
-           moan("response from %s:%d doesn't match hash",
-                inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));
-           goto again;
-         }
+       buf_init(&bin, ibuf, nn);
+       buf_init(&bout, obuf, sizeof(obuf));
+       if (s->proto->receive(q, s, &bin, &bout)) continue;
+       if (q->k) continue;
+       if (!BLEN(&bout) && s->proto->retransmit(q, s, &bout)) continue;
+       if (BBAD(&bout)) {
+         moan("overflow while constructing request!");
+         continue;
        }
-       donequery(q, p, n);
-
-      again:
-       /* Tidy things up for the next run through. */
-       if (R) { G_DESTROY(s->k.g, R); R = 0; }
-       if (V) { G_DESTROY(s->k.g, V); V = 0; }
-       if (W) { G_DESTROY(s->k.g, W); W = 0; }
-       if (Y) { G_DESTROY(s->k.g, Y); Y = 0; }
-       if (Z) { G_DESTROY(s->k.g, Z); Z = 0; }
-       if (c) { GC_DESTROY(c); c = 0; }
-       if (m) { GM_DESTROY(m); m = 0; }
-       if (h) { GH_DESTROY(h); h = 0; }
+       sendto(sk, BBASE(&bout), BLEN(&bout), 0,
+              (struct sockaddr *)&s->sin, sizeof(s->sin));
       }
     }
   }