Add support for DSA authentication in SSH2, following clever ideas
[u/mdw/putty] / sshprime.c
index e6d3b3f..8bbfb87 100644 (file)
@@ -2,6 +2,7 @@
  * Prime generation.
  */
 
+#include <assert.h>
 #include "ssh.h"
 
 /*
@@ -1182,14 +1183,25 @@ static const unsigned short primes[] = {
 #define NPRIMES (sizeof(primes) / sizeof(*primes))
 
 /*
- * Generate a prime. We arrange to select a prime with the property
- * (prime % modulus) != residue (to speed up use in RSA).
+ * Generate a prime. We can deal with various extra properties of
+ * the prime:
+ * 
+ *  - to speed up use in RSA, we can arrange to select a prime with
+ *    the property (prime % modulus) != residue.
+ * 
+ *  - for use in DSA, we can arrange to select a prime which is one
+ *    more than a multiple of a dirty great bignum. In this case
+ *    `bits' gives the size of the factor by which we _multiply_
+ *    that bignum, rather than the size of the whole number.
  */
-Bignum primegen(int bits, int modulus, int residue,
+Bignum primegen(int bits, int modulus, int residue, Bignum factor,
                int phase, progfn_t pfn, void *pfnparam)
 {
     int i, k, v, byte, bitsleft, check, checks;
-    unsigned long delta, moduli[NPRIMES + 1], residues[NPRIMES + 1];
+    unsigned long delta;
+    unsigned long moduli[NPRIMES + 1];
+    unsigned long residues[NPRIMES + 1];
+    unsigned long multipliers[NPRIMES + 1];
     Bignum p, pm1, q, wqp, wqp2;
     int progress = 0;
 
@@ -1198,15 +1210,18 @@ Bignum primegen(int bits, int modulus, int residue,
 
   STARTOVER:
 
-    pfn(pfnparam, phase, ++progress);
+    pfn(pfnparam, PROGFN_PROGRESS, phase, ++progress);
 
     /*
      * Generate a k-bit random number with top and bottom bits set.
+     * Alternatively, if `factor' is nonzero, generate a k-bit
+     * random number with the top bit set and the bottom bit clear,
+     * multiply it by `factor', and add one.
      */
     p = bn_power_2(bits - 1);
     for (i = 0; i < bits; i++) {
        if (i == 0 || i == bits - 1)
-           v = 1;
+           v = (i != 0 || !factor) ? 1 : 0;
        else {
            if (bitsleft <= 0)
                bitsleft = 8, byte = random_byte();
@@ -1216,14 +1231,26 @@ Bignum primegen(int bits, int modulus, int residue,
        }
        bignum_set_bit(p, i, v);
     }
+    if (factor) {
+       Bignum tmp = p;
+       p = bigmul(tmp, factor);
+       freebn(tmp);
+       assert(bignum_bit(p, 0) == 0);
+       bignum_set_bit(p, 0, 1);
+    }
 
     /*
      * Ensure this random number is coprime to the first few
-     * primes, by repeatedly adding 2 to it until it is.
+     * primes, by repeatedly adding either 2 or 2*factor to it
+     * until it is.
      */
     for (i = 0; i < NPRIMES; i++) {
        moduli[i] = primes[i];
        residues[i] = bignum_mod_short(p, primes[i]);
+       if (factor)
+           multipliers[i] = bignum_mod_short(factor, primes[i]);
+       else
+           multipliers[i] = 1;
     }
     moduli[NPRIMES] = modulus;
     residues[NPRIMES] = (bignum_mod_short(p, (unsigned short) modulus)
@@ -1231,11 +1258,11 @@ Bignum primegen(int bits, int modulus, int residue,
     delta = 0;
     while (1) {
        for (i = 0; i < (sizeof(moduli) / sizeof(*moduli)); i++)
-           if (!((residues[i] + delta) % moduli[i]))
+           if (!((residues[i] + delta * multipliers[i]) % moduli[i]))
                break;
        if (i < (sizeof(moduli) / sizeof(*moduli))) {   /* we broke */
            delta += 2;
-           if (delta < 2) {
+           if (delta > 65536) {
                freebn(p);
                goto STARTOVER;
            }
@@ -1244,7 +1271,14 @@ Bignum primegen(int bits, int modulus, int residue,
        break;
     }
     q = p;
-    p = bignum_add_long(q, delta);
+    if (factor) {
+       Bignum tmp;
+       tmp = bignum_from_long(delta);
+       p = bigmuladd(tmp, factor, q);
+       freebn(tmp);
+    } else {
+       p = bignum_add_long(q, delta);
+    }
     freebn(q);
 
     /*
@@ -1311,7 +1345,7 @@ Bignum primegen(int bits, int modulus, int residue,
            break;
        }
 
-       pfn(pfnparam, phase, ++progress);
+       pfn(pfnparam, PROGFN_PROGRESS, phase, ++progress);
 
        /*
         * Compute w^q mod p.