66b7657bd581c7850891d0902355a62bfde81724
[u/mdw/catacomb] / mpmont.c
1 /* -*-c-*-
2 *
3 * $Id: mpmont.c,v 1.3 1999/11/21 11:35:10 mdw Exp $
4 *
5 * Montgomery reduction
6 *
7 * (c) 1999 Straylight/Edgeware
8 */
9
10 /*----- Licensing notice --------------------------------------------------*
11 *
12 * This file is part of Catacomb.
13 *
14 * Catacomb is free software; you can redistribute it and/or modify
15 * it under the terms of the GNU Library General Public License as
16 * published by the Free Software Foundation; either version 2 of the
17 * License, or (at your option) any later version.
18 *
19 * Catacomb is distributed in the hope that it will be useful,
20 * but WITHOUT ANY WARRANTY; without even the implied warranty of
21 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22 * GNU Library General Public License for more details.
23 *
24 * You should have received a copy of the GNU Library General Public
25 * License along with Catacomb; if not, write to the Free
26 * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27 * MA 02111-1307, USA.
28 */
29
30 /*----- Revision history --------------------------------------------------*
31 *
32 * $Log: mpmont.c,v $
33 * Revision 1.3 1999/11/21 11:35:10 mdw
34 * Performance improvement: use @mp_sqr@ and @mpmont_reduce@ instead of
35 * @mpmont_mul@ for squaring in exponentiation.
36 *
37 * Revision 1.2 1999/11/19 13:17:26 mdw
38 * Add extra interface to exponentiation which returns a Montgomerized
39 * result.
40 *
41 * Revision 1.1 1999/11/17 18:02:16 mdw
42 * New multiprecision integer arithmetic suite.
43 *
44 */
45
46 /*----- Header files ------------------------------------------------------*/
47
48 #include "mp.h"
49 #include "mpmont.h"
50
51 /*----- Main code ---------------------------------------------------------*/
52
53 /* --- @mpmont_create@ --- *
54 *
55 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
56 * @mp *m@ = modulus to use
57 *
58 * Returns: ---
59 *
60 * Use: Initializes a Montgomery reduction context ready for use.
61 */
62
63 void mpmont_create(mpmont *mm, mp *m)
64 {
65 /* --- Take a copy of the modulus --- */
66
67 mp_shrink(m);
68 mm->m = MP_COPY(m);
69
70 /* --- Find the magic value @mi@ --- *
71 *
72 * This is a slightly grungy way of solving the problem, but it does work.
73 */
74
75 {
76 mpw av[2] = { 0, 1 };
77 mp a, b;
78 mp *i;
79 mpw mi;
80
81 mp_build(&a, av, av + 2);
82 mp_build(&b, m->v, m->v + 1);
83 mp_gcd(0, 0, &i, &a, &b);
84 mi = i->v[0];
85 if (!(i->f & MP_NEG))
86 mi = MPW(-mi);
87 mm->mi = mi;
88 MP_DROP(i);
89 }
90
91 /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
92
93 {
94 size_t l = MP_LEN(m);
95 mp *r = mp_create(l + 1);
96
97 mm->shift = l * MPW_BITS;
98 MPX_ZERO(r->v, r->vl - 1);
99 r->vl[-1] = 1;
100 mm->r = mm->r2 = MP_NEW;
101 mp_div(0, &mm->r, r, m);
102 r = mp_sqr(r, mm->r);
103 mp_div(0, &mm->r2, r, m);
104 MP_DROP(r);
105 }
106 }
107
108 /* --- @mpmont_destroy@ --- *
109 *
110 * Arguments: @mpmont *mm@ = pointer to a Montgomery reduction context
111 *
112 * Returns: ---
113 *
114 * Use: Disposes of a context when it's no longer of any use to
115 * anyone.
116 */
117
118 void mpmont_destroy(mpmont *mm)
119 {
120 MP_DROP(mm->m);
121 MP_DROP(mm->r);
122 MP_DROP(mm->r2);
123 }
124
125 /* --- @mpmont_reduce@ --- *
126 *
127 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
128 * @mp *d@ = destination
129 * @const mp *a@ = source, assumed positive
130 *
131 * Returns: Result, %$a R^{-1} \bmod m$%.
132 */
133
134 mp *mpmont_reduce(mpmont *mm, mp *d, const mp *a)
135 {
136 mpw *dv, *dvl;
137 const mpw *mv, *mvl;
138 size_t n;
139
140 /* --- Initial conditioning of the arguments --- */
141
142 n = MP_LEN(mm->m);
143
144 if (d == a)
145 MP_MODIFY(d, 2 * n + 1);
146 else {
147 MP_MODIFY(d, 2 * n + 1);
148 memcpy(d->v, a->v, MPWS(MP_LEN(a)));
149 memset(d->v + MP_LEN(a), 0, MPWS(MP_LEN(d) - MP_LEN(a)));
150 }
151
152 dv = d->v; dvl = d->vl;
153 mv = mm->m->v; mvl = mm->m->vl;
154
155 /* --- Let's go to work --- */
156
157 while (n--) {
158 mpw u = MPW(*dv * mm->mi);
159 MPX_UMLAN(dv, dvl, mv, mvl, u);
160 dv++;
161 }
162
163 /* --- Done --- */
164
165 memmove(d->v, dv, MPWS(dvl - dv));
166 d->vl -= dv - d->v;
167 MP_SHRINK(d);
168 d->f = a->f & MP_BURN;
169 if (MP_CMP(d, >=, mm->m))
170 d = mp_sub(d, d, mm->m);
171 return (d);
172 }
173
174 /* --- @mpmont_mul@ --- *
175 *
176 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
177 * @mp *d@ = destination
178 * @const mp *a, *b@ = sources, assumed positive
179 *
180 * Returns: Result, %$a b R^{-1} \bmod m$%.
181 */
182
183 mp *mpmont_mul(mpmont *mm, mp *d, const mp *a, const mp *b)
184 {
185 mpw *dv, *dvl;
186 const mpw *av, *avl;
187 const mpw *bv, *bvl;
188 const mpw *mv, *mvl;
189 mpw y;
190 size_t n, i;
191
192 /* --- Initial conditioning of the arguments --- */
193
194 if (MP_LEN(a) > MP_LEN(b)) {
195 const mp *t = a; a = b; b = t;
196 }
197 n = MP_LEN(mm->m);
198
199 MP_MODIFY(d, 2 * n + 1);
200 dv = d->v; dvl = d->vl;
201 MPX_ZERO(dv, dvl);
202 av = a->v; avl = a->vl;
203 bv = b->v; bvl = b->vl;
204 mv = mm->m->v; mvl = mm->m->vl;
205 y = *bv;
206
207 /* --- Montgomery multiplication phase --- */
208
209 i = 0;
210 while (i < n && av < avl) {
211 mpw x = *av++;
212 mpw u = MPW((*dv + x * y) * mm->mi);
213 MPX_UMLAN(dv, dvl, bv, bvl, x);
214 MPX_UMLAN(dv, dvl, mv, mvl, u);
215 dv++;
216 i++;
217 }
218
219 /* --- Simpler Montgomery reduction phase --- */
220
221 while (i < n) {
222 mpw u = MPW(*dv * mm->mi);
223 MPX_UMLAN(dv, dvl, mv, mvl, u);
224 dv++;
225 i++;
226 }
227
228 /* --- Done --- */
229
230 memmove(d->v, dv, MPWS(dvl - dv));
231 d->vl -= dv - d->v;
232 MP_SHRINK(d);
233 d->f = (a->f | b->f) & MP_BURN;
234 if (MP_CMP(d, >=, mm->m))
235 d = mp_sub(d, d, mm->m);
236 return (d);
237 }
238
239 /* --- @mpmont_expr@ --- *
240 *
241 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
242 * @const mp *a@ = base
243 * @const mp *e@ = exponent
244 *
245 * Returns: Result, %$a^e R \bmod m$%.
246 */
247
248 mp *mpmont_expr(mpmont *mm, const mp *a, const mp *e)
249 {
250 mpscan sc;
251 mp *ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
252 mp *d = MP_COPY(mm->r);
253 mp *spare = MP_NEW;
254
255 mp_scan(&sc, e);
256
257 if (MP_STEP(&sc)) {
258 size_t sq = 0;
259 for (;;) {
260 mp *dd;
261 if (MP_BIT(&sc)) {
262 while (sq) {
263 dd = mp_sqr(spare, ar);
264 dd = mpmont_reduce(mm, dd, dd);
265 spare = ar; ar = dd;
266 sq--;
267 }
268 dd = mpmont_mul(mm, spare, d, ar);
269 spare = d; d = dd;
270 }
271 sq++;
272 if (!MP_STEP(&sc))
273 break;
274 }
275 }
276 MP_DROP(ar);
277 if (spare != MP_NEW)
278 MP_DROP(spare);
279 return (d);
280 }
281
282 /* --- @mpmont_exp@ --- *
283 *
284 * Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
285 * @const mp *a@ = base
286 * @const mp *e@ = exponent
287 *
288 * Returns: Result, %$a^e \bmod m$%.
289 */
290
291 mp *mpmont_exp(mpmont *mm, const mp *a, const mp *e)
292 {
293 mp *d = mpmont_expr(mm, a, e);
294 d = mpmont_reduce(mm, d, d);
295 return (d);
296 }
297
298 /*----- Test rig ----------------------------------------------------------*/
299
300 #ifdef TEST_RIG
301
302 static int tcreate(dstr *v)
303 {
304 mp *m = *(mp **)v[0].buf;
305 mp *mi = *(mp **)v[1].buf;
306 mp *r = *(mp **)v[2].buf;
307 mp *r2 = *(mp **)v[3].buf;
308
309 mpmont mm;
310 int ok = 1;
311
312 mpmont_create(&mm, m);
313
314 if (mm.mi != mi->v[0]) {
315 fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
316 (unsigned long)mm.mi, (unsigned long)mi->v[0]);
317 fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
318 fputc('\n', stderr);
319 ok = 0;
320 }
321
322 if (MP_CMP(mm.r, !=, r)) {
323 fputs("\n*** bad r", stderr);
324 fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
325 fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
326 fputs("\n found ", stderr); mp_writefile(mm.r, stderr, 10);
327 fputc('\n', stderr);
328 ok = 0;
329 }
330
331 if (MP_CMP(mm.r2, !=, r2)) {
332 fputs("\n*** bad r2", stderr);
333 fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
334 fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
335 fputs("\n found ", stderr); mp_writefile(mm.r2, stderr, 10);
336 fputc('\n', stderr);
337 ok = 0;
338 }
339
340 MP_DROP(m);
341 MP_DROP(mi);
342 MP_DROP(r);
343 MP_DROP(r2);
344 mpmont_destroy(&mm);
345 return (ok);
346 }
347
348 static int tmul(dstr *v)
349 {
350 mp *m = *(mp **)v[0].buf;
351 mp *a = *(mp **)v[1].buf;
352 mp *b = *(mp **)v[2].buf;
353 mp *r = *(mp **)v[3].buf;
354 int ok = 1;
355
356 mpmont mm;
357 mpmont_create(&mm, m);
358
359 {
360 mp *qr = mp_mul(MP_NEW, a, b);
361 mp_div(0, &qr, qr, m);
362
363 if (MP_CMP(qr, !=, r)) {
364 fputs("\n*** classical modmul failed", stderr);
365 fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
366 fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
367 fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
368 fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
369 fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
370 fputc('\n', stderr);
371 ok = 0;
372 }
373
374 mp_drop(qr);
375 }
376
377 {
378 mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
379 mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
380 mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
381 mr = mpmont_reduce(&mm, mr, mr);
382 if (MP_CMP(mr, !=, r)) {
383 fputs("\n*** montgomery modmul failed", stderr);
384 fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
385 fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
386 fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
387 fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
388 fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
389 fputc('\n', stderr);
390 ok = 0;
391 }
392 MP_DROP(ar); MP_DROP(br);
393 mp_drop(mr);
394 }
395
396
397 MP_DROP(m);
398 MP_DROP(a);
399 MP_DROP(b);
400 MP_DROP(r);
401 mpmont_destroy(&mm);
402 return ok;
403 }
404
405 static int texp(dstr *v)
406 {
407 mp *m = *(mp **)v[0].buf;
408 mp *a = *(mp **)v[1].buf;
409 mp *b = *(mp **)v[2].buf;
410 mp *r = *(mp **)v[3].buf;
411 mp *mr;
412 int ok = 1;
413
414 mpmont mm;
415 mpmont_create(&mm, m);
416
417 mr = mpmont_exp(&mm, a, b);
418
419 if (MP_CMP(mr, !=, r)) {
420 fputs("\n*** montgomery modexp failed", stderr);
421 fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
422 fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
423 fputs("\n e = ", stderr); mp_writefile(b, stderr, 10);
424 fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
425 fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
426 fputc('\n', stderr);
427 ok = 0;
428 }
429
430 MP_DROP(m);
431 MP_DROP(a);
432 MP_DROP(b);
433 MP_DROP(r);
434 MP_DROP(mr);
435 mpmont_destroy(&mm);
436 return ok;
437 }
438
439
440 static test_chunk tests[] = {
441 { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp } },
442 { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp } },
443 { "exp", texp, { &type_mp, &type_mp, &type_mp, &type_mp } },
444 { 0, 0, { 0 } },
445 };
446
447 int main(int argc, char *argv[])
448 {
449 sub_init();
450 test_run(argc, argv, tests, SRCDIR "/tests/mpmont");
451 return (0);
452 }
453
454 #endif
455
456 /*----- That's all, folks -------------------------------------------------*/