Initial checkin: beta 0.43
[u/mdw/putty] / sshrsa.c
1 /*
2 * RSA implementation just sufficient for ssh client-side
3 * initialisation step
4 */
5
6 /*#include <windows.h>
7 #define RSADEBUG
8 #define DLVL 2
9 #include "stel.h"*/
10
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14
15 #include "ssh.h"
16
17 typedef unsigned short *Bignum;
18
19 static unsigned short Zero[1] = { 0 };
20
21 #if defined TESTMODE || defined RSADEBUG
22 #ifndef DLVL
23 #define DLVL 10000
24 #endif
25 #define debug(x) bndebug(#x,x)
26 static int level = 0;
27 static void bndebug(char *name, Bignum b) {
28 int i;
29 int w = 50-level-strlen(name)-5*b[0];
30 if (level >= DLVL)
31 return;
32 if (w < 0) w = 0;
33 dprintf("%*s%s%*s", level, "", name, w, "");
34 for (i=b[0]; i>0; i--)
35 dprintf(" %04x", b[i]);
36 dprintf("\n");
37 }
38 #define dmsg(x) do {if(level<DLVL){dprintf("%*s",level,"");printf x;}} while(0)
39 #define enter(x) do { dmsg(x); level += 4; } while(0)
40 #define leave(x) do { level -= 4; dmsg(x); } while(0)
41 #else
42 #define debug(x)
43 #define dmsg(x)
44 #define enter(x)
45 #define leave(x)
46 #endif
47
48 static Bignum newbn(int length) {
49 Bignum b = malloc((length+1)*sizeof(unsigned short));
50 if (!b)
51 abort(); /* FIXME */
52 b[0] = length;
53 return b;
54 }
55
56 static void freebn(Bignum b) {
57 free(b);
58 }
59
60 static int msb(Bignum r) {
61 int i;
62 int j;
63 unsigned short n;
64
65 for (i=r[0]; i>0; i--)
66 if (r[i])
67 break;
68
69 j = (i-1)*16;
70 n = r[i];
71 if (n & 0xFF00) j += 8, n >>= 8;
72 if (n & 0x00F0) j += 4, n >>= 4;
73 if (n & 0x000C) j += 2, n >>= 2;
74 if (n & 0x0002) j += 1, n >>= 1;
75
76 return j;
77 }
78
79 static void add(Bignum r1, Bignum r2, Bignum result) {
80 int i;
81 long stuff = 0;
82
83 enter((">add\n"));
84 debug(r1);
85 debug(r2);
86
87 for (i = 1 ;; i++) {
88 if (i <= r1[0])
89 stuff += r1[i];
90 if (i <= r2[0])
91 stuff += r2[i];
92 if (i <= result[0])
93 result[i] = stuff & 0xFFFFU;
94 if (i > r1[0] && i > r2[0] && i >= result[0])
95 break;
96 stuff >>= 16;
97 }
98
99 debug(result);
100 leave(("<add\n"));
101 }
102
103 static void sub(Bignum r1, Bignum r2, Bignum result) {
104 int i;
105 long stuff = 0;
106
107 enter((">sub\n"));
108 debug(r1);
109 debug(r2);
110
111 for (i = 1 ;; i++) {
112 if (i <= r1[0])
113 stuff += r1[i];
114 if (i <= r2[0])
115 stuff -= r2[i];
116 if (i <= result[0])
117 result[i] = stuff & 0xFFFFU;
118 if (i > r1[0] && i > r2[0] && i >= result[0])
119 break;
120 stuff = stuff<0 ? -1 : 0;
121 }
122
123 debug(result);
124 leave(("<sub\n"));
125 }
126
127 static int ge(Bignum r1, Bignum r2) {
128 int i;
129
130 enter((">ge\n"));
131 debug(r1);
132 debug(r2);
133
134 if (r1[0] < r2[0])
135 i = r2[0];
136 else
137 i = r1[0];
138
139 while (i > 0) {
140 unsigned short n1 = (i > r1[0] ? 0 : r1[i]);
141 unsigned short n2 = (i > r2[0] ? 0 : r2[i]);
142
143 if (n1 > n2) {
144 dmsg(("greater\n"));
145 leave(("<ge\n"));
146 return 1; /* r1 > r2 */
147 } else if (n1 < n2) {
148 dmsg(("less\n"));
149 leave(("<ge\n"));
150 return 0; /* r1 < r2 */
151 }
152
153 i--;
154 }
155
156 dmsg(("equal\n"));
157 leave(("<ge\n"));
158 return 1; /* r1 = r2 */
159 }
160
161 static void modmult(Bignum r1, Bignum r2, Bignum modulus, Bignum result) {
162 Bignum temp = newbn(modulus[0]+1);
163 Bignum tmp2 = newbn(modulus[0]+1);
164 int i;
165 int bit, bits, digit, smallbit;
166
167 enter((">modmult\n"));
168 debug(r1);
169 debug(r2);
170 debug(modulus);
171
172 for (i=1; i<=result[0]; i++)
173 result[i] = 0; /* result := 0 */
174 for (i=1; i<=temp[0]; i++)
175 temp[i] = (i > r2[0] ? 0 : r2[i]); /* temp := r2 */
176
177 bits = 1+msb(r1);
178
179 for (bit = 0; bit < bits; bit++) {
180 digit = 1 + bit / 16;
181 smallbit = bit % 16;
182
183 debug(temp);
184 if (digit <= r1[0] && (r1[digit] & (1<<smallbit))) {
185 dmsg(("bit %d\n", bit));
186 add(temp, result, tmp2);
187 if (ge(tmp2, modulus))
188 sub(tmp2, modulus, result);
189 else
190 add(tmp2, Zero, result);
191 debug(result);
192 }
193
194 add(temp, temp, tmp2);
195 if (ge(tmp2, modulus))
196 sub(tmp2, modulus, temp);
197 else
198 add(tmp2, Zero, temp);
199 }
200
201 freebn(temp);
202 freebn(tmp2);
203
204 debug(result);
205 leave(("<modmult\n"));
206 }
207
208 static void modpow(Bignum r1, Bignum r2, Bignum modulus, Bignum result) {
209 Bignum temp = newbn(modulus[0]+1);
210 Bignum tmp2 = newbn(modulus[0]+1);
211 int i;
212 int bit, bits, digit, smallbit;
213
214 enter((">modpow\n"));
215 debug(r1);
216 debug(r2);
217 debug(modulus);
218
219 for (i=1; i<=result[0]; i++)
220 result[i] = (i==1); /* result := 1 */
221 for (i=1; i<=temp[0]; i++)
222 temp[i] = (i > r1[0] ? 0 : r1[i]); /* temp := r1 */
223
224 bits = 1+msb(r2);
225
226 for (bit = 0; bit < bits; bit++) {
227 digit = 1 + bit / 16;
228 smallbit = bit % 16;
229
230 debug(temp);
231 if (digit <= r2[0] && (r2[digit] & (1<<smallbit))) {
232 dmsg(("bit %d\n", bit));
233 modmult(temp, result, modulus, tmp2);
234 add(tmp2, Zero, result);
235 debug(result);
236 }
237
238 modmult(temp, temp, modulus, tmp2);
239 add(tmp2, Zero, temp);
240 }
241
242 freebn(temp);
243 freebn(tmp2);
244
245 debug(result);
246 leave(("<modpow\n"));
247 }
248
249 int makekey(unsigned char *data, struct RSAKey *result,
250 unsigned char **keystr) {
251 unsigned char *p = data;
252 Bignum bn[2];
253 int i, j;
254 int w, b;
255
256 result->bits = 0;
257 for (i=0; i<4; i++)
258 result->bits = (result->bits << 8) + *p++;
259
260 for (j=0; j<2; j++) {
261
262 w = 0;
263 for (i=0; i<2; i++)
264 w = (w << 8) + *p++;
265
266 result->bytes = b = (w+7)/8; /* bits -> bytes */
267 w = (w+15)/16; /* bits -> words */
268
269 bn[j] = newbn(w);
270
271 if (keystr) *keystr = p; /* point at key string, second time */
272
273 for (i=1; i<=w; i++)
274 bn[j][i] = 0;
275 for (i=0; i<b; i++) {
276 unsigned char byte = *p++;
277 if ((b-i) & 1)
278 bn[j][w-i/2] |= byte;
279 else
280 bn[j][w-i/2] |= byte<<8;
281 }
282
283 debug(bn[j]);
284
285 }
286
287 result->exponent = bn[0];
288 result->modulus = bn[1];
289
290 return p - data;
291 }
292
293 void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) {
294 Bignum b1, b2;
295 int w, i;
296 unsigned char *p;
297
298 debug(key->exponent);
299
300 memmove(data+key->bytes-length, data, length);
301 data[0] = 0;
302 data[1] = 2;
303
304 for (i = 2; i < key->bytes-length-1; i++) {
305 do {
306 data[i] = random_byte();
307 } while (data[i] == 0);
308 }
309 data[key->bytes-length-1] = 0;
310
311 w = (key->bytes+1)/2;
312
313 b1 = newbn(w);
314 b2 = newbn(w);
315
316 p = data;
317 for (i=1; i<=w; i++)
318 b1[i] = 0;
319 for (i=0; i<key->bytes; i++) {
320 unsigned char byte = *p++;
321 if ((key->bytes-i) & 1)
322 b1[w-i/2] |= byte;
323 else
324 b1[w-i/2] |= byte<<8;
325 }
326
327 debug(b1);
328
329 modpow(b1, key->exponent, key->modulus, b2);
330
331 debug(b2);
332
333 p = data;
334 for (i=0; i<key->bytes; i++) {
335 unsigned char b;
336 if (i & 1)
337 b = b2[w-i/2] & 0xFF;
338 else
339 b = b2[w-i/2] >> 8;
340 *p++ = b;
341 }
342
343 freebn(b1);
344 freebn(b2);
345 }
346
347 int rsastr_len(struct RSAKey *key) {
348 Bignum md, ex;
349
350 md = key->modulus;
351 ex = key->exponent;
352 return 4 * (ex[0]+md[0]) + 10;
353 }
354
355 void rsastr_fmt(char *str, struct RSAKey *key) {
356 Bignum md, ex;
357 int len = 0, i;
358
359 md = key->modulus;
360 ex = key->exponent;
361
362 for (i=1; i<=ex[0]; i++) {
363 sprintf(str+len, "%04x", ex[i]);
364 len += strlen(str+len);
365 }
366 str[len++] = '/';
367 for (i=1; i<=md[0]; i++) {
368 sprintf(str+len, "%04x", md[i]);
369 len += strlen(str+len);
370 }
371 str[len] = '\0';
372 }
373
374 #ifdef TESTMODE
375
376 #ifndef NODDY
377 #define p1 10007
378 #define p2 10069
379 #define p3 10177
380 #else
381 #define p1 3
382 #define p2 7
383 #define p3 13
384 #endif
385
386 unsigned short P1[2] = { 1, p1 };
387 unsigned short P2[2] = { 1, p2 };
388 unsigned short P3[2] = { 1, p3 };
389 unsigned short bigmod[5] = { 4, 0, 0, 0, 32768U };
390 unsigned short mod[5] = { 4, 0, 0, 0, 0 };
391 unsigned short a[5] = { 4, 0, 0, 0, 0 };
392 unsigned short b[5] = { 4, 0, 0, 0, 0 };
393 unsigned short c[5] = { 4, 0, 0, 0, 0 };
394 unsigned short One[2] = { 1, 1 };
395 unsigned short Two[2] = { 1, 2 };
396
397 int main(void) {
398 modmult(P1, P2, bigmod, a); debug(a);
399 modmult(a, P3, bigmod, mod); debug(mod);
400
401 sub(P1, One, a); debug(a);
402 sub(P2, One, b); debug(b);
403 modmult(a, b, bigmod, c); debug(c);
404 sub(P3, One, a); debug(a);
405 modmult(a, c, bigmod, b); debug(b);
406
407 modpow(Two, b, mod, a); debug(a);
408
409 return 0;
410 }
411
412 #endif