/* -*-c-*-
*
- * $Id: mpmont.c,v 1.6 1999/12/10 23:18:39 mdw Exp $
+ * $Id: mpmont.c,v 1.11 2000/10/08 12:04:27 mdw Exp $
*
* Montgomery reduction
*
/*----- Revision history --------------------------------------------------*
*
* $Log: mpmont.c,v $
+ * Revision 1.11 2000/10/08 12:04:27 mdw
+ * (mpmont_reduce, mpmont_mul): Cope with negative numbers.
+ *
+ * Revision 1.10 2000/07/29 17:05:43 mdw
+ * (mpmont_expr): Use sliding window exponentiation, with a drop-through
+ * for small exponents to use a simple left-to-right bitwise routine. This
+ * can reduce modexp times by up to a quarter.
+ *
+ * Revision 1.9 2000/06/17 11:45:09 mdw
+ * Major memory management overhaul. Added arena support. Use the secure
+ * arena for secret integers. Replace and improve the MP management macros
+ * (e.g., replace MP_MODIFY by MP_DEST).
+ *
+ * Revision 1.8 1999/12/22 15:55:00 mdw
+ * Adjust Karatsuba parameters.
+ *
+ * Revision 1.7 1999/12/11 01:51:14 mdw
+ * Use a Karatsuba-based reduction for large moduli.
+ *
* Revision 1.6 1999/12/10 23:18:39 mdw
* Change interface for suggested destinations.
*
mm->m = MP_COPY(m);
mm->r = MP_ONE;
mm->r2 = MP_ONE;
+ mm->mi = MP_ONE;
}
#else
void mpmont_create(mpmont *mm, mp *m)
{
+ size_t n = MP_LEN(m);
+ mp *r2 = mp_new(2 * n + 1, 0);
+ mp r;
+
/* --- Validate the arguments --- */
assert(((void)"Montgomery modulus must be positive",
mp_shrink(m);
mm->m = MP_COPY(m);
- /* --- Find the magic value @mi@ --- *
- *
- * This is a slightly grungy way of solving the problem, but it does work.
- */
+ /* --- Determine %$R^2$% --- */
- {
- mpw av[2] = { 0, 1 };
- mp a, b;
- mp *i = MP_NEW;
- mpw mi;
+ mm->n = n;
+ MPX_ZERO(r2->v, r2->vl - 1);
+ r2->vl[-1] = 1;
- mp_build(&a, av, av + 2);
- mp_build(&b, m->v, m->v + 1);
- mp_gcd(0, 0, &i, &a, &b);
- mi = i->v[0];
- if (!(i->f & MP_NEG))
- mi = MPW(-mi);
- mm->mi = mi;
- MP_DROP(i);
- }
+ /* --- Find the magic value @mi@ --- */
+
+ mp_build(&r, r2->v + n, r2->vl);
+ mm->mi = MP_NEW;
+ mp_gcd(0, 0, &mm->mi, &r, m);
+ mm->mi = mp_sub(mm->mi, &r, mm->mi);
/* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
- {
- size_t l = MP_LEN(m);
- mp *r = mp_create(2 * l + 1);
-
- mm->shift = l * MPW_BITS;
- MPX_ZERO(r->v, r->vl - 1);
- r->vl[-1] = 1;
- mm->r2 = MP_NEW;
- mp_div(0, &mm->r2, r, m);
- mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
- MP_DROP(r);
- }
+ mm->r2 = MP_NEW;
+ mp_div(0, &mm->r2, r2, m);
+ mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
+ MP_DROP(r2);
}
#endif
MP_DROP(mm->m);
MP_DROP(mm->r);
MP_DROP(mm->r2);
+ MP_DROP(mm->mi);
}
/* --- @mpmont_reduce@ --- *
mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
{
- mpw *dv, *dvl;
- mpw *mv, *mvl;
- size_t n;
-
- /* --- Initial conditioning of the arguments --- */
+ size_t n = mm->n;
+
+ /* --- Check for serious Karatsuba reduction --- */
+
+ if (n > KARATSUBA_CUTOFF * 3) {
+ mp al;
+ mpw *vl;
+ mp *u;
+
+ if (MP_LEN(a) >= n)
+ vl = a->v + n;
+ else
+ vl = a->vl;
+ mp_build(&al, a->v, vl);
+ u = mp_mul(MP_NEW, &al, mm->mi);
+ if (MP_LEN(u) > n)
+ u->vl = u->v + n;
+ u = mp_mul(u, u, mm->m);
+ d = mp_add(d, a, u);
+ mp_drop(u);
+ }
- n = MP_LEN(mm->m);
+ /* --- Otherwise do it the hard way --- */
- if (d == a)
- MP_MODIFY(d, 2 * n + 1);
else {
- MP_MODIFY(d, 2 * n + 1);
- memcpy(d->v, a->v, MPWS(MP_LEN(a)));
- memset(d->v + MP_LEN(a), 0, MPWS(MP_LEN(d) - MP_LEN(a)));
- }
-
- dv = d->v; dvl = d->vl;
- mv = mm->m->v; mvl = mm->m->vl;
+ mpw *dv, *dvl;
+ mpw *mv, *mvl;
+ mpw mi;
+ size_t k = n;
- /* --- Let's go to work --- */
+ /* --- Initial conditioning of the arguments --- */
+
+ a = MP_COPY(a);
+ if (d)
+ MP_DROP(d);
+ d = a;
+ MP_DEST(d, 2 * n + 1, a->f);
- while (n--) {
- mpw u = MPW(*dv * mm->mi);
- MPX_UMLAN(dv, dvl, mv, mvl, u);
- dv++;
+ dv = d->v; dvl = d->vl;
+ mv = mm->m->v; mvl = mm->m->vl;
+
+ /* --- Let's go to work --- */
+
+ mi = mm->mi->v[0];
+ while (k--) {
+ mpw u = MPW(*dv * mi);
+ MPX_UMLAN(dv, dvl, mv, mvl, u);
+ dv++;
+ }
}
- /* --- Done --- */
+ /* --- Wrap everything up --- */
- memmove(d->v, dv, MPWS(dvl - dv));
- d->vl -= dv - d->v;
+ memmove(d->v, d->v + n, MPWS(MP_LEN(d) - n));
+ d->vl -= n;
+ if (MPX_UCMP(d->v, d->vl, >=, mm->m->v, mm->m->vl))
+ mpx_usub(d->v, d->vl, d->v, d->vl, mm->m->v, mm->m->vl);
+ if (d->f & MP_NEG) {
+ mpx_usub(d->v, d->vl, mm->m->v, mm->m->vl, d->v, d->vl);
+ d->f &= ~MP_NEG;
+ }
MP_SHRINK(d);
- d->f = a->f & MP_BURN;
- if (MP_CMP(d, >=, mm->m))
- d = mp_sub(d, d, mm->m);
return (d);
}
mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
{
- if (MP_LEN(a) > KARATSUBA_CUTOFF && MP_LEN(b) > KARATSUBA_CUTOFF) {
+ if (mm->n > KARATSUBA_CUTOFF * 3) {
d = mp_mul(d, a, b);
d = mpmont_reduce(mm, d, d);
} else {
mpw *mv, *mvl;
mpw y;
size_t n, i;
+ mpw mi;
/* --- Initial conditioning of the arguments --- */
a = MP_COPY(a);
b = MP_COPY(b);
- MP_MODIFY(d, 2 * n + 1);
+ MP_DEST(d, 2 * n + 1, a->f | b->f | MP_UNDEF);
dv = d->v; dvl = d->vl;
MPX_ZERO(dv, dvl);
av = a->v; avl = a->vl;
/* --- Montgomery multiplication phase --- */
i = 0;
+ mi = mm->mi->v[0];
while (i < n && av < avl) {
mpw x = *av++;
- mpw u = MPW((*dv + x * y) * mm->mi);
+ mpw u = MPW((*dv + x * y) * mi);
MPX_UMLAN(dv, dvl, bv, bvl, x);
MPX_UMLAN(dv, dvl, mv, mvl, u);
dv++;
/* --- Simpler Montgomery reduction phase --- */
while (i < n) {
- mpw u = MPW(*dv * mm->mi);
+ mpw u = MPW(*dv * mi);
MPX_UMLAN(dv, dvl, mv, mvl, u);
dv++;
i++;
memmove(d->v, dv, MPWS(dvl - dv));
d->vl -= dv - d->v;
+ if (MPX_UCMP(d->v, d->vl, >=, mm->m->v, mm->m->vl))
+ mpx_usub(d->v, d->vl, d->v, d->vl, mm->m->v, mm->m->vl);
+ if ((a->f ^ b->f) & MP_NEG)
+ mpx_usub(d->v, d->vl, mm->m->v, mm->m->vl, d->v, d->vl);
MP_SHRINK(d);
d->f = (a->f | b->f) & MP_BURN;
- if (MP_CMP(d, >=, mm->m))
- d = mp_sub(d, d, mm->m);
MP_DROP(a);
MP_DROP(b);
}
* Returns: Result, %$a^e R \bmod m$%.
*/
-mp *mpmont_expr(mpmont *mm, mp *d, mp *a, mp *e)
+#define WINSZ 5
+#define TABSZ (1 << (WINSZ - 1))
+
+#define THRESH (((MPW_BITS / WINSZ) << 2) + 1)
+
+static mp *exp_simple(mpmont *mm, mp *d, mp *a, mp *e)
{
mpscan sc;
- mp *ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
+ mp *ar;
mp *x = MP_COPY(mm->r);
- mp *spare = MP_NEW;
-
- mp_scan(&sc, e);
-
- if (MP_STEP(&sc)) {
- size_t sq = 0;
+ mp *spare = (e->f & MP_BURN) ? MP_NEWSEC : MP_NEW;
+ unsigned sq = 0;
+
+ mp_rscan(&sc, e);
+ if (!MP_RSTEP(&sc))
+ goto exit;
+ while (!MP_RBIT(&sc))
+ MP_RSTEP(&sc);
+
+ /* --- Do the main body of the work --- */
+
+ ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
+ for (;;) {
+ sq++;
+ while (sq) {
+ mp *y;
+ y = mp_sqr(spare, x);
+ y = mpmont_reduce(mm, y, y);
+ spare = x; x = y;
+ sq--;
+ }
+ { mp *y = mpmont_mul(mm, spare, x, ar); spare = x; x = y; }
+ sq = 0;
for (;;) {
- mp *dd;
- if (MP_BIT(&sc)) {
- while (sq) {
- dd = mp_sqr(spare, ar);
- dd = mpmont_reduce(mm, dd, dd);
- spare = ar; ar = dd;
- sq--;
- }
- dd = mpmont_mul(mm, spare, x, ar);
- spare = x; x = dd;
- }
- sq++;
- if (!MP_STEP(&sc))
+ if (!MP_RSTEP(&sc))
+ goto done;
+ if (MP_RBIT(&sc))
break;
+ sq++;
}
}
+
+ /* --- Do a final round of squaring --- */
+
+done:
+ while (sq) {
+ mp *y;
+ y = mp_sqr(spare, x);
+ y = mpmont_reduce(mm, y, y);
+ spare = x; x = y;
+ sq--;
+ }
+
+ /* --- Done --- */
+
MP_DROP(ar);
+exit:
if (spare != MP_NEW)
MP_DROP(spare);
if (d != MP_NEW)
return (x);
}
+mp *mpmont_expr(mpmont *mm, mp *d, mp *a, mp *e)
+{
+ mp **tab;
+ mp *ar, *a2;
+ mp *spare = (e->f & MP_BURN) ? MP_NEWSEC : MP_NEW;
+ mp *x = MP_COPY(mm->r);
+ unsigned i, sq = 0;
+ mpscan sc;
+
+ /* --- Do we bother? --- */
+
+ MP_SHRINK(e);
+ if (MP_LEN(e) == 0)
+ goto exit;
+ if (MP_LEN(e) < THRESH) {
+ x->ref--;
+ return (exp_simple(mm, d, a, e));
+ }
+
+ /* --- Do the precomputation --- */
+
+ ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
+ a2 = mp_sqr(MP_NEW, ar);
+ a2 = mpmont_reduce(mm, a2, a2);
+ tab = xmalloc(TABSZ * sizeof(mp *));
+ tab[0] = ar;
+ for (i = 1; i < TABSZ; i++)
+ tab[i] = mpmont_mul(mm, MP_NEW, tab[i - 1], a2);
+ mp_drop(a2);
+ mp_rscan(&sc, e);
+
+ /* --- Skip top-end zero bits --- *
+ *
+ * If the initial step worked, there must be a set bit somewhere, so keep
+ * stepping until I find it.
+ */
+
+ MP_RSTEP(&sc);
+ while (!MP_RBIT(&sc)) {
+ MP_RSTEP(&sc);
+ }
+
+ /* --- Now for the main work --- */
+
+ for (;;) {
+ unsigned l = 0;
+ unsigned z = 0;
+
+ /* --- The next bit is set, so read a window index --- *
+ *
+ * Reset @i@ to zero and increment @sq@. Then, until either I read
+ * @WINSZ@ bits or I run out of bits, scan in a bit: if it's clear, bump
+ * the @z@ counter; if it's set, push a set bit into @i@, shift it over
+ * by @z@ bits, bump @sq@ by @z + 1@ and clear @z@. By the end of this
+ * palaver, @i@ is an index to the precomputed value in @tab@.
+ */
+
+ i = 0;
+ sq++;
+ for (;;) {
+ l++;
+ if (l >= WINSZ || !MP_RSTEP(&sc))
+ break;
+ if (!MP_RBIT(&sc))
+ z++;
+ else {
+ i = ((i << 1) | 1) << z;
+ sq += z + 1;
+ z = 0;
+ }
+ }
+
+ /* --- Do the squaring --- *
+ *
+ * Remember that @sq@ carries over from the zero-skipping stuff below.
+ */
+
+ while (sq) {
+ mp *y;
+ y = mp_sqr(spare, x);
+ y = mpmont_reduce(mm, y, y);
+ spare = x; x = y;
+ sq--;
+ }
+
+ /* --- Do the multiply --- */
+
+ { mp *y = mpmont_mul(mm, spare, x, tab[i]); spare = x; x = y; }
+
+ /* --- Now grind along through the rest of the bits --- */
+
+ sq = z;
+ for (;;) {
+ if (!MP_RSTEP(&sc))
+ goto done;
+ if (MP_RBIT(&sc))
+ break;
+ sq++;
+ }
+ }
+
+ /* --- Do a final round of squaring --- */
+
+done:
+ while (sq) {
+ mp *y;
+ y = mp_sqr(spare, x);
+ y = mpmont_reduce(mm, y, y);
+ spare = x; x = y;
+ sq--;
+ }
+
+ /* --- Done --- */
+
+ for (i = 0; i < TABSZ; i++)
+ mp_drop(tab[i]);
+ xfree(tab);
+exit:
+ if (d != MP_NEW)
+ mp_drop(d);
+ if (spare)
+ mp_drop(spare);
+ return (x);
+}
+
/* --- @mpmont_exp@ --- *
*
* Arguments: @mpmont *mm@ = pointer to Montgomery reduction context
mpmont_create(&mm, m);
- if (mm.mi != mi->v[0]) {
+ if (mm.mi->v[0] != mi->v[0]) {
fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
- (unsigned long)mm.mi, (unsigned long)mi->v[0]);
+ (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
fputc('\n', stderr);
ok = 0;
}
- if (MP_CMP(mm.r, !=, r)) {
+ if (!MP_EQ(mm.r, r)) {
fputs("\n*** bad r", stderr);
fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
ok = 0;
}
- if (MP_CMP(mm.r2, !=, r2)) {
+ if (!MP_EQ(mm.r2, r2)) {
fputs("\n*** bad r2", stderr);
fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
mp *qr = mp_mul(MP_NEW, a, b);
mp_div(0, &qr, qr, m);
- if (MP_CMP(qr, !=, r)) {
+ if (!MP_EQ(qr, r)) {
fputs("\n*** classical modmul failed", stderr);
fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
mr = mpmont_reduce(&mm, mr, mr);
- if (MP_CMP(mr, !=, r)) {
+ if (!MP_EQ(mr, r)) {
fputs("\n*** montgomery modmul failed", stderr);
fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
mr = mpmont_exp(&mm, MP_NEW, a, b);
- if (MP_CMP(mr, !=, r)) {
+ if (!MP_EQ(mr, r)) {
fputs("\n*** montgomery modexp failed", stderr);
fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);