oaep.c, pkcs1.c: Use official constant-time operations.
authorMark Wooding <mdw@distorted.org.uk>
Mon, 27 May 2013 21:23:58 +0000 (22:23 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Mon, 27 May 2013 22:06:43 +0000 (23:06 +0100)
The logic is a bit more contorted in places, but the security is better.

oaep.c
pkcs1.c

diff --git a/oaep.c b/oaep.c
index d7570de..dfcd41b 100644 (file)
--- a/oaep.c
+++ b/oaep.c
@@ -35,6 +35,7 @@
 #include <mLib/bits.h>
 #include <mLib/dstr.h>
 
+#include "ct.h"
 #include "gcipher.h"
 #include "ghash.h"
 #include "grand.h"
@@ -64,7 +65,7 @@ mp *oaep_encode(mp *d, const void *m, size_t msz, octet *b, size_t sz,
   oaep *o = p;
   size_t hsz = o->ch->hashsz;
   ghash *h;
-  octet *q, *mq, *qq;
+  octet *q, *mq;
   octet *pp;
   gcipher *c;
   size_t n;
@@ -79,7 +80,6 @@ mp *oaep_encode(mp *d, const void *m, size_t msz, octet *b, size_t sz,
   q = b;
   *q++ = 0; sz--;
   mq = q + hsz;
-  qq = q + sz;
   GR_FILL(o->r, q, hsz);
 
   /* --- Fill in the rest of the buffer --- */
@@ -126,17 +126,6 @@ mp *oaep_encode(mp *d, const void *m, size_t msz, octet *b, size_t sz,
  *             PKCS#1 v. 2.0 (RFC2437).
  */
 
-static int memeq(const void *xx, const void *yy, size_t sz)
-{
-  int eq = 1;
-  const octet *x = xx, *y = yy;
-  while (sz) {                         /* Always check every byte */
-    if (*x++ != *y++) eq = 0;
-    sz--;
-  }
-  return (eq);
-}
-
 int oaep_decode(mp *m, octet *b, size_t sz, unsigned long nbits, void *p)
 {
   oaep *o = p;
@@ -144,7 +133,7 @@ int oaep_decode(mp *m, octet *b, size_t sz, unsigned long nbits, void *p)
   ghash *h;
   octet *q, *mq, *qq;
   octet *pp;
-  unsigned bad = 0;
+  uint32 goodp = 1;
   size_t n;
   size_t hsz = o->ch->hashsz;
 
@@ -157,7 +146,7 @@ int oaep_decode(mp *m, octet *b, size_t sz, unsigned long nbits, void *p)
 
   mp_storeb(m, b, sz);
   q = b;
-  bad = *q;
+  goodp &= ct_inteq(*q, 0);
   q++; sz--;
   mq = q + hsz;
   qq = q + sz;
@@ -177,18 +166,19 @@ int oaep_decode(mp *m, octet *b, size_t sz, unsigned long nbits, void *p)
   GH_HASH(h, o->ep, o->epsz);
   GH_DONE(h, q);
   GH_DESTROY(h);
-  bad |= !memeq(q, mq, hsz);
+  goodp &= ct_memeq(q, mq, hsz);
 
   /* --- Now find the start of the actual message --- */
 
   pp = mq + hsz;
   while (*pp == 0 && pp < qq)
     pp++;
-  bad |= (pp >= qq) | (*pp != 1);
+  goodp &= ~ct_intle(qq - b, pp - b);
+  goodp &= ct_inteq(*pp, 1);
   pp++;
   n = qq - pp;
   memmove(q, pp, n);
-  return (bad ? -1 : n);
+  return (goodp ? n : -1);
 }
 
 /*----- That's all, folks -------------------------------------------------*/
diff --git a/pkcs1.c b/pkcs1.c
index 9241c45..47c135f 100644 (file)
--- a/pkcs1.c
+++ b/pkcs1.c
@@ -34,6 +34,7 @@
 #include <mLib/bits.h>
 #include <mLib/dstr.h>
 
+#include "ct.h"
 #include "grand.h"
 #include "rsa.h"
 
@@ -109,24 +110,13 @@ mp *pkcs1_cryptencode(mp *d, const void *m, size_t msz, octet *b, size_t sz,
  *             in PKCS#1 v. 2.0 (RFC2437).
  */
 
-static int memeq(const void *xx, const void *yy, size_t sz)
-{
-  int eq = 1;
-  const octet *x = xx, *y = yy;
-  while (sz) {                         /* Always check every byte */
-    if (*x++ != *y++) eq = 0;
-    sz--;
-  }
-  return (eq);
-}
-
 int pkcs1_cryptdecode(mp *m, octet *b, size_t sz,
                      unsigned long nbits, void *p)
 {
   pkcs1 *pp = p;
   const octet *q, *qq;
   size_t n, i;
-  int bad = 0;
+  uint32 goodp = 1;
 
   /* --- Check the size of the block looks sane --- */
 
@@ -138,26 +128,29 @@ int pkcs1_cryptdecode(mp *m, octet *b, size_t sz,
 
   /* --- Ensure that the block looks OK --- */
 
-  bad |= (*q++ != 0x00 || *q++ != 0x02);
+  goodp &= ct_inteq(*q++, 0);
+  goodp &= ct_inteq(*q++, 2);
 
   /* --- Check the nonzero padding --- */
 
   i = 0;
   while (*q != 0 && q < qq)
     i++, q++;
-  bad |= (i < 8 || qq - q < pp->epsz + 1);
+  goodp &= ct_intle(8, i);
+  goodp &= ~ct_intle(qq - q, pp->epsz + 1);
   q++;
 
   /* --- Check the encoding parameters --- */
 
-  bad |= (pp->ep && !memeq(bad ? b : q, pp->ep, pp->epsz));
+  if (pp->ep)
+    goodp &= ct_memeq(b + ct_pick(goodp, 0, q - b), pp->ep, pp->epsz);
   q += pp->epsz;
 
   /* --- Done --- */
 
   n = qq - q;
-  memmove(b, bad ? b + 1 : q, n);
-  return (bad ? -1 : n);
+  memmove(b, b + ct_pick(goodp, 1, q - b), n);
+  return (goodp ? n : -1);
 }
 
 /* --- @pkcs1_sigencode@ --- *