Joris van Rantwijk's RSA speedup patch
authorsimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Mon, 5 Jul 1999 16:31:57 +0000 (16:31 +0000)
committersimon <simon@cda61777-01e9-0310-a592-d414129be87e>
Mon, 5 Jul 1999 16:31:57 +0000 (16:31 +0000)
git-svn-id: svn://svn.tartarus.org/sgt/putty@171 cda61777-01e9-0310-a592-d414129be87e

sshrsa.c

index 70cf9bb..bac360f 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
@@ -1,8 +1,11 @@
 /*
  * RSA implementation just sufficient for ssh client-side
  * initialisation step
+ * Modified by Joris, Jun 1999.
  */
 
+#define JORIS_RSA
+
 /*#include <windows.h>
 #define RSADEBUG
 #define DLVL 2
@@ -57,6 +60,195 @@ static void freebn(Bignum b) {
     free(b);
 }
 
+#ifdef JORIS_RSA
+
+/*
+ *  Compute c = a * b.
+ *  Input is in the first len words of a and b.
+ *  Result is returned in the first 2*len words of c.
+ */
+static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c,
+                   int len)
+{
+       int i, j;
+       unsigned long ai, t;
+
+       for (j = len - 1; j >= 0; j--)
+               c[j+len] = 0;
+       
+       for (i = len - 1; i >= 0; i--) {
+               ai = a[i];
+               t = 0;
+               for (j = len - 1; j >= 0; j--) {
+                       t += ai * (unsigned long) b[j];
+                       t += (unsigned long) c[i+j+1];
+                       c[i+j+1] = t;
+                       t = t >> 16;
+               }
+               c[i] = t;
+       }       
+}
+
+
+/*
+ *  Compute a = a % m.
+ *  Input in first 2*len words of a and first len words of m.
+ *  Output in first 2*len words of a (of which first len words will be zero).
+ *  The MSW of m MUST have its high bit set.
+ */
+static void bigmod(unsigned short *a, unsigned short *m, int len)
+{
+       unsigned short m0, m1;
+       unsigned int h;
+       int i, k;
+
+       /* Special case for len == 1 */
+       if (len == 1) {
+               a[1] = (((long) a[0] << 16) + a[1]) % m[0];
+               a[0] = 0;
+               return;
+       }
+
+       m0 = m[0];
+       m1 = m[1];
+
+       for (i = 0; i <= len; i++) {
+               unsigned long t;
+               unsigned int q, r, c;
+
+               if (i == 0) {
+                       h = 0;
+               } else {
+                       h = a[i-1];
+                       a[i-1] = 0;
+               }
+
+               /* Find q = h:a[i] / m0 */
+               t = ((unsigned long) h << 16) + a[i];
+               q = t / m0;
+               r = t % m0;
+
+               /* Refine our estimate of q by looking at 
+                  h:a[i]:a[i+1] / m0:m1 */
+               t = (long) m1 * (long) q;               
+               if (t > ((unsigned long) r << 16) + a[i+1]) {
+                       q--;
+                       t -= m1;
+                       r = (r + m0) & 0xffff; /* overflow? */
+                       if (r >= m0 && t > ((unsigned long) r << 16) + a[i+1])
+                               q--;
+               }
+
+               /* Substract q * m from a[i...] */
+               c = 0;
+               for (k = len - 1; k >= 0; k--) {
+                       t = (long) q * (long) m[k];
+                       t += c;
+                       c = t >> 16;
+                       if ((unsigned short) t > a[i+k]) c++;
+                       a[i+k] -= (unsigned short) t;
+               }
+
+               /* Add back m in case of borrow */
+               if (c != h) {
+                       t = 0;
+                       for (k = len - 1; k >= 0; k--) {
+                               t += m[k];
+                               t += a[i+k];
+                               a[i+k] = t;
+                               t = t >> 16;
+                       }
+               }
+       }
+}
+
+
+/*
+ *  Compute (base ^ exp) % mod.
+ *  The base MUST be smaller than the modulus.
+ *  The most significant word of mod MUST be non-zero.
+ *  We assume that the result array is the same size as the mod array.
+ */
+static void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
+{
+       unsigned short *a, *b, *n, *m;
+       int mshift;
+       int mlen, i, j;
+
+       /* Allocate m of size mlen, copy mod to m */
+       /* We use big endian internally */
+       mlen = mod[0];
+       m = malloc(mlen * sizeof(unsigned short));
+       for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
+
+       /* Shift m left to make msb bit set */
+       for (mshift = 0; mshift < 15; mshift++)
+               if ((m[0] << mshift) & 0x8000) break;
+       if (mshift) {
+               for (i = 0; i < mlen - 1; i++)
+                       m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
+               m[mlen-1] = m[mlen-1] << mshift;
+       }
+
+       /* Allocate n of size mlen, copy base to n */
+       n = malloc(mlen * sizeof(unsigned short));
+       i = mlen - base[0];
+       for (j = 0; j < i; j++) n[j] = 0;
+       for (j = 0; j < base[0]; j++) n[i+j] = base[base[0] - j];
+
+       /* Allocate a and b of size 2*mlen. Set a = 1 */
+       a = malloc(2 * mlen * sizeof(unsigned short));
+       b = malloc(2 * mlen * sizeof(unsigned short));
+       for (i = 0; i < 2*mlen; i++) a[i] = 0;
+       a[2*mlen-1] = 1;
+
+       /* Skip leading zero bits of exp. */
+       i = 0; j = 15;
+       while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) {
+               j--;
+               if (j < 0) { i++; j = 15; }
+       }
+
+       /* Main computation */
+       while (i < exp[0]) {
+               while (j >= 0) {
+                       bigmul(a + mlen, a + mlen, b, mlen);
+                       bigmod(b, m, mlen);
+                       if ((exp[exp[0] - i] & (1 << j)) != 0) {
+                               bigmul(b + mlen, n, a, mlen);
+                               bigmod(a, m, mlen);
+                       } else {
+                               unsigned short *t;
+                               t = a;  a = b;  b = t;
+                       }
+                       j--;
+               }
+               i++; j = 15;
+       }
+
+       /* Fixup result in case the modulus was shifted */
+       if (mshift) {
+               for (i = mlen - 1; i < 2*mlen - 1; i++)
+                       a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
+               a[2*mlen-1] = a[2*mlen-1] << mshift;
+               bigmod(a, m, mlen);
+               for (i = 2*mlen - 1; i >= mlen; i--)
+                       a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
+       }
+       
+       /* Copy result to buffer */
+       for (i = 0; i < mlen; i++)
+               result[result[0] - i] = a[i+mlen];
+
+       /* Free temporary arrays */
+       for (i = 0; i < 2*mlen; i++) a[i] = 0; free(a);
+       for (i = 0; i < 2*mlen; i++) b[i] = 0; free(b);
+       for (i = 0; i < mlen; i++) m[i] = 0; free(m);
+       for (i = 0; i < mlen; i++) n[i] = 0; free(n);
+}
+
+#else
+
 static int msb(Bignum r) {
     int i;
     int j;
@@ -246,6 +438,8 @@ static void modpow(Bignum r1, Bignum r2, Bignum modulus, Bignum result) {
     leave(("<modpow\n"));
 }
 
+#endif
+
 int makekey(unsigned char *data, struct RSAKey *result,
            unsigned char **keystr) {
     unsigned char *p = data;