@@@ crypto-test
[secnet] / crypto-test.c
diff --git a/crypto-test.c b/crypto-test.c
new file mode 100644 (file)
index 0000000..137f7a3
--- /dev/null
@@ -0,0 +1,420 @@
+/*
+ * crypto-test.c: common test vector processing
+ */
+/*
+ * This file is Free Software.  It was originally written for secnet.
+ *
+ * Copyright 2017 Mark Wooding
+ *
+ * You may redistribute secnet as a whole and/or modify it under the
+ * terms of the GNU General Public License as published by the Free
+ * Software Foundation; either version 3, or (at your option) any
+ * later version.
+ *
+ * You may redistribute this file and/or modify it under the terms of
+ * the GNU General Public License as published by the Free Software
+ * Foundation; either version 2, or (at your option) any later
+ * version.
+ *
+ * This software is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this software; if not, see
+ * https://www.gnu.org/licenses/gpl.html.
+ */
+
+#include <assert.h>
+#include <errno.h>
+#include <ctype.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "secnet.h"
+#include "util.h"
+
+#include "crypto-test.h"
+
+/*----- Utilities ---------------------------------------------------------*/
+
+static void *xmalloc(size_t sz)
+{
+    void *p;
+
+    if (!sz) return 0;
+    p = malloc(sz);
+    if (!p) {
+       fprintf(stderr, "out of memory!\n");
+       exit(2);
+    }
+    return p;
+}
+
+static void *xrealloc(void *p, size_t sz)
+{
+    void *q;
+
+    if (!sz) { free(p); return 0; }
+    else if (!p) return xmalloc(sz);
+    q = realloc(p, sz);
+    if (!q) {
+       fprintf(stderr, "out of memory!\n");
+       exit(2);
+    }
+    return q;
+}
+
+static int lno;
+
+void bail(const char *msg, ...)
+{
+    va_list ap;
+    va_start(ap, msg);
+    fprintf(stderr, "unexpected error (line %d): ", lno);
+    vfprintf(stderr, msg, ap);
+    va_end(ap);
+    fputc('\n', stderr);
+    exit(2);
+}
+
+struct linebuf {
+    char *p;
+    size_t sz;
+};
+#define LINEBUF_INIT { 0, 0 };
+
+static int read_line(struct linebuf *b, FILE *fp)
+{
+    size_t n = 0;
+    int ch;
+
+    ch = getc(fp); if (ch == EOF) return EOF;
+    for (;;) {
+       if (n >= b->sz) {
+           b->sz = b->sz ? 2*b->sz : 64;
+           b->p = xrealloc(b->p, b->sz);
+       }
+       if (ch == EOF || ch == '\n') { b->p[n++] = 0; return 0; }
+       b->p[n++] = ch;
+       ch = getc(fp);
+    }
+}
+
+void parse_hex(uint8_t *b, size_t sz, char *p)
+{
+    size_t n = strlen(p);
+    unsigned i;
+    char bb[3];
+
+    if (n%2) bail("bad hex (odd number of nibbles)");
+    else if (n/2 != sz) bail("bad hex (want %zu bytes, found %zu)", sz, n/2);
+    while (sz) {
+       for (i = 0; i < 2; i++) {
+           if (!isxdigit((unsigned char)p[i]))
+               bail("bad hex digit `%c'", p[i]);
+           bb[i] = p[i];
+       }
+       bb[2] = 0;
+       p += 2;
+       *b++ = strtoul(bb, 0, 16); sz--;
+    }
+}
+
+void dump_hex(FILE *fp, const uint8_t *b, size_t sz)
+    { while (sz--) fprintf(fp, "%02x", *b++); fputc('\n', fp); }
+
+void trivial_regty_init(union regval *v) { ; }
+void trivial_regty_release(union regval *v) { ; }
+
+/* Define some global variables we shouldn't need.
+ *
+ * Annoyingly, `secnet.h' declares static pointers and initializes them to
+ * point to some external variables.  At `-O0', GCC doesn't optimize these
+ * away, so there's a link-time dependency on these variables.  Define them
+ * here, so that `f25519.c' and `f448.c' can find them.
+ *
+ * (Later GCC has `-Og', which optimizes without making debugging a
+ * nightmare, but I'm not running that version here.  Note that `serpent.c'
+ * doesn't have this problem because it defines its own word load and store
+ * operations to cope with its endian weirdness, whereas the field arithmetic
+ * uses `unaligned.h' which manages to include `secnet.h'.)
+ */
+uint64_t now_global;
+struct timeval tv_now_global;
+
+/* Bletch.  util.c is a mess of layers. */
+int consttime_memeq(const void *s1in, const void *s2in, size_t n)
+{
+    const uint8_t *s1=s1in, *s2=s2in;
+    register volatile uint8_t accumulator=0;
+
+    while (n-- > 0) {
+       accumulator |= (*s1++ ^ *s2++);
+    }
+    accumulator |= accumulator >> 4; /* constant-time             */
+    accumulator |= accumulator >> 2; /*  boolean canonicalisation */
+    accumulator |= accumulator >> 1;
+    accumulator &= 1;
+    accumulator ^= 1;
+    return accumulator;
+}
+
+/*----- Built-in types ----------------------------------------------------*/
+
+/* Signed integer. */
+
+static void parse_int(union regval *v, char *p)
+{
+    char *q;
+
+    errno = 0;
+    v->i = strtol(p, &q, 0);
+    if (*q || errno) bail("bad integer `%s'", p);
+}
+
+static void dump_int(FILE *fp, const union regval *v)
+    { fprintf(fp, "%ld\n", v->i); }
+
+static int eq_int(const union regval *v0, const union regval *v1)
+    { return (v0->i == v1->i); }
+
+const struct regty regty_int = {
+    trivial_regty_init,
+    parse_int,
+    dump_int,
+    eq_int,
+    trivial_regty_release
+};
+
+/* Unsigned integer. */
+
+static void parse_uint(union regval *v, char *p)
+{
+    char *q;
+
+    errno = 0;
+    v->u = strtoul(p, &q, 0);
+    if (*q || errno) bail("bad integer `%s'", p);
+}
+
+static void dump_uint(FILE *fp, const union regval *v)
+    { fprintf(fp, "%lu\n", v->u); }
+
+static int eq_uint(const union regval *v0, const union regval *v1)
+    { return (v0->u == v1->u); }
+
+const struct regty regty_uint = {
+    trivial_regty_init,
+    parse_uint,
+    dump_uint,
+    eq_uint,
+    trivial_regty_release
+};
+
+/* Byte string, as hex. */
+
+void allocate_bytes(union regval *v, size_t sz)
+    { v->bytes.p = xmalloc(sz); v->bytes.sz = sz; }
+
+static void init_bytes(union regval *v) { v->bytes.p = 0; v->bytes.sz = 0; }
+
+static void parse_bytes(union regval *v, char *p)
+{
+    size_t n = strlen(p);
+
+    allocate_bytes(v, n/2);
+    parse_hex(v->bytes.p, v->bytes.sz, p);
+}
+
+static void dump_bytes(FILE *fp, const union regval *v)
+    { dump_hex(fp, v->bytes.p, v->bytes.sz); }
+
+static int eq_bytes(const union regval *v0, const union regval *v1)
+{
+    return v0->bytes.sz == v1->bytes.sz &&
+       !memcmp(v0->bytes.p, v1->bytes.p, v0->bytes.sz);
+}
+
+static void release_bytes(union regval *v) { free(v->bytes.p); }
+
+const struct regty regty_bytes = {
+    init_bytes,
+    parse_bytes,
+    dump_bytes,
+    eq_bytes,
+    release_bytes
+};
+
+/*----- Core test machinery -----------------------------------------------*/
+
+/* Say that a register is `reset' by releasing and then re-initializing it.
+ * While there is a current test, all of that test's registers are
+ * initialized.  The input registers are reset at the end of `check', ready
+ * for the next test to load new values.  The output registers are reset at
+ * the end of `check_test_output', so that a test runner can run a test
+ * multiple times against the same test input, but with different context
+ * data.
+ */
+
+#define REG(rvec, i)                                                   \
+    ((struct reg *)((unsigned char *)state->rvec + (i)*state->regsz))
+
+void check_test_output(struct test_state *state, const struct test *test)
+{
+    const struct regdef *def;
+    struct reg *reg, *in, *out;
+    int ok = 1;
+    int match;
+
+    for (def = test->regs; def->name; def++) {
+       if (def->i >= state->nrout) continue;
+       in = REG(in, def->i); out = REG(out, def->i);
+       if (!def->ty->eq(&in->v, &out->v)) ok = 0;
+    }
+    if (ok)
+       state->win++;
+    else {
+       printf("failed test `%s'\n", test->name);
+       for (def = test->regs; def->name; def++) {
+           in = REG(in, def->i);
+           if (def->i >= state->nrout) {
+               printf("\t   input `%s' = ", def->name);
+               def->ty->dump(stdout, &in->v);
+           } else {
+               out = REG(out, def->i);
+               match = def->ty->eq(&in->v, &out->v);
+               printf("\t%s `%s' = ",
+                      match ? "  output" : "expected", def->name);
+               def->ty->dump(stdout, &in->v);
+               if (!match) {
+                   printf("\tcomputed `%s' = ", def->name);
+                   def->ty->dump(stdout, &out->v);
+               }
+           }
+       }
+       state->lose++;
+    }
+
+    for (def = test->regs; def->name; def++) {
+       if (def->i >= state->nrout) continue;
+       reg = REG(out, def->i);
+       def->ty->release(&reg->v); def->ty->init(&reg->v);
+    }
+}
+
+void run_test(struct test_state *state, const struct test *test)
+{
+    test->fn(state->out, state->in, 0);
+    check_test_output(state, test);
+}
+
+static void check(struct test_state *state, const struct test *test)
+{
+    const struct regdef *def, *miss = 0;
+    struct reg *reg;
+    int any = 0;
+
+    if (!test) return;
+    for (def = test->regs; def->name; def++) {
+       reg = REG(in, def->i);
+       if (reg->f&REGF_LIVE) any = 1;
+       else if (!miss && !(def->f&REGF_OPT)) miss = def;
+    }
+    if (!any) return;
+    if (miss)
+       bail("register `%s' not set in test `%s'", def->name, test->name);
+
+    test->run(state, test);
+
+    for (def = test->regs; def->name; def++) {
+       reg = REG(in, def->i);
+       reg->f = 0; def->ty->release(&reg->v); def->ty->init(&reg->v);
+    }
+}
+
+int run_test_suite(unsigned nrout, unsigned nreg, size_t regsz,
+                  const struct test *tests, FILE *fp)
+{
+    struct linebuf buf = LINEBUF_INIT;
+    struct test_state state[1];
+    const struct test *test;
+    const struct regdef *def;
+    struct reg *reg;
+    char *p;
+    const char *q;
+    int total;
+    size_t n;
+
+    for (test = tests; test->name; test++)
+       for (def = test->regs; def->name; def++)
+           assert(def->i < nreg);
+
+    state->in = xmalloc(nreg*regsz);
+    state->out = xmalloc(nrout*regsz);
+    state->nrout = nrout;
+    state->nreg = nreg;
+    state->regsz = regsz;
+    state->win = state->lose = 0;
+
+    test = 0;
+    lno = 0;
+    while (!read_line(&buf, fp)) {
+       lno++;
+       p = buf.p; n = strlen(buf.p);
+
+       while (isspace((unsigned char)*p)) p++;
+       if (*p == '#') continue;
+       if (!*p) { check(state, test); continue; }
+
+       q = p;
+       while (*p && !isspace((unsigned char)*p)) p++;
+       if (*p) *p++ = 0;
+
+       if (!strcmp(q, "test")) {
+           if (!*p) bail("missing argument");
+           check(state, test);
+           if (test) {
+               for (def = test->regs; def->name; def++) {
+                   def->ty->release(&REG(in, def->i)->v);
+                   if (def->i < state->nrout)
+                       def->ty->release(&REG(out, def->i)->v);
+               }
+           }
+           for (test = tests; test->name; test++)
+               if (!strcmp(p, test->name)) goto found_test;
+           bail("unknown test `%s'", p);
+       found_test:
+           for (def = test->regs; def->name; def++) {
+               reg = REG(in, def->i);
+               reg->f = 0; def->ty->init(&reg->v);
+               if (def->i < state->nrout) {
+                   reg = REG(out, def->i);
+                   reg->f = 0; def->ty->init(&reg->v);
+               }
+           }
+           continue;
+       }
+
+       if (!test) bail("no current test");
+       for (def = test->regs; def->name; def++)
+           if (!strcmp(q, def->name)) goto found_reg;
+       bail("unknown register `%s' in test `%s'", q, test->name);
+    found_reg:
+       reg = REG(in, def->i);
+       if (reg->f&REGF_LIVE) bail("register `%s' already set", def->name);
+       def->ty->parse(&reg->v, p); reg->f |= REGF_LIVE;
+    }
+    check(state, test);
+
+    total = state->win + state->lose;
+    if (!state->lose)
+       printf("PASSED all %d test%s\n", state->win, total == 1 ? "" : "s");
+    else
+       printf("FAILED %d of %d test%s\n", state->lose, total,
+              total == 1 ? "" : "s");
+    return state->lose ? 1 : 0;
+}