RSA key authentication in ssh1 works; SSH2 is nearly there
[u/mdw/putty] / sshbn.c
diff --git a/sshbn.c b/sshbn.c
index a9eae85..97ae357 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -6,36 +6,17 @@
 #include <stdlib.h>
 #include <string.h>
 
+#include <stdio.h> /* FIXME */
+#include <stdarg.h> /* FIXME */
+#include <windows.h> /* FIXME */
+#include "putty.h" /* FIXME */
+
 #include "ssh.h"
 
-static unsigned short Zero[1] = { 0 };
+unsigned short bnZero[1] = { 0 };
+unsigned short bnOne[2] = { 1, 1 };
 
-#if defined TESTMODE || defined RSADEBUG
-#ifndef DLVL
-#define DLVL 10000
-#endif
-#define debug(x) bndebug(#x,x)
-static int level = 0;
-static void bndebug(char *name, Bignum b) {
-    int i;
-    int w = 50-level-strlen(name)-5*b[0];
-    if (level >= DLVL)
-       return;
-    if (w < 0) w = 0;
-    dprintf("%*s%s%*s", level, "", name, w, "");
-    for (i=b[0]; i>0; i--)
-       dprintf(" %04x", b[i]);
-    dprintf("\n");
-}
-#define dmsg(x) do {if(level<DLVL){dprintf("%*s",level,"");printf x;}} while(0)
-#define enter(x) do { dmsg(x); level += 4; } while(0)
-#define leave(x) do { level -= 4; dmsg(x); } while(0)
-#else
-#define debug(x)
-#define dmsg(x)
-#define enter(x)
-#define leave(x)
-#endif
+Bignum Zero = bnZero, One = bnOne;
 
 Bignum newbn(int length) {
     Bignum b = malloc((length+1)*sizeof(unsigned short));
@@ -46,6 +27,14 @@ Bignum newbn(int length) {
     return b;
 }
 
+Bignum copybn(Bignum orig) {
+    Bignum b = malloc((orig[0]+1)*sizeof(unsigned short));
+    if (!b)
+       abort();                       /* FIXME */
+    memcpy(b, orig, (orig[0]+1)*sizeof(*b));
+    return b;
+}
+
 void freebn(Bignum b) {
     /*
      * Burn the evidence, just in case.
@@ -83,11 +72,13 @@ static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c,
 
 /*
  * Compute a = a % m.
- * Input in first 2*len words of a and first len words of m.
- * Output in first 2*len words of a (of which first len words will be zero).
+ * Input in first len2 words of a and first len words of m.
+ * Output in first len2 words of a
+ * (of which first len2-len words will be zero).
  * The MSW of m MUST have its high bit set.
  */
-static void bigmod(unsigned short *a, unsigned short *m, int len)
+static void bigmod(unsigned short *a, unsigned short *m,
+                   int len, int len2)
 {
     unsigned short m0, m1;
     unsigned int h;
@@ -103,7 +94,7 @@ static void bigmod(unsigned short *a, unsigned short *m, int len)
     m0 = m[0];
     m1 = m[1];
 
-    for (i = 0; i <= len; i++) {
+    for (i = 0; i <= len2-len; i++) {
        unsigned long t;
        unsigned int q, r, c;
 
@@ -204,10 +195,10 @@ void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
     while (i < exp[0]) {
        while (j >= 0) {
            bigmul(a + mlen, a + mlen, b, mlen);
-           bigmod(b, m, mlen);
+           bigmod(b, m, mlen, mlen*2);
            if ((exp[exp[0] - i] & (1 << j)) != 0) {
                bigmul(b + mlen, n, a, mlen);
-               bigmod(a, m, mlen);
+               bigmod(a, m, mlen, mlen*2);
            } else {
                unsigned short *t;
                t = a;  a = b;  b = t;
@@ -222,7 +213,7 @@ void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
        for (i = mlen - 1; i < 2*mlen - 1; i++)
            a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
        a[2*mlen-1] = a[2*mlen-1] << mshift;
-       bigmod(a, m, mlen);
+       bigmod(a, m, mlen, mlen*2);
        for (i = 2*mlen - 1; i >= mlen; i--)
            a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
     }
@@ -237,3 +228,115 @@ void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
     for (i = 0; i < mlen; i++) m[i] = 0; free(m);
     for (i = 0; i < mlen; i++) n[i] = 0; free(n);
 }
+
+/*
+ * Compute (p * q) % mod.
+ * The most significant word of mod MUST be non-zero.
+ * We assume that the result array is the same size as the mod array.
+ */
+void modmul(Bignum p, Bignum q, Bignum mod, Bignum result)
+{
+    unsigned short *a, *n, *m, *o;
+    int mshift;
+    int pqlen, mlen, i, j;
+
+    /* Allocate m of size mlen, copy mod to m */
+    /* We use big endian internally */
+    mlen = mod[0];
+    m = malloc(mlen * sizeof(unsigned short));
+    for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
+
+    /* Shift m left to make msb bit set */
+    for (mshift = 0; mshift < 15; mshift++)
+       if ((m[0] << mshift) & 0x8000) break;
+    if (mshift) {
+       for (i = 0; i < mlen - 1; i++)
+           m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
+       m[mlen-1] = m[mlen-1] << mshift;
+    }
+
+    pqlen = (p[0] > q[0] ? p[0] : q[0]);
+
+    /* Allocate n of size pqlen, copy p to n */
+    n = malloc(pqlen * sizeof(unsigned short));
+    i = pqlen - p[0];
+    for (j = 0; j < i; j++) n[j] = 0;
+    for (j = 0; j < p[0]; j++) n[i+j] = p[p[0] - j];
+
+    /* Allocate o of size pqlen, copy q to o */
+    o = malloc(pqlen * sizeof(unsigned short));
+    i = pqlen - q[0];
+    for (j = 0; j < i; j++) o[j] = 0;
+    for (j = 0; j < q[0]; j++) o[i+j] = q[q[0] - j];
+
+    /* Allocate a of size 2*pqlen for result */
+    a = malloc(2 * pqlen * sizeof(unsigned short));
+
+    /* Main computation */
+    bigmul(n, o, a, pqlen);
+    bigmod(a, m, mlen, 2*pqlen);
+
+    /* Fixup result in case the modulus was shifted */
+    if (mshift) {
+       for (i = 2*pqlen - mlen - 1; i < 2*pqlen - 1; i++)
+           a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
+       a[2*pqlen-1] = a[2*pqlen-1] << mshift;
+       bigmod(a, m, mlen, pqlen*2);
+       for (i = 2*pqlen - 1; i >= 2*pqlen - mlen; i--)
+           a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
+    }
+
+    /* Copy result to buffer */
+    for (i = 0; i < mlen; i++)
+       result[result[0] - i] = a[i+2*pqlen-mlen];
+
+    /* Free temporary arrays */
+    for (i = 0; i < 2*pqlen; i++) a[i] = 0; free(a);
+    for (i = 0; i < mlen; i++) m[i] = 0; free(m);
+    for (i = 0; i < pqlen; i++) n[i] = 0; free(n);
+    for (i = 0; i < pqlen; i++) o[i] = 0; free(o);
+}
+
+/*
+ * Decrement a number.
+ */
+void decbn(Bignum bn) {
+    int i = 1;
+    while (i < bn[0] && bn[i] == 0)
+        bn[i++] = 0xFFFF;
+    bn[i]--;
+}
+
+/*
+ * Read an ssh1-format bignum from a data buffer. Return the number
+ * of bytes consumed.
+ */
+int ssh1_read_bignum(unsigned char *data, Bignum *result) {
+    unsigned char *p = data;
+    Bignum bn;
+    int i;
+    int w, b;
+
+    w = 0;
+    for (i=0; i<2; i++)
+        w = (w << 8) + *p++;
+
+    b = (w+7)/8;                       /* bits -> bytes */
+    w = (w+15)/16;                    /* bits -> words */
+
+    bn = newbn(w);
+
+    for (i=1; i<=w; i++)
+        bn[i] = 0;
+    for (i=b; i-- ;) {
+        unsigned char byte = *p++;
+        if (i & 1)
+            bn[1+i/2] |= byte<<8;
+        else
+            bn[1+i/2] |= byte;
+    }
+
+    *result = bn;
+
+    return p - data;
+}