X-Git-Url: https://git.distorted.org.uk/u/mdw/putty/blobdiff_plain/4644b0ce3adad5efe574dc125bc3b0cd8f6c2aa7..4017be6d5375063f59b63474490ac072e7f09b1a:/sshrsa.c diff --git a/sshrsa.c b/sshrsa.c index 71e1d634..bd92c7bd 100644 --- a/sshrsa.c +++ b/sshrsa.c @@ -9,12 +9,6 @@ #include #include -#include "ssh.h" - -typedef unsigned short *Bignum; - -static unsigned short Zero[1] = { 0 }; - #if defined TESTMODE || defined RSADEBUG #ifndef DLVL #define DLVL 10000 @@ -42,245 +36,38 @@ static void bndebug(char *name, Bignum b) { #define leave(x) #endif -static Bignum newbn(int length) { - Bignum b = malloc((length+1)*sizeof(unsigned short)); - if (!b) - abort(); /* FIXME */ - b[0] = length; - return b; -} - -static void freebn(Bignum b) { - free(b); -} - -/* - * Compute c = a * b. - * Input is in the first len words of a and b. - * Result is returned in the first 2*len words of c. - */ -static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c, - int len) -{ - int i, j; - unsigned long ai, t; - - for (j = len - 1; j >= 0; j--) - c[j+len] = 0; - - for (i = len - 1; i >= 0; i--) { - ai = a[i]; - t = 0; - for (j = len - 1; j >= 0; j--) { - t += ai * (unsigned long) b[j]; - t += (unsigned long) c[i+j+1]; - c[i+j+1] = (unsigned short)t; - t = t >> 16; - } - c[i] = (unsigned short)t; - } -} - -/* - * Compute a = a % m. - * Input in first 2*len words of a and first len words of m. - * Output in first 2*len words of a (of which first len words will be zero). - * The MSW of m MUST have its high bit set. - */ -static void bigmod(unsigned short *a, unsigned short *m, int len) -{ - unsigned short m0, m1; - unsigned int h; - int i, k; - - /* Special case for len == 1 */ - if (len == 1) { - a[1] = (((long) a[0] << 16) + a[1]) % m[0]; - a[0] = 0; - return; - } - - m0 = m[0]; - m1 = m[1]; - - for (i = 0; i <= len; i++) { - unsigned long t; - unsigned int q, r, c; - - if (i == 0) { - h = 0; - } else { - h = a[i-1]; - a[i-1] = 0; - } - - /* Find q = h:a[i] / m0 */ - t = ((unsigned long) h << 16) + a[i]; - q = t / m0; - r = t % m0; - - /* Refine our estimate of q by looking at - h:a[i]:a[i+1] / m0:m1 */ - t = (long) m1 * (long) q; - if (t > ((unsigned long) r << 16) + a[i+1]) { - q--; - t -= m1; - r = (r + m0) & 0xffff; /* overflow? */ - if (r >= m0 && t > ((unsigned long) r << 16) + a[i+1]) - q--; - } - - /* Substract q * m from a[i...] */ - c = 0; - for (k = len - 1; k >= 0; k--) { - t = (long) q * (long) m[k]; - t += c; - c = t >> 16; - if ((unsigned short) t > a[i+k]) c++; - a[i+k] -= (unsigned short) t; - } - - /* Add back m in case of borrow */ - if (c != h) { - t = 0; - for (k = len - 1; k >= 0; k--) { - t += m[k]; - t += a[i+k]; - a[i+k] = (unsigned short)t; - t = t >> 16; - } - } - } -} - -/* - * Compute (base ^ exp) % mod. - * The base MUST be smaller than the modulus. - * The most significant word of mod MUST be non-zero. - * We assume that the result array is the same size as the mod array. - */ -static void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result) -{ - unsigned short *a, *b, *n, *m; - int mshift; - int mlen, i, j; - - /* Allocate m of size mlen, copy mod to m */ - /* We use big endian internally */ - mlen = mod[0]; - m = malloc(mlen * sizeof(unsigned short)); - for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < 15; mshift++) - if ((m[0] << mshift) & 0x8000) break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift)); - m[mlen-1] = m[mlen-1] << mshift; - } - - /* Allocate n of size mlen, copy base to n */ - n = malloc(mlen * sizeof(unsigned short)); - i = mlen - base[0]; - for (j = 0; j < i; j++) n[j] = 0; - for (j = 0; j < base[0]; j++) n[i+j] = base[base[0] - j]; - - /* Allocate a and b of size 2*mlen. Set a = 1 */ - a = malloc(2 * mlen * sizeof(unsigned short)); - b = malloc(2 * mlen * sizeof(unsigned short)); - for (i = 0; i < 2*mlen; i++) a[i] = 0; - a[2*mlen-1] = 1; - - /* Skip leading zero bits of exp. */ - i = 0; j = 15; - while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) { - j--; - if (j < 0) { i++; j = 15; } - } - - /* Main computation */ - while (i < exp[0]) { - while (j >= 0) { - bigmul(a + mlen, a + mlen, b, mlen); - bigmod(b, m, mlen); - if ((exp[exp[0] - i] & (1 << j)) != 0) { - bigmul(b + mlen, n, a, mlen); - bigmod(a, m, mlen); - } else { - unsigned short *t; - t = a; a = b; b = t; - } - j--; - } - i++; j = 15; - } - - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = mlen - 1; i < 2*mlen - 1; i++) - a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift)); - a[2*mlen-1] = a[2*mlen-1] << mshift; - bigmod(a, m, mlen); - for (i = 2*mlen - 1; i >= mlen; i--) - a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift)); - } - - /* Copy result to buffer */ - for (i = 0; i < mlen; i++) - result[result[0] - i] = a[i+mlen]; - - /* Free temporary arrays */ - for (i = 0; i < 2*mlen; i++) a[i] = 0; free(a); - for (i = 0; i < 2*mlen; i++) b[i] = 0; free(b); - for (i = 0; i < mlen; i++) m[i] = 0; free(m); - for (i = 0; i < mlen; i++) n[i] = 0; free(n); -} +#include "ssh.h" int makekey(unsigned char *data, struct RSAKey *result, - unsigned char **keystr) { + unsigned char **keystr, int order) { unsigned char *p = data; - Bignum bn[2]; - int i, j; - int w, b; + int i; result->bits = 0; for (i=0; i<4; i++) result->bits = (result->bits << 8) + *p++; - for (j=0; j<2; j++) { - - w = 0; - for (i=0; i<2; i++) - w = (w << 8) + *p++; + /* + * order=0 means exponent then modulus (the keys sent by the + * server). order=1 means modulus then exponent (the keys + * stored in a keyfile). + */ - result->bytes = b = (w+7)/8; /* bits -> bytes */ - w = (w+15)/16; /* bits -> words */ - - bn[j] = newbn(w); - - if (keystr) *keystr = p; /* point at key string, second time */ - - for (i=1; i<=w; i++) - bn[j][i] = 0; - for (i=0; iexponent = bn[0]; - result->modulus = bn[1]; + if (order == 0) + p += ssh1_read_bignum(p, &result->exponent); + result->bytes = (((p[0] << 8) + p[1]) + 7) / 8; + if (keystr) *keystr = p+2; + p += ssh1_read_bignum(p, &result->modulus); + if (order == 1) + p += ssh1_read_bignum(p, &result->exponent); return p - data; } +int makeprivate(unsigned char *data, struct RSAKey *result) { + return ssh1_read_bignum(data, &result->private_exponent); +} + void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) { Bignum b1, b2; int w, i; @@ -307,12 +94,12 @@ void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) { p = data; for (i=1; i<=w; i++) b1[i] = 0; - for (i=0; ibytes; i++) { + for (i=key->bytes; i-- ;) { unsigned char byte = *p++; - if ((key->bytes-i) & 1) - b1[w-i/2] |= byte; + if (i & 1) + b1[1+i/2] |= byte<<8; else - b1[w-i/2] |= byte<<8; + b1[1+i/2] |= byte; } debug(b1); @@ -322,12 +109,12 @@ void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) { debug(b2); p = data; - for (i=0; ibytes; i++) { + for (i=key->bytes; i-- ;) { unsigned char b; if (i & 1) - b = b2[w-i/2] & 0xFF; + b = b2[1+i/2] >> 8; else - b = b2[w-i/2] >> 8; + b = b2[1+i/2] & 0xFF; *p++ = b; } @@ -335,6 +122,13 @@ void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) { freebn(b2); } +Bignum rsadecrypt(Bignum input, struct RSAKey *key) { + Bignum ret; + ret = newbn(key->modulus[0]); + modpow(input, key->private_exponent, key->modulus, ret); + return ret; +} + int rsastr_len(struct RSAKey *key) { Bignum md, ex;