rand/rand.c: Rearrange some comparisons to avoid arithmetic overflow.
[catacomb] / rand / rand.c
index 32605ac..0d0650e 100644 (file)
@@ -163,11 +163,14 @@ static int trivial_quick(rand_pool *r) { return (-1); }
 
 #if CPUFAM_X86 || CPUFAM_AMD64
 extern int rand_quick_x86ish_rdrand(rand_pool */*r*/);
+extern int rand_quick_x86ish_rdseed(rand_pool */*r*/);
 #endif
 
 static quick__functype *pick_quick(void)
 {
 #if CPUFAM_X86 || CPUFAM_AMD64
+  DISPATCH_PICK_COND(rand_quick, rand_quick_x86ish_rdseed,
+                    cpu_feature_p(CPUFEAT_X86_RDSEED));
   DISPATCH_PICK_COND(rand_quick, rand_quick_x86ish_rdrand,
                     cpu_feature_p(CPUFEAT_X86_RDRAND));
 #endif
@@ -304,6 +307,8 @@ void rand_gate(rand_pool *r)
   HASH_CTX hc;
   CIPHER_CTX cc;
 
+  STATIC_ASSERT(CIPHER_KEYSZ <= HASH_SZ, "rand cipher keysize too long");
+
   RAND_RESOLVE(r);
   QUICK(r);
 
@@ -319,7 +324,6 @@ void rand_gate(rand_pool *r)
 
   /* --- Now mangle all of the data based on the hash --- */
 
-  assert(CIPHER_KEYSZ <= HASH_SZ);
   CIPHER_INIT(&cc, h, CIPHER_KEYSZ, 0);
   CIPHER_ENCRYPT(&cc, r->pool, r->pool, RAND_POOLSZ);
   CIPHER_ENCRYPT(&cc, r->buf, r->buf, RAND_BUFSZ);
@@ -330,7 +334,7 @@ void rand_gate(rand_pool *r)
   r->o = RAND_SECSZ;
   r->obits += r->ibits;
   if (r->obits > RAND_OBITS) {
-    r->ibits = r->obits - r->ibits;
+    r->ibits = r->obits - RAND_OBITS;
     r->obits = RAND_OBITS;
   } else
     r->ibits = 0;
@@ -355,6 +359,8 @@ void rand_stretch(rand_pool *r)
   HASH_CTX hc;
   CIPHER_CTX cc;
 
+  STATIC_ASSERT(CIPHER_KEYSZ <= HASH_SZ, "rand cipher keysize too long");
+
   RAND_RESOLVE(r);
   QUICK(r);
 
@@ -370,7 +376,6 @@ void rand_stretch(rand_pool *r)
 
   /* --- Now mangle the buffer based on the hash --- */
 
-  assert(CIPHER_KEYSZ <= HASH_SZ);
   CIPHER_INIT(&cc, h, CIPHER_KEYSZ, 0);
   CIPHER_ENCRYPT(&cc, r->buf, r->buf, RAND_BUFSZ);
   BURN(cc);
@@ -408,7 +413,7 @@ void rand_get(rand_pool *r, void *p, size_t sz)
   if (!sz)
     return;
   for (;;) {
-    if (r->o + sz <= RAND_BUFSZ) {
+    if (sz <= RAND_BUFSZ - r->o) {
       memcpy(o, r->buf + r->o, sz);
       r->o += sz;
       break;
@@ -470,11 +475,15 @@ void rand_getgood(rand_pool *r, void *p, size_t sz)
        chunk = r->obits / 8;
     }
 
-    if (chunk + r->o > RAND_BUFSZ)
+    if (chunk <= RAND_BUFSZ - r->o) {
+      memcpy(o, r->buf + r->o, chunk);
+      r->o += chunk;
+    } else {
       chunk = RAND_BUFSZ - r->o;
+      memcpy(o, r->buf + r->o, chunk);
+      rand_stretch(r);
+    }
 
-    memcpy(o, r->buf + r->o, chunk);
-    r->o += chunk;
     r->obits -= chunk * 8;
     o += chunk;
     sz -= chunk;