Split mode macros into interface and implementation.
[u/mdw/catacomb] / mpmont.c
1 /* -*-c-*-
2 *
3 * $Id: mpmont.c,v 1.5 1999/11/22 13:58:40 mdw Exp $
4 *
5 * Montgomery reduction
6 *
7 * (c) 1999 Straylight/Edgeware
8 */
9
10 /*----- Licensing notice --------------------------------------------------*
11 *
12 * This file is part of Catacomb.
13 *
14 * Catacomb is free software; you can redistribute it and/or modify
15 * it under the terms of the GNU Library General Public License as
16 * published by the Free Software Foundation; either version 2 of the
17 * License, or (at your option) any later version.
18 *
19 * Catacomb is distributed in the hope that it will be useful,
20 * but WITHOUT ANY WARRANTY; without even the implied warranty of
21 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22 * GNU Library General Public License for more details.
23 *
24 * You should have received a copy of the GNU Library General Public
25 * License along with Catacomb; if not, write to the Free
26 * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27 * MA 02111-1307, USA.
28 */
29
30 /*----- Revision history --------------------------------------------------*
31 *
32 * $Log: mpmont.c,v $
33 * Revision 1.5 1999/11/22 13:58:40 mdw
34 * Add an option to disable Montgomery reduction, so that performance
35 * comparisons can be done.
36 *
37 * Revision 1.4 1999/11/21 12:27:06 mdw
38 * Remove a division from the Montgomery setup by calculating
39 * %$R^2 \bmod m$% first and then %$R \bmod m$% by Montgomery reduction of
40 * %$R^2$%.
41 *
42 * Revision 1.3 1999/11/21 11:35:10 mdw
43 * Performance improvement: use @mp_sqr@ and @mpmont_reduce@ instead of
44 * @mpmont_mul@ for squaring in exponentiation.
45 *
46 * Revision 1.2 1999/11/19 13:17:26 mdw
47 * Add extra interface to exponentiation which returns a Montgomerized
48 * result.
49 *
50 * Revision 1.1 1999/11/17 18:02:16 mdw
51 * New multiprecision integer arithmetic suite.
52 *
53 */
54
55 /*----- Header files ------------------------------------------------------*/
56
57 #include "mp.h"
58 #include "mpmont.h"
59
60 /*----- Tweakables --------------------------------------------------------*/
61
62 /* --- @MPMONT_DISABLE@ --- *
63 *
64 * Replace all the clever Montgomery reduction with good old-fashioned long
65 * division.
66 */
67
68 /* #define MPMONT_DISABLE */
69
70 /*----- Main code ---------------------------------------------------------*/
71
72 /* --- @mpmont_create@ --- *
73 *
74 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
75 * @mp *m@ = modulus to use
76 *
77 * Returns: ---
78 *
79 * Use: Initializes a Montgomery reduction context ready for use.
80 */
81
82 #ifdef MPMONT_DISABLE
83
84 void mpmont_create(mpmont *mm, mp *m)
85 {
86 mp_shrink(m);
87 mm->m = MP_COPY(m);
88 mm->r = MP_ONE;
89 mm->r2 = MP_ONE;
90 }
91
92 #else
93
94 void mpmont_create(mpmont *mm, mp *m)
95 {
96 /* --- Take a copy of the modulus --- */
97
98 mp_shrink(m);
99 mm->m = MP_COPY(m);
100
101 /* --- Find the magic value @mi@ --- *
102 *
103 * This is a slightly grungy way of solving the problem, but it does work.
104 */
105
106 {
107 mpw av[2] = { 0, 1 };
108 mp a, b;
109 mp *i;
110 mpw mi;
111
112 mp_build(&a, av, av + 2);
113 mp_build(&b, m->v, m->v + 1);
114 mp_gcd(0, 0, &i, &a, &b);
115 mi = i->v[0];
116 if (!(i->f & MP_NEG))
117 mi = MPW(-mi);
118 mm->mi = mi;
119 MP_DROP(i);
120 }
121
122 /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
123
124 {
125 size_t l = MP_LEN(m);
126 mp *r = mp_create(2 * l + 1);
127
128 mm->shift = l * MPW_BITS;
129 MPX_ZERO(r->v, r->vl - 1);
130 r->vl[-1] = 1;
131 mm->r2 = MP_NEW;
132 mp_div(0, &mm->r2, r, m);
133 mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
134 MP_DROP(r);
135 }
136 }
137
138 #endif
139
140 /* --- @mpmont_destroy@ --- *
141 *
142 * Arguments: @mpmont *mm@ = pointer to a Montgomery reduction context
143 *
144 * Returns: ---
145 *
146 * Use: Disposes of a context when it's no longer of any use to
147 * anyone.
148 */
149
150 void mpmont_destroy(mpmont *mm)
151 {
152 MP_DROP(mm->m);
153 MP_DROP(mm->r);
154 MP_DROP(mm->r2);
155 }
156
157 /* --- @mpmont_reduce@ --- *
158 *
159 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
160 * @mp *d@ = destination
161 * @const mp *a@ = source, assumed positive
162 *
163 * Returns: Result, %$a R^{-1} \bmod m$%.
164 */
165
166 #ifdef MPMONT_DISABLE
167
168 mp *mpmont_reduce(mpmont *mm, mp *d, const mp *a)
169 {
170 mp_div(0, &d, a, mm->m);
171 return (d);
172 }
173
174 #else
175
176 mp *mpmont_reduce(mpmont *mm, mp *d, const mp *a)
177 {
178 mpw *dv, *dvl;
179 const mpw *mv, *mvl;
180 size_t n;
181
182 /* --- Initial conditioning of the arguments --- */
183
184 n = MP_LEN(mm->m);
185
186 if (d == a)
187 MP_MODIFY(d, 2 * n + 1);
188 else {
189 MP_MODIFY(d, 2 * n + 1);
190 memcpy(d->v, a->v, MPWS(MP_LEN(a)));
191 memset(d->v + MP_LEN(a), 0, MPWS(MP_LEN(d) - MP_LEN(a)));
192 }
193
194 dv = d->v; dvl = d->vl;
195 mv = mm->m->v; mvl = mm->m->vl;
196
197 /* --- Let's go to work --- */
198
199 while (n--) {
200 mpw u = MPW(*dv * mm->mi);
201 MPX_UMLAN(dv, dvl, mv, mvl, u);
202 dv++;
203 }
204
205 /* --- Done --- */
206
207 memmove(d->v, dv, MPWS(dvl - dv));
208 d->vl -= dv - d->v;
209 MP_SHRINK(d);
210 d->f = a->f & MP_BURN;
211 if (MP_CMP(d, >=, mm->m))
212 d = mp_sub(d, d, mm->m);
213 return (d);
214 }
215
216 #endif
217
218 /* --- @mpmont_mul@ --- *
219 *
220 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
221 * @mp *d@ = destination
222 * @const mp *a, *b@ = sources, assumed positive
223 *
224 * Returns: Result, %$a b R^{-1} \bmod m$%.
225 */
226
227 #ifdef MPMONT_DISABLE
228
229 mp *mpmont_mul(mpmont *mm, mp *d, const mp *a, const mp *b)
230 {
231 d = mp_mul(d, a, b);
232 mp_div(0, &d, d, mm->m);
233 return (d);
234 }
235
236 #else
237
238 mp *mpmont_mul(mpmont *mm, mp *d, const mp *a, const mp *b)
239 {
240 mpw *dv, *dvl;
241 const mpw *av, *avl;
242 const mpw *bv, *bvl;
243 const mpw *mv, *mvl;
244 mpw y;
245 size_t n, i;
246
247 /* --- Initial conditioning of the arguments --- */
248
249 if (MP_LEN(a) > MP_LEN(b)) {
250 const mp *t = a; a = b; b = t;
251 }
252 n = MP_LEN(mm->m);
253
254 MP_MODIFY(d, 2 * n + 1);
255 dv = d->v; dvl = d->vl;
256 MPX_ZERO(dv, dvl);
257 av = a->v; avl = a->vl;
258 bv = b->v; bvl = b->vl;
259 mv = mm->m->v; mvl = mm->m->vl;
260 y = *bv;
261
262 /* --- Montgomery multiplication phase --- */
263
264 i = 0;
265 while (i < n && av < avl) {
266 mpw x = *av++;
267 mpw u = MPW((*dv + x * y) * mm->mi);
268 MPX_UMLAN(dv, dvl, bv, bvl, x);
269 MPX_UMLAN(dv, dvl, mv, mvl, u);
270 dv++;
271 i++;
272 }
273
274 /* --- Simpler Montgomery reduction phase --- */
275
276 while (i < n) {
277 mpw u = MPW(*dv * mm->mi);
278 MPX_UMLAN(dv, dvl, mv, mvl, u);
279 dv++;
280 i++;
281 }
282
283 /* --- Done --- */
284
285 memmove(d->v, dv, MPWS(dvl - dv));
286 d->vl -= dv - d->v;
287 MP_SHRINK(d);
288 d->f = (a->f | b->f) & MP_BURN;
289 if (MP_CMP(d, >=, mm->m))
290 d = mp_sub(d, d, mm->m);
291 return (d);
292 }
293
294 #endif
295
296 /* --- @mpmont_expr@ --- *
297 *
298 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
299 * @const mp *a@ = base
300 * @const mp *e@ = exponent
301 *
302 * Returns: Result, %$a^e R \bmod m$%.
303 */
304
305 mp *mpmont_expr(mpmont *mm, const mp *a, const mp *e)
306 {
307 mpscan sc;
308 mp *ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
309 mp *d = MP_COPY(mm->r);
310 mp *spare = MP_NEW;
311
312 mp_scan(&sc, e);
313
314 if (MP_STEP(&sc)) {
315 size_t sq = 0;
316 for (;;) {
317 mp *dd;
318 if (MP_BIT(&sc)) {
319 while (sq) {
320 dd = mp_sqr(spare, ar);
321 dd = mpmont_reduce(mm, dd, dd);
322 spare = ar; ar = dd;
323 sq--;
324 }
325 dd = mpmont_mul(mm, spare, d, ar);
326 spare = d; d = dd;
327 }
328 sq++;
329 if (!MP_STEP(&sc))
330 break;
331 }
332 }
333 MP_DROP(ar);
334 if (spare != MP_NEW)
335 MP_DROP(spare);
336 return (d);
337 }
338
339 /* --- @mpmont_exp@ --- *
340 *
341 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
342 * @const mp *a@ = base
343 * @const mp *e@ = exponent
344 *
345 * Returns: Result, %$a^e \bmod m$%.
346 */
347
348 mp *mpmont_exp(mpmont *mm, const mp *a, const mp *e)
349 {
350 mp *d = mpmont_expr(mm, a, e);
351 d = mpmont_reduce(mm, d, d);
352 return (d);
353 }
354
355 /*----- Test rig ----------------------------------------------------------*/
356
357 #ifdef TEST_RIG
358
359 static int tcreate(dstr *v)
360 {
361 mp *m = *(mp **)v[0].buf;
362 mp *mi = *(mp **)v[1].buf;
363 mp *r = *(mp **)v[2].buf;
364 mp *r2 = *(mp **)v[3].buf;
365
366 mpmont mm;
367 int ok = 1;
368
369 mpmont_create(&mm, m);
370
371 if (mm.mi != mi->v[0]) {
372 fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
373 (unsigned long)mm.mi, (unsigned long)mi->v[0]);
374 fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
375 fputc('\n', stderr);
376 ok = 0;
377 }
378
379 if (MP_CMP(mm.r, !=, r)) {
380 fputs("\n*** bad r", stderr);
381 fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
382 fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
383 fputs("\n found ", stderr); mp_writefile(mm.r, stderr, 10);
384 fputc('\n', stderr);
385 ok = 0;
386 }
387
388 if (MP_CMP(mm.r2, !=, r2)) {
389 fputs("\n*** bad r2", stderr);
390 fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
391 fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
392 fputs("\n found ", stderr); mp_writefile(mm.r2, stderr, 10);
393 fputc('\n', stderr);
394 ok = 0;
395 }
396
397 MP_DROP(m);
398 MP_DROP(mi);
399 MP_DROP(r);
400 MP_DROP(r2);
401 mpmont_destroy(&mm);
402 return (ok);
403 }
404
405 static int tmul(dstr *v)
406 {
407 mp *m = *(mp **)v[0].buf;
408 mp *a = *(mp **)v[1].buf;
409 mp *b = *(mp **)v[2].buf;
410 mp *r = *(mp **)v[3].buf;
411 int ok = 1;
412
413 mpmont mm;
414 mpmont_create(&mm, m);
415
416 {
417 mp *qr = mp_mul(MP_NEW, a, b);
418 mp_div(0, &qr, qr, m);
419
420 if (MP_CMP(qr, !=, r)) {
421 fputs("\n*** classical modmul failed", stderr);
422 fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
423 fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
424 fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
425 fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
426 fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
427 fputc('\n', stderr);
428 ok = 0;
429 }
430
431 mp_drop(qr);
432 }
433
434 {
435 mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
436 mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
437 mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
438 mr = mpmont_reduce(&mm, mr, mr);
439 if (MP_CMP(mr, !=, r)) {
440 fputs("\n*** montgomery modmul failed", stderr);
441 fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
442 fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
443 fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
444 fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
445 fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
446 fputc('\n', stderr);
447 ok = 0;
448 }
449 MP_DROP(ar); MP_DROP(br);
450 mp_drop(mr);
451 }
452
453
454 MP_DROP(m);
455 MP_DROP(a);
456 MP_DROP(b);
457 MP_DROP(r);
458 mpmont_destroy(&mm);
459 return ok;
460 }
461
462 static int texp(dstr *v)
463 {
464 mp *m = *(mp **)v[0].buf;
465 mp *a = *(mp **)v[1].buf;
466 mp *b = *(mp **)v[2].buf;
467 mp *r = *(mp **)v[3].buf;
468 mp *mr;
469 int ok = 1;
470
471 mpmont mm;
472 mpmont_create(&mm, m);
473
474 mr = mpmont_exp(&mm, a, b);
475
476 if (MP_CMP(mr, !=, r)) {
477 fputs("\n*** montgomery modexp failed", stderr);
478 fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
479 fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
480 fputs("\n e = ", stderr); mp_writefile(b, stderr, 10);
481 fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
482 fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
483 fputc('\n', stderr);
484 ok = 0;
485 }
486
487 MP_DROP(m);
488 MP_DROP(a);
489 MP_DROP(b);
490 MP_DROP(r);
491 MP_DROP(mr);
492 mpmont_destroy(&mm);
493 return ok;
494 }
495
496
497 static test_chunk tests[] = {
498 { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp } },
499 { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp } },
500 { "exp", texp, { &type_mp, &type_mp, &type_mp, &type_mp } },
501 { 0, 0, { 0 } },
502 };
503
504 int main(int argc, char *argv[])
505 {
506 sub_init();
507 test_run(argc, argv, tests, SRCDIR "/tests/mpmont");
508 return (0);
509 }
510
511 #endif
512
513 /*----- That's all, folks -------------------------------------------------*/