Use a Karatsuba-based reduction for large moduli.
authormdw <mdw>
Sat, 11 Dec 1999 01:51:14 +0000 (01:51 +0000)
committermdw <mdw>
Sat, 11 Dec 1999 01:51:14 +0000 (01:51 +0000)
mpmont.c
mpmont.h

index d061934..a5b5e66 100644 (file)
--- a/mpmont.c
+++ b/mpmont.c
@@ -1,6 +1,6 @@
 /* -*-c-*-
  *
- * $Id: mpmont.c,v 1.6 1999/12/10 23:18:39 mdw Exp $
+ * $Id: mpmont.c,v 1.7 1999/12/11 01:51:14 mdw Exp $
  *
  * Montgomery reduction
  *
@@ -30,6 +30,9 @@
 /*----- Revision history --------------------------------------------------* 
  *
  * $Log: mpmont.c,v $
+ * Revision 1.7  1999/12/11 01:51:14  mdw
+ * Use a Karatsuba-based reduction for large moduli.
+ *
  * Revision 1.6  1999/12/10 23:18:39  mdw
  * Change interface for suggested destinations.
  *
@@ -91,12 +94,17 @@ void mpmont_create(mpmont *mm, mp *m)
   mm->m = MP_COPY(m);
   mm->r = MP_ONE;
   mm->r2 = MP_ONE;
+  mm->mi = MP_ONE;
 }
 
 #else
 
 void mpmont_create(mpmont *mm, mp *m)
 {
+  size_t n = MP_LEN(m);
+  mp *r2 = mp_create(2 * n + 1);
+  mp r;
+
   /* --- Validate the arguments --- */
 
   assert(((void)"Montgomery modulus must be positive",
@@ -108,41 +116,25 @@ void mpmont_create(mpmont *mm, mp *m)
   mp_shrink(m);
   mm->m = MP_COPY(m);
 
-  /* --- Find the magic value @mi@ --- *
-   *
-   * This is a slightly grungy way of solving the problem, but it does work.
-   */
+  /* --- Determine %$R^2$% --- */
 
-  {
-    mpw av[2] = { 0, 1 };
-    mp a, b;
-    mp *i = MP_NEW;
-    mpw mi;
+  mm->n = n;
+  MPX_ZERO(r2->v, r2->vl - 1);
+  r2->vl[-1] = 1;
 
-    mp_build(&a, av, av + 2);
-    mp_build(&b, m->v, m->v + 1);
-    mp_gcd(0, 0, &i, &a, &b);
-    mi = i->v[0];
-    if (!(i->f & MP_NEG))
-      mi = MPW(-mi);
-    mm->mi = mi;
-    MP_DROP(i);
-  }
+  /* --- Find the magic value @mi@ --- */
+
+  mp_build(&r, r2->v + n, r2->vl);
+  mm->mi = MP_NEW;
+  mp_gcd(0, 0, &mm->mi, &r, m);
+  mm->mi = mp_sub(mm->mi, &r, mm->mi);
 
   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
 
-  {
-    size_t l = MP_LEN(m);
-    mp *r = mp_create(2 * l + 1);
-
-    mm->shift = l * MPW_BITS;
-    MPX_ZERO(r->v, r->vl - 1);
-    r->vl[-1] = 1;
-    mm->r2 = MP_NEW;
-    mp_div(0, &mm->r2, r, m);
-    mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
-    MP_DROP(r);
-  }
+  mm->r2 = MP_NEW;
+  mp_div(0, &mm->r2, r2, m);
+  mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
+  MP_DROP(r2);
 }
 
 #endif
@@ -162,6 +154,7 @@ void mpmont_destroy(mpmont *mm)
   MP_DROP(mm->m);
   MP_DROP(mm->r);
   MP_DROP(mm->r2);
+  MP_DROP(mm->mi);
 }
 
 /* --- @mpmont_reduce@ --- *
@@ -185,41 +178,66 @@ mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
 
 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
 {
-  mpw *dv, *dvl;
-  mpw *mv, *mvl;
-  size_t n;
-
-  /* --- Initial conditioning of the arguments --- */
+  size_t n = mm->n;
+
+  /* --- Check for serious Karatsuba reduction --- */
+
+  if (n > KARATSUBA_CUTOFF * 2) {
+    mp al;
+    mpw *vl;
+    mp *u;
+
+    if (MP_LEN(a) >= n)
+      vl = a->v + n;
+    else
+      vl = a->vl;
+    mp_build(&al, a->v, vl);
+    u = mp_mul(MP_NEW, &al, mm->mi);
+    if (MP_LEN(u) > n)
+      u->vl = u->v + n;
+    u = mp_mul(u, u, mm->m);
+    d = mp_add(d, a, u);
+    mp_drop(u);
+  }
 
-  n = MP_LEN(mm->m);
+  /* --- Otherwise do it the hard way --- */
 
-  if (d == a)
-    MP_MODIFY(d, 2 * n + 1);
   else {
-    MP_MODIFY(d, 2 * n + 1);
-    memcpy(d->v, a->v, MPWS(MP_LEN(a)));
-    memset(d->v + MP_LEN(a), 0, MPWS(MP_LEN(d) - MP_LEN(a)));
-  }
+    mpw *dv, *dvl;
+    mpw *mv, *mvl;
+    mpw mi;
+    size_t k = n;
+
+    /* --- Initial conditioning of the arguments --- */
+
+    if (d == a)
+      MP_MODIFY(d, 2 * n + 1);
+    else {
+      MP_MODIFY(d, 2 * n + 1);
+      MPX_COPY(d->v, d->vl, a->v, a->vl);
+    }
     
-  dv = d->v; dvl = d->vl;
-  mv = mm->m->v; mvl = mm->m->vl;
+    dv = d->v; dvl = d->vl;
+    mv = mm->m->v; mvl = mm->m->vl;
 
-  /* --- Let's go to work --- */
+    /* --- Let's go to work --- */
 
-  while (n--) {
-    mpw u = MPW(*dv * mm->mi);
-    MPX_UMLAN(dv, dvl, mv, mvl, u);
-    dv++;
+    mi = mm->mi->v[0];
+    while (k--) {
+      mpw u = MPW(*dv * mi);
+      MPX_UMLAN(dv, dvl, mv, mvl, u);
+      dv++;
+    }
   }
 
-  /* --- Done --- */
+  /* --- Wrap everything up --- */
 
-  memmove(d->v, dv, MPWS(dvl - dv));
-  d->vl -= dv - d->v;
-  MP_SHRINK(d);
   d->f = a->f & MP_BURN;
+  memmove(d->v, d->v + n, MPWS(MP_LEN(d) - n));
+  d->vl -= n;
   if (MP_CMP(d, >=, mm->m))
     d = mp_sub(d, d, mm->m);
+  MP_SHRINK(d);
   return (d);
 }
 
@@ -247,7 +265,7 @@ mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
 
 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
 {
-  if (MP_LEN(a) > KARATSUBA_CUTOFF && MP_LEN(b) > KARATSUBA_CUTOFF) {
+  if (mm->n > KARATSUBA_CUTOFF * 2) {
     d = mp_mul(d, a, b);
     d = mpmont_reduce(mm, d, d);
   } else {
@@ -257,6 +275,7 @@ mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
     mpw *mv, *mvl;
     mpw y;
     size_t n, i;
+    mpw mi;
 
     /* --- Initial conditioning of the arguments --- */
 
@@ -278,9 +297,10 @@ mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
     /* --- Montgomery multiplication phase --- */
 
     i = 0;
+    mi = mm->mi->v[0];
     while (i < n && av < avl) {
       mpw x = *av++;
-      mpw u = MPW((*dv + x * y) * mm->mi);
+      mpw u = MPW((*dv + x * y) * mi);
       MPX_UMLAN(dv, dvl, bv, bvl, x);
       MPX_UMLAN(dv, dvl, mv, mvl, u);
       dv++;
@@ -290,7 +310,7 @@ mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
     /* --- Simpler Montgomery reduction phase --- */
 
     while (i < n) {
-      mpw u = MPW(*dv * mm->mi);
+      mpw u = MPW(*dv * mi);
       MPX_UMLAN(dv, dvl, mv, mvl, u);
       dv++;
       i++;
@@ -392,9 +412,9 @@ static int tcreate(dstr *v)
 
   mpmont_create(&mm, m);
 
-  if (mm.mi != mi->v[0]) {
+  if (mm.mi->v[0] != mi->v[0]) {
     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
-           (unsigned long)mm.mi, (unsigned long)mi->v[0]);
+           (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
     fputc('\n', stderr);
     ok = 0;
index 6b2b9bd..120fa92 100644 (file)
--- a/mpmont.h
+++ b/mpmont.h
@@ -1,6 +1,6 @@
 /* -*-c-*-
  *
- * $Id: mpmont.h,v 1.3 1999/12/10 23:29:48 mdw Exp $
+ * $Id: mpmont.h,v 1.4 1999/12/11 01:51:14 mdw Exp $
  *
  * Montgomery reduction
  *
@@ -30,6 +30,9 @@
 /*----- Revision history --------------------------------------------------* 
  *
  * $Log: mpmont.h,v $
+ * Revision 1.4  1999/12/11 01:51:14  mdw
+ * Use a Karatsuba-based reduction for large moduli.
+ *
  * Revision 1.3  1999/12/10 23:29:48  mdw
  * Change header file guard names.
  *
@@ -95,8 +98,8 @@
 
 typedef struct mpmont {
   mp *m;                               /* Modulus */
-  mpw mi;                              /* %$-m^{-1} \bmod b$% */
-  size_t shift;                                /* %$\log_2 R$% */
+  mp *mi;                              /* %$-m^{-1} \bmod R$% */
+  size_t n;                            /* %$\log_b R$% */
   mp *r, *r2;                          /* %$R \bmod m$%, %$R^2 \bmod m$% */
 } mpmont;