Rearrange the file tree.
[u/mdw/catacomb] / pub / rsa-test.c
diff --git a/pub/rsa-test.c b/pub/rsa-test.c
new file mode 100644 (file)
index 0000000..34b2a1f
--- /dev/null
@@ -0,0 +1,512 @@
+/* -*-c-*-
+ *
+ * Testing RSA padding operations
+ *
+ * (c) 2004 Straylight/Edgeware
+ */
+
+/*----- Licensing notice --------------------------------------------------*
+ *
+ * This file is part of Catacomb.
+ *
+ * Catacomb is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Library General Public License as
+ * published by the Free Software Foundation; either version 2 of the
+ * License, or (at your option) any later version.
+ *
+ * Catacomb is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU Library General Public License for more details.
+ *
+ * You should have received a copy of the GNU Library General Public
+ * License along with Catacomb; if not, write to the Free
+ * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
+ * MA 02111-1307, USA.
+ */
+
+/*----- Header files ------------------------------------------------------*/
+
+#include "fibrand.h"
+#include "rsa.h"
+
+/*----- Main code ---------------------------------------------------------*/
+
+static int tencpad(int nbits,
+                  dstr *p, int rc, mp *c,
+                  const char *ename, dstr *eparam, rsa_pad *e, void *earg)
+{
+  size_t n = (nbits + 7)/8;
+  void *q = xmalloc(n);
+  mp *d;
+  int ok = 1;
+
+  d = e(MP_NEW, p->buf, p->len, q, n, nbits, earg);
+  if (!d == !rc || (!rc && !MP_EQ(d, c))) {
+    ok = 0;
+    fprintf(stderr, "*** %s padding failed!\n", ename);
+    fprintf(stderr, "*** padding bits = %d\n", nbits);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    fprintf(stderr, "*** input message = "); type_hex.dump(p, stderr);
+    if (rc)
+      fprintf(stderr, "\n*** expected failure\n");
+    else {
+      MP_EPRINTX("\n*** expected", c);
+      MP_EPRINTX("*** computed", d);
+    }
+  }
+  mp_drop(d);
+  mp_drop(c);
+  xfree(q);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+#define tsigpad tencpad
+
+#define DSTR_EQ(x, y)                                                  \
+  ((x)->len == (y)->len && !memcmp((x)->buf, (y)->buf, (x)->len))
+
+static int tdecpad(int nbits,
+                  mp *c, int rc, dstr *p,
+                  const char *ename, dstr *eparam,
+                  rsa_decunpad *e, void *earg)
+{
+  dstr d = DSTR_INIT;
+  int n = (nbits + 7)/8;
+  int ok = 1;
+
+  dstr_ensure(&d, n);
+  n = e(c, (octet *)d.buf, n, nbits, earg);
+  if (n >= 0)
+    d.len += n;
+  if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
+    ok = 0;
+    fprintf(stderr, "*** %s encryption unpadding failed!\n", ename);
+    fprintf(stderr, "*** padding bits = %d\n", nbits);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    MP_EPRINTX("*** input", c);
+    if (rc < 0)
+      fprintf(stderr, "*** expected failure\n");
+    else {
+      fprintf(stderr, "*** expected: %d = ", rc); type_hex.dump(p, stderr);
+      fprintf(stderr, "\n*** computed: %d = ", n); type_hex.dump(&d, stderr);
+      fprintf(stderr, "\n");
+    }
+  }
+  mp_drop(c);
+  dstr_destroy(&d);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+static int tvrfpad(int nbits,
+                  mp *c, dstr *m, int rc, dstr *p,
+                  const char *ename, dstr *eparam,
+                  rsa_vrfunpad *e, void *earg)
+{
+  dstr d = DSTR_INIT;
+  int n = (nbits + 7)/8;
+  int ok = 1;
+
+  dstr_ensure(&d, n);
+  n = e(c, m->len ? (octet *)m->buf : 0, m->len,
+       (octet *)d.buf, n, nbits, earg);
+  if (n >= 0)
+    d.len += n;
+  if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
+    ok = 0;
+    fprintf(stderr, "*** %s signature unpadding failed!\n", ename);
+    fprintf(stderr, "*** padding bits = %d\n", nbits);
+    MP_EPRINTX("*** input", c);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    fprintf(stderr, "*** message = "); type_hex.dump(m, stderr);
+    if (rc < 0)
+      fprintf(stderr, "\n*** expected failure\n");
+    else {
+      fprintf(stderr, "\n*** expected = %d: ", rc); type_hex.dump(p, stderr);
+      fprintf(stderr, "\n*** computed = %d: ", n); type_hex.dump(&d, stderr);
+      fprintf(stderr, "\n");
+    }
+  }
+  mp_drop(c);
+  dstr_destroy(&d);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+static int tencpub(rsa_pub *rp,
+                  dstr *p, int rc, mp *c,
+                  const char *ename, dstr *eparam, rsa_pad *e, void *earg)
+{
+  mp *d;
+  rsa_pubctx rpc;
+  int ok = 1;
+
+  rsa_pubcreate(&rpc, rp);
+  d = rsa_encrypt(&rpc, MP_NEW, p->buf, p->len, e, earg);
+  if (!d == !rc || (!rc && !MP_EQ(d, c))) {
+    ok = 0;
+    fprintf(stderr, "*** encrypt with %s padding failed!\n", ename);
+    MP_EPRINTX("*** key.n", rp->n);
+    MP_EPRINTX("*** key.e", rp->e);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    fprintf(stderr, "*** input message = "); type_hex.dump(p, stderr);
+    if (rc)
+      fprintf(stderr, "\n*** expected failure\n");
+    else {
+      MP_EPRINTX("\n*** expected", c);
+      MP_EPRINTX("*** computed", d);
+    }
+  }
+  rsa_pubdestroy(&rpc);
+  rsa_pubfree(rp);
+  mp_drop(d);
+  mp_drop(c);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+static int tsigpriv(rsa_priv *rp,
+                   dstr *p, int rc, mp *c,
+                   const char *ename, dstr *eparam, rsa_pad *e, void *earg)
+{
+  mp *d;
+  grand *r = fibrand_create(0);
+  rsa_privctx rpc;
+  int ok = 1;
+
+  rsa_privcreate(&rpc, rp, r);
+  d = rsa_sign(&rpc, MP_NEW, p->buf, p->len, e, earg);
+  if (!d == !rc || (!rc && !MP_EQ(d, c))) {
+    ok = 0;
+    fprintf(stderr, "*** sign with %s padding failed!\n", ename);
+    MP_EPRINTX("*** key.n", rp->n);
+    MP_EPRINTX("*** key.d", rp->d);
+    MP_EPRINTX("*** key.e", rp->e);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    fprintf(stderr, "*** input message = "); type_hex.dump(p, stderr);
+    if (rc)
+      fprintf(stderr, "\n*** expected failure\n");
+    else {
+      MP_EPRINTX("\n*** expected", c);
+      MP_EPRINTX("\n*** computed", d);
+    }
+  }
+  rsa_privdestroy(&rpc);
+  rsa_privfree(rp);
+  mp_drop(d);
+  mp_drop(c);
+  GR_DESTROY(r);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+static int tdecpriv(rsa_priv *rp,
+                   mp *c, int rc, dstr *p,
+                   const char *ename, dstr *eparam,
+                   rsa_decunpad *e, void *earg)
+{
+  rsa_privctx rpc;
+  dstr d = DSTR_INIT;
+  grand *r = fibrand_create(0);
+  int n;
+  int ok = 1;
+
+  rsa_privcreate(&rpc, rp, r);
+  n = rsa_decrypt(&rpc, c, &d, e, earg);
+  if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
+    ok = 0;
+    fprintf(stderr, "*** decryption with %s padding failed!\n", ename);
+    MP_EPRINTX("*** key.n", rp->n);
+    MP_EPRINTX("*** key.d", rp->d);
+    MP_EPRINTX("*** key.e", rp->e);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    MP_EPRINTX("*** input", c);
+    if (rc < 0)
+      fprintf(stderr, "*** expected failure\n");
+    else {
+      fprintf(stderr, "*** expected = %d: ", rc); type_hex.dump(p, stderr);
+      fprintf(stderr, "\n*** computed = %d: ", n); type_hex.dump(&d, stderr);
+      fprintf(stderr, "\n");
+    }
+  }
+  rsa_privdestroy(&rpc);
+  rsa_privfree(rp);
+  mp_drop(c);
+  dstr_destroy(&d);
+  GR_DESTROY(r);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+static int tvrfpub(rsa_pub *rp,
+                  mp *c, dstr *m, int rc, dstr *p,
+                  const char *ename, dstr *eparam,
+                  rsa_vrfunpad *e, void *earg)
+{
+  rsa_pubctx rpc;
+  dstr d = DSTR_INIT;
+  int n;
+  int ok = 1;
+
+  rsa_pubcreate(&rpc, rp);
+  n = rsa_verify(&rpc, c, m->len ? m->buf : 0, m->len, &d, e, earg);
+  if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
+    ok = 0;
+    fprintf(stderr, "*** verification with %s padding failed!\n", ename);
+    MP_EPRINTX("*** key.n", rp->n);
+    MP_EPRINTX("*** key.e", rp->e);
+    if (eparam) {
+      fprintf(stderr, "*** encoding parameters = ");
+      type_hex.dump(eparam, stderr);
+      fputc('\n', stderr);
+    }
+    MP_EPRINTX("*** input", c);
+    fprintf(stderr, "*** message = "); type_hex.dump(m, stderr);
+    if (rc < 0)
+      fprintf(stderr, "\n*** expected failure\n");
+    else {
+      fprintf(stderr, "\n*** expected = %d: ", rc); type_hex.dump(p, stderr);
+      fprintf(stderr, "\n*** computed = %d: ", n); type_hex.dump(&d, stderr);
+      fprintf(stderr, "\n");
+    }
+  }
+  rsa_pubdestroy(&rpc);
+  rsa_pubfree(rp);
+  mp_drop(c);
+  dstr_destroy(&d);
+  assert(mparena_count(MPARENA_GLOBAL) == 0);
+  return (ok);
+}
+
+/*----- Deep magic --------------------------------------------------------*
+ *
+ * Wahey!  Whacko macro programming on curry and lager.  There's nothing like
+ * it.
+ */
+
+#define DECL_priv                                                      \
+  rsa_priv rp = { 0 };
+#define FUNC_priv                                                      \
+  rp.n = *(mp **)v++->buf;                                             \
+  rp.e = *(mp **)v++->buf;                                             \
+  rp.d = *(mp **)v++->buf;                                             \
+  rsa_recover(&rp);
+#define ARG_priv                                                       \
+  &rp,
+#define TAB_priv                                                       \
+  &type_mp, &type_mp, &type_mp,
+
+#define DECL_pub                                                       \
+  rsa_pub rp;
+#define FUNC_pub                                                       \
+  rp.n = *(mp **)v++->buf;                                             \
+  rp.e = *(mp **)v++->buf;
+#define ARG_pub                                                                \
+  &rp,
+#define TAB_pub                                                                \
+  &type_mp, &type_mp,
+
+#define DECL_pad                                                       \
+  int nbits;
+#define FUNC_pad                                                       \
+  nbits = *(int *)v++->buf;
+#define ARG_pad                                                                \
+  nbits,
+#define TAB_pad                                                                \
+  &type_int,
+
+#define DECL_enc                                                       \
+  dstr *p;                                                             \
+  int rc;                                                              \
+  mp *c;
+#define FUNC_enc                                                       \
+  p = v++;                                                             \
+  rc = *(int *)v++->buf;                                               \
+  c = *(mp **)v++->buf;
+#define ARG_enc                                                                \
+  p, rc, c,
+#define TAB_enc                                                                \
+  &type_hex, &type_int, &type_mp,
+
+#define DECL_sig DECL_enc
+#define FUNC_sig FUNC_enc
+#define ARG_sig ARG_enc
+#define TAB_sig TAB_enc
+
+#define DECL_dec                                                       \
+  mp *c;                                                               \
+  int rc;                                                              \
+  dstr *p;
+#define FUNC_dec                                                       \
+  c = *(mp **)v++->buf;                                                        \
+  rc = *(int *)v++->buf;                                               \
+  p = v++;
+#define ARG_dec                                                                \
+  c, rc, p,
+#define TAB_dec                                                                \
+  &type_mp, &type_int, &type_hex,
+
+#define DECL_vrf                                                       \
+  mp *c;                                                               \
+  dstr *m;                                                             \
+  int rc;                                                              \
+  dstr *p;
+#define FUNC_vrf                                                       \
+  c = *(mp **)v++->buf;                                                        \
+  m = v++;                                                             \
+  rc = *(int *)v++->buf;                                               \
+  p = v++;
+#define ARG_vrf                                                                \
+  c, m, rc, p,
+#define TAB_vrf                                                                \
+  &type_mp, &type_hex, &type_int, &type_hex,
+
+#define DECL_p1enc                                                     \
+  pkcs1 p1;                                                            \
+  dstr *ep;
+#define FUNC_p1enc                                                     \
+  p1.r = fib;                                                          \
+  ep = v++;                                                            \
+  p1.ep = ep->buf;                                                     \
+  p1.epsz = ep->len;
+#define ARG_p1enc                                                      \
+  "pkcs1", ep, pkcs1_cryptencode, &p1
+#define TAB_p1enc                                                      \
+  &type_hex
+
+#define DECL_p1sig DECL_p1enc
+#define FUNC_p1sig FUNC_p1enc
+#define ARG_p1sig                                                      \
+  "pkcs1", ep, pkcs1_sigencode, &p1
+#define TAB_p1sig TAB_p1enc
+
+#define DECL_p1dec DECL_p1enc
+#define FUNC_p1dec FUNC_p1enc
+#define ARG_p1dec                                                      \
+  "pkcs1", ep, pkcs1_cryptdecode, &p1
+#define TAB_p1dec TAB_p1enc
+
+#define DECL_p1vrf DECL_p1enc
+#define FUNC_p1vrf FUNC_p1enc
+#define ARG_p1vrf                                                      \
+  "pkcs1", ep, pkcs1_sigdecode, &p1
+#define TAB_p1vrf TAB_p1enc
+
+#define DECL_oaepenc                                                   \
+  oaep o;                                                              \
+  dstr *ep;
+#define FUNC_oaepenc                                                   \
+  o.r = fib;                                                           \
+  o.cc = gcipher_byname(v++->buf);                                     \
+  o.ch = ghash_byname(v++->buf);                                       \
+  ep = v++;                                                            \
+  o.ep = ep->buf;                                                      \
+  o.epsz = ep->len;
+#define ARG_oaepenc                                                    \
+  "oaep", ep, oaep_encode, &o
+#define TAB_oaepenc                                                    \
+  &type_string, &type_string, &type_hex
+
+#define DECL_oaepdec DECL_oaepenc
+#define FUNC_oaepdec FUNC_oaepenc
+#define ARG_oaepdec                                                    \
+  "oaep", ep, oaep_decode, &o
+#define TAB_oaepdec TAB_oaepenc
+
+#define DECL_psssig                                                    \
+  pss pp;
+#define FUNC_psssig                                                    \
+  pp.r = fib;                                                          \
+  pp.cc = gcipher_byname(v++->buf);                                    \
+  pp.ch = ghash_byname(v++->buf);                                      \
+  pp.ssz = *(int *)v++->buf;
+#define ARG_psssig                                                     \
+  "pss", 0, pss_encode, &pp
+#define TAB_psssig                                                     \
+  &type_string, &type_string, &type_int
+
+#define DECL_pssvrf DECL_psssig
+#define FUNC_pssvrf FUNC_psssig
+#define ARG_pssvrf                                                     \
+  "pss", 0, pss_decode, &pp
+#define TAB_pssvrf TAB_psssig
+
+#define TESTS(DO)                                                      \
+  DO(pad, enc, p1enc)                                                  \
+  DO(pad, dec, p1dec)                                                  \
+  DO(pad, sig, p1sig)                                                  \
+  DO(pad, vrf, p1vrf)                                                  \
+  DO(pub, enc, p1enc)                                                  \
+  DO(priv, dec, p1dec)                                                 \
+  DO(priv, sig, p1sig)                                                 \
+  DO(pub, vrf, p1vrf)                                                  \
+  DO(pad, enc, oaepenc)                                                        \
+  DO(pad, dec, oaepdec)                                                        \
+  DO(pub, enc, oaepenc)                                                        \
+  DO(priv, dec, oaepdec)                                               \
+  DO(pad, sig, psssig)                                                 \
+  DO(pad, vrf, pssvrf)                                                 \
+  DO(priv, sig, psssig)                                                        \
+  DO(pub, vrf, pssvrf)
+
+#define FUNCS(key, op, enc)                                            \
+  int t_##key##_##enc(dstr *v)                                         \
+  {                                                                    \
+    DECL_##key                                                         \
+    DECL_##op                                                          \
+    DECL_##enc                                                         \
+    fib->ops->misc(fib, GRAND_SEEDINT, 14);                            \
+    FUNC_##key                                                         \
+    FUNC_##op                                                          \
+    FUNC_##enc                                                         \
+    return (t##op##key(ARG_##key ARG_##op ARG_##enc));                 \
+  }
+
+#define TAB(key, op, enc)                                              \
+  { #enc "-" #key, t_##key##_##enc, { TAB_##key TAB_##op TAB_##enc } },
+
+static grand *fib;
+
+TESTS(FUNCS)
+
+static const test_chunk tests[] = {
+  TESTS(TAB)
+  { 0 }
+};
+
+int main(int argc, char *argv[])
+{
+  sub_init();
+  fib = fibrand_create(0);
+  test_run(argc, argv, tests, SRCDIR "/t/rsa");
+  GR_DESTROY(fib);
+  return (0);
+}
+
+/*----- That's all, folks -------------------------------------------------*/