#include "ssh.h"
-static unsigned short Zero[1] = { 0 };
-
-#if defined TESTMODE || defined RSADEBUG
-#ifndef DLVL
-#define DLVL 10000
-#endif
-#define debug(x) bndebug(#x,x)
-static int level = 0;
-static void bndebug(char *name, Bignum b) {
- int i;
- int w = 50-level-strlen(name)-5*b[0];
- if (level >= DLVL)
- return;
- if (w < 0) w = 0;
- dprintf("%*s%s%*s", level, "", name, w, "");
- for (i=b[0]; i>0; i--)
- dprintf(" %04x", b[i]);
- dprintf("\n");
-}
-#define dmsg(x) do {if(level<DLVL){dprintf("%*s",level,"");printf x;}} while(0)
-#define enter(x) do { dmsg(x); level += 4; } while(0)
-#define leave(x) do { level -= 4; dmsg(x); } while(0)
-#else
-#define debug(x)
-#define dmsg(x)
-#define enter(x)
-#define leave(x)
-#endif
+unsigned short bnZero[1] = { 0 };
+unsigned short bnOne[2] = { 1, 1 };
+
+Bignum Zero = bnZero, One = bnOne;
Bignum newbn(int length) {
Bignum b = malloc((length+1)*sizeof(unsigned short));
return b;
}
+Bignum copybn(Bignum orig) {
+ Bignum b = malloc((orig[0]+1)*sizeof(unsigned short));
+ if (!b)
+ abort(); /* FIXME */
+ memcpy(b, orig, (orig[0]+1)*sizeof(*b));
+ return b;
+}
+
void freebn(Bignum b) {
/*
* Burn the evidence, just in case.
/*
* Compute a = a % m.
- * Input in first 2*len words of a and first len words of m.
- * Output in first 2*len words of a (of which first len words will be zero).
+ * Input in first len2 words of a and first len words of m.
+ * Output in first len2 words of a
+ * (of which first len2-len words will be zero).
* The MSW of m MUST have its high bit set.
*/
-static void bigmod(unsigned short *a, unsigned short *m, int len)
+static void bigmod(unsigned short *a, unsigned short *m,
+ int len, int len2)
{
unsigned short m0, m1;
unsigned int h;
m0 = m[0];
m1 = m[1];
- for (i = 0; i <= len; i++) {
+ for (i = 0; i <= len2-len; i++) {
unsigned long t;
unsigned int q, r, c;
while (i < exp[0]) {
while (j >= 0) {
bigmul(a + mlen, a + mlen, b, mlen);
- bigmod(b, m, mlen);
+ bigmod(b, m, mlen, mlen*2);
if ((exp[exp[0] - i] & (1 << j)) != 0) {
bigmul(b + mlen, n, a, mlen);
- bigmod(a, m, mlen);
+ bigmod(a, m, mlen, mlen*2);
} else {
unsigned short *t;
t = a; a = b; b = t;
for (i = mlen - 1; i < 2*mlen - 1; i++)
a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
a[2*mlen-1] = a[2*mlen-1] << mshift;
- bigmod(a, m, mlen);
+ bigmod(a, m, mlen, mlen*2);
for (i = 2*mlen - 1; i >= mlen; i--)
a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
}
for (i = 0; i < mlen; i++) m[i] = 0; free(m);
for (i = 0; i < mlen; i++) n[i] = 0; free(n);
}
+
+/*
+ * Compute (p * q) % mod.
+ * The most significant word of mod MUST be non-zero.
+ * We assume that the result array is the same size as the mod array.
+ */
+void modmul(Bignum p, Bignum q, Bignum mod, Bignum result)
+{
+ unsigned short *a, *n, *m, *o;
+ int mshift;
+ int pqlen, mlen, i, j;
+
+ /* Allocate m of size mlen, copy mod to m */
+ /* We use big endian internally */
+ mlen = mod[0];
+ m = malloc(mlen * sizeof(unsigned short));
+ for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
+
+ /* Shift m left to make msb bit set */
+ for (mshift = 0; mshift < 15; mshift++)
+ if ((m[0] << mshift) & 0x8000) break;
+ if (mshift) {
+ for (i = 0; i < mlen - 1; i++)
+ m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
+ m[mlen-1] = m[mlen-1] << mshift;
+ }
+
+ pqlen = (p[0] > q[0] ? p[0] : q[0]);
+
+ /* Allocate n of size pqlen, copy p to n */
+ n = malloc(pqlen * sizeof(unsigned short));
+ i = pqlen - p[0];
+ for (j = 0; j < i; j++) n[j] = 0;
+ for (j = 0; j < p[0]; j++) n[i+j] = p[p[0] - j];
+
+ /* Allocate o of size pqlen, copy q to o */
+ o = malloc(pqlen * sizeof(unsigned short));
+ i = pqlen - q[0];
+ for (j = 0; j < i; j++) o[j] = 0;
+ for (j = 0; j < q[0]; j++) o[i+j] = q[q[0] - j];
+
+ /* Allocate a of size 2*pqlen for result */
+ a = malloc(2 * pqlen * sizeof(unsigned short));
+
+ /* Main computation */
+ bigmul(n, o, a, pqlen);
+ bigmod(a, m, mlen, 2*pqlen);
+
+ /* Fixup result in case the modulus was shifted */
+ if (mshift) {
+ for (i = 2*pqlen - mlen - 1; i < 2*pqlen - 1; i++)
+ a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
+ a[2*pqlen-1] = a[2*pqlen-1] << mshift;
+ bigmod(a, m, mlen, pqlen*2);
+ for (i = 2*pqlen - 1; i >= 2*pqlen - mlen; i--)
+ a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
+ }
+
+ /* Copy result to buffer */
+ for (i = 0; i < mlen; i++)
+ result[result[0] - i] = a[i+2*pqlen-mlen];
+
+ /* Free temporary arrays */
+ for (i = 0; i < 2*pqlen; i++) a[i] = 0; free(a);
+ for (i = 0; i < mlen; i++) m[i] = 0; free(m);
+ for (i = 0; i < pqlen; i++) n[i] = 0; free(n);
+ for (i = 0; i < pqlen; i++) o[i] = 0; free(o);
+}
+
+/*
+ * Decrement a number.
+ */
+void decbn(Bignum bn) {
+ int i = 1;
+ while (i < bn[0] && bn[i] == 0)
+ bn[i++] = 0xFFFF;
+ bn[i]--;
+}
+
+/*
+ * Read an ssh1-format bignum from a data buffer. Return the number
+ * of bytes consumed.
+ */
+int ssh1_read_bignum(unsigned char *data, Bignum *result) {
+ unsigned char *p = data;
+ Bignum bn;
+ int i;
+ int w, b;
+
+ w = 0;
+ for (i=0; i<2; i++)
+ w = (w << 8) + *p++;
+
+ b = (w+7)/8; /* bits -> bytes */
+ w = (w+15)/16; /* bits -> words */
+
+ if (!result) /* just return length */
+ return b + 2;
+
+ bn = newbn(w);
+
+ for (i=1; i<=w; i++)
+ bn[i] = 0;
+ for (i=b; i-- ;) {
+ unsigned char byte = *p++;
+ if (i & 1)
+ bn[1+i/2] |= byte<<8;
+ else
+ bn[1+i/2] |= byte;
+ }
+
+ *result = bn;
+
+ return p - data;
+}
+
+/*
+ * Return the bit count of a bignum, for ssh1 encoding.
+ */
+int ssh1_bignum_bitcount(Bignum bn) {
+ int bitcount = bn[0] * 16 - 1;
+
+ while (bitcount >= 0 && (bn[bitcount/16+1] >> (bitcount % 16)) == 0)
+ bitcount--;
+ return bitcount + 1;
+}
+
+/*
+ * Return the byte length of a bignum when ssh1 encoded.
+ */
+int ssh1_bignum_length(Bignum bn) {
+ return 2 + (ssh1_bignum_bitcount(bn)+7)/8;
+}
+
+/*
+ * Return a byte from a bignum; 0 is least significant, etc.
+ */
+int bignum_byte(Bignum bn, int i) {
+ if (i >= 2*bn[0])
+ return 0; /* beyond the end */
+ else if (i & 1)
+ return (bn[i/2+1] >> 8) & 0xFF;
+ else
+ return (bn[i/2+1] ) & 0xFF;
+}
+
+/*
+ * Write a ssh1-format bignum into a buffer. It is assumed the
+ * buffer is big enough. Returns the number of bytes used.
+ */
+int ssh1_write_bignum(void *data, Bignum bn) {
+ unsigned char *p = data;
+ int len = ssh1_bignum_length(bn);
+ int i;
+ int bitc = ssh1_bignum_bitcount(bn);
+
+ *p++ = (bitc >> 8) & 0xFF;
+ *p++ = (bitc ) & 0xFF;
+ for (i = len-2; i-- ;)
+ *p++ = bignum_byte(bn, i);
+ return len;
+}