Multiprecision routines finished and tested.
authormdw <mdw>
Sat, 13 Nov 1999 01:50:59 +0000 (01:50 +0000)
committermdw <mdw>
Sat, 13 Nov 1999 01:50:59 +0000 (01:50 +0000)
mpx.c

diff --git a/mpx.c b/mpx.c
index 7ac5a8b..b8bc8bf 100644 (file)
--- a/mpx.c
+++ b/mpx.c
@@ -1,6 +1,6 @@
 /* -*-c-*-
  *
- * $Id: mpx.c,v 1.1 1999/09/03 08:41:12 mdw Exp $
+ * $Id: mpx.c,v 1.2 1999/11/13 01:50:59 mdw Exp $
  *
  * Low-level multiprecision arithmetic
  *
@@ -30,6 +30,9 @@
 /*----- Revision history --------------------------------------------------* 
  *
  * $Log: mpx.c,v $
+ * Revision 1.2  1999/11/13 01:50:59  mdw
+ * Multiprecision routines finished and tested.
+ *
  * Revision 1.1  1999/09/03 08:41:12  mdw
  * Initial import.
  *
@@ -37,6 +40,7 @@
 
 /*----- Header files ------------------------------------------------------*/
 
+#include <assert.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
@@ -51,7 +55,7 @@
 /* --- @mpx_storel@ --- *
  *
  * Arguments:  @const mpw *v, *vl@ = base and limit of source vector
- *             @octet *p@ = pointer to octet array
+ *             @void *pp@ = pointer to octet array
  *             @size_t sz@ = size of octet array
  *
  * Returns:    ---
  *             isn't enough space for them.
  */
 
-void mpx_storel(const mpw *v, const mpw *vl, octet *p, size_t sz)
+void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
 {
   mpw n, w = 0;
-  octet *q = p + sz;
+  octet *p = pp, *q = p + sz;
   unsigned bits = 0;
 
   while (p < q) {
@@ -89,7 +93,7 @@ void mpx_storel(const mpw *v, const mpw *vl, octet *p, size_t sz)
 /* --- @mpx_loadl@ --- *
  *
  * Arguments:  @mpw *v, *vl@ = base and limit of destination vector
- *             @const octet *p@ = pointer to octet array
+ *             @const void *pp@ = pointer to octet array
  *             @size_t sz@ = size of octet array
  *
  * Returns:    ---
@@ -99,10 +103,11 @@ void mpx_storel(const mpw *v, const mpw *vl, octet *p, size_t sz)
  *             space for them.
  */
 
-void mpx_loadl(mpw *v, const mpw *vl, const octet *p, size_t sz)
+void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
 {
   unsigned n;
-  const octet *q = p + sz;
+  mpw w = 0;
+  const octet *p = pp, *q = p + sz;
   unsigned bits = 0;
 
   if (v >= vl)
@@ -126,7 +131,7 @@ void mpx_loadl(mpw *v, const mpw *vl, const octet *p, size_t sz)
 /* --- @mpx_storeb@ --- *
  *
  * Arguments:  @const mpw *v, *vl@ = base and limit of source vector
- *             @octet *p@ = pointer to octet array
+ *             @void *pp@ = pointer to octet array
  *             @size_t sz@ = size of octet array
  *
  * Returns:    ---
@@ -136,10 +141,10 @@ void mpx_loadl(mpw *v, const mpw *vl, const octet *p, size_t sz)
  *             isn't enough space for them.
  */
 
-void mpx_storeb(const mpw *v, const mpw *vl, octet *p, size_t sz);
+void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
 {
   mpw n, w = 0;
-  octet *q = p + sz;
+  octet *p = pp, *q = p + sz;
   unsigned bits = 0;
 
   while (q > p) {
@@ -164,7 +169,7 @@ void mpx_storeb(const mpw *v, const mpw *vl, octet *p, size_t sz);
 /* --- @mpx_loadb@ --- *
  *
  * Arguments:  @mpw *v, *vl@ = base and limit of destination vector
- *             @const octet *p@ = pointer to octet array
+ *             @const void *pp@ = pointer to octet array
  *             @size_t sz@ = size of octet array
  *
  * Returns:    ---
@@ -174,10 +179,11 @@ void mpx_storeb(const mpw *v, const mpw *vl, octet *p, size_t sz);
  *             space for them.
  */
 
-void mpx_loadb(mpw *v, const mpw *vl, const octet *p, size_t sz)
+void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
 {
   unsigned n;
-  const octet *q = p + sz;
+  mpw w = 0;
+  const octet *p = pp, *q = p + sz;
   unsigned bits = 0;
 
   if (v >= vl)
@@ -237,6 +243,7 @@ void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
       goto done;
     *dv++ = MPW(w);
     MPX_ZERO(dv, dvl);
+    goto done;
   }
 
   /* --- Break out word and bit shifts for more sophisticated work --- */
@@ -251,33 +258,42 @@ void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
     memset(dv, 0, MPWS(nw));
   }
 
-  /* --- And finally the difficult case --- */
+  /* --- And finally the difficult case --- *
+   *
+   * This is a little convoluted, because I have to start from the end and
+   * work backwards to avoid overwriting the source, if they're both the same
+   * block of memory.
+   */
 
   else {
     mpw w;
     size_t nr = MPW_BITS - nb;
+    size_t dvn = dvl - dv;
+    size_t avn = avl - av;
 
-    if (dv + nw >= dvl) {
+    if (dvn <= nw) {
       MPX_ZERO(dv, dvl);
       goto done;
     }
-    memset(dv, 0, MPWS(nw));
-    dv += nw;
-    w = *av++;
 
-    while (av < avl) {
-      mpw t;
-      if (dv >= dvl)
-       goto done;
-      t = *av++;
-      *dv++ = MPW((w >> nr) | (t << nb));
-      w = t;
+    if (dvn > avn + nw) {
+      size_t off = avn + nw + 1;
+      MPX_ZERO(dv + off, dvl);
+      dvl = dv + off;
+      w = 0;
+    } else {
+      avl = av + dvn - nw;
+      w = *--avl << nb;
     }
 
-    if (dv < dvl) {
-      *dv++ = MPW(w >> nr);
-      MPX_ZERO(dv, dvl);
+    while (avl > av) {
+      mpw t = *--avl;
+      *--dvl = (t >> nr) | w;
+      w = t << nb;
     }
+
+    *--dvl = w;
+    MPX_ZERO(dv, dvl);
   }
 
 done:;
@@ -320,6 +336,7 @@ void mpx_lsr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
       goto done;
     *dv++ = MPW(w);
     MPX_ZERO(dv, dvl);
+    goto done;
   }
 
   /* --- Break out word and bit shifts for more sophisticated work --- */
@@ -457,14 +474,17 @@ void mpx_usub(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
       return;
     a = (av < avl) ? *av++ : 0;
     b = (bv < bvl) ? *bv++ : 0;
-    x = (mpd)a - (mpd)b + c;
+    x = (mpd)a - (mpd)b - c;
     *dv++ = MPW(x);
-    if (c >> MPW_BITS)
-      c = MPW(~0u);
+    if (x >> MPW_BITS)
+      c = 1;
+    else
+      c = 0;
   }
-  c = c ? ~0u : 0;
+  if (c)
+    c = MPW_MAX;
   while (dv < dvl)
-    *dv++ = c
+    *dv++ = c;
 }
 
 /* --- @mpx_umul@ --- *
@@ -492,7 +512,7 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
   /* --- Deal with a multiply by zero --- */
   
   if (bv == bvl) {
-    MPX_COPY(dv, dvl, bv, bvl);
+    MPX_ZERO(dv, dvl);
     return;
   }
 
@@ -502,9 +522,9 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
 
   /* --- Do the remaining multiply/accumulates --- */
 
-  while (bv < bvl) {
+  while (dv < dvl && bv < bvl) {
     mpw m = *bv++;
-    mpw c = ;
+    mpw c = 0;
     const mpw *avv = av;
     mpw *dvv = ++dv;
 
@@ -512,21 +532,81 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
       mpd x;
       if (dvv >= dvl)
        goto next;
-      x = *dvv + m * *av++ + c;
-      *dv++ = MPW(x);
+      x = (mpd)*dvv + (mpd)m * (mpd)*avv++ + c;
+      *dvv++ = MPW(x);
       c = x >> MPW_BITS;
     }
-    if (dvv < dvl)
-      *dvv++ = MPW(c);
+    MPX_UADDN(dvv, dvl, c);
   next:;
   }
 }
 
+/* --- @mpx_usqr@ --- *
+ *
+ * Arguments:  @mpw *dv, *dvl@ = destination vector base and limit
+ *             @const mpw *av, *av@ = source vector base and limit
+ *
+ * Returns:    ---
+ *
+ * Use:                Performs unsigned integer squaring.  The result vector must
+ *             not overlap the source vector in any way.
+ */
+
+void mpx_usqr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl)
+{
+  MPX_ZERO(dv, dvl);
+
+  /* --- Main loop --- */
+
+  while (av < avl) {
+    const mpw *avv = av;
+    mpw *dvv = dv;
+    mpw a = *av;
+    mpd c;
+
+    /* --- Stop if I've run out of destination --- */
+
+    if (dvv >= dvl)
+      break;
+
+    /* --- Work out the square at this point in the proceedings --- */
+
+    {
+      mpw d = *dvv;
+      mpd x = (mpd)a * (mpd)a + *dvv;
+      *dvv++ = MPW(x);
+      c = MPW(x >> MPW_BITS);
+    }
+
+    /* --- Now fix up the rest of the vector upwards --- */
+
+    avv++;
+    while (dvv < dvl && avv < avl) {
+      mpw aa = *avv;
+      mpd x = (mpd)a * (mpd)*avv++;
+      mpd y = ((x << 1) & MPW_MAX) + c + *dvv;
+      c = (x >> (MPW_BITS - 1)) + (y >> MPW_BITS);
+      *dvv++ = MPW(y);
+    }
+    while (dvv < dvl && c) {
+      mpd x = c + *dvv;
+      *dvv++ = MPW(x);
+      c = x >> MPW_BITS;
+    }
+
+    /* --- Get ready for the next round --- */
+
+    av++;
+    dv += 2;
+  }
+}
+
 /* --- @mpx_udiv@ --- *
  *
  * Arguments:  @mpw *qv, *qvl@ = quotient vector base and limit
  *             @mpw *rv, *rvl@ = dividend/remainder vector base and limit
  *             @const mpw *dv, *dvl@ = divisor vector base and limit
+ *             @mpw *sv, *svl@ = scratch workspace
  *
  * Returns:    ---
  *
@@ -536,13 +616,15 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
  *             may not overlap in any way.  Yes, I know it's a bit odd
  *             requiring the dividend to be in the result position but it
  *             does make some sense really.  The remainder must have
- *             headroom for at least two extra words.
+ *             headroom for at least two extra words.  The scratch space
+ *             must be at least two words larger than twice the size of the
+ *             divisor.
  */
 
 void mpx_udiv(mpw *qv, mpw *qvl, mpw *rv, mpw *rvl,
-             const mpw *dv, const mpw *dvl)
+             const mpw *dv, const mpw *dvl,
+             mpw *sv, mpw *svl)
 {
-  mpw spare[2];
   unsigned norm = 0;
   size_t scale;
   mpw d, dd;
@@ -551,41 +633,52 @@ void mpx_udiv(mpw *qv, mpw *qvl, mpw *rv, mpw *rvl,
 
   MPX_ZERO(qv, qvl);
 
+  /* --- Perform some sanity checks --- */
+
+  MPX_SHRINK(dv, dvl);
+  assert(((void)"division by zero in mpx_udiv", dv < dvl));
+
   /* --- Normalize the divisor --- *
    *
    * The algorithm requires that the divisor be at least two digits long.
    * This is easy to fix.
    */
 
-  MPX_SHRINK(dv, dvl);
-
-  assert(((void)"division by zero in mpx_udiv", dv < dvl));
-
-  d = dvl[-1];
-  if (dv + 1 == dvl) {
-    spare[0] = 0;
-    spare[1] = d;
-    dv = spare;
-    dvl = spare + 2;
-    norm += MPW_BITS;
-  }
+  {
+    unsigned b;
 
-  while (d < MPW_MAX / 2) {
-    d <<= 1;
-    norm += 1;
+    d = dvl[-1];
+    for (b = MPW_BITS / 2; b; b >>= 1) {
+      if (d < (MPW_MAX >> b)) {
+       d <<= b;
+       norm += b;
+      }
+    }
+    if (dv + 1 == dvl)
+      norm += MPW_BITS;
   }
-  dd = dvl[-2];
 
   /* --- Normalize the dividend/remainder to match --- */
 
-  mpx_lsl(rv, rvl, rv, rvl, norm);
+  if (norm) {
+    mpw *svvl = sv + (dvl - dv) + 1;
+    mpx_lsl(rv, rvl, rv, rvl, norm);
+    mpx_lsl(sv, svvl, dv, dvl, norm);
+    dv = sv;
+    sv = svvl;
+    dvl = svvl;
+    MPX_SHRINK(dv, dvl);
+  }
+
   MPX_SHRINK(rv, rvl);
+  d = dvl[-1];
+  dd = dvl[-2];
 
   /* --- Work out the relative scales --- */
 
   {
     size_t rvn = rvl - rv;
-    size_t dvn = dvn - dv;
+    size_t dvn = dvl - dv;
 
     /* --- If the divisor is clearly larger, notice this --- */
 
@@ -613,32 +706,89 @@ void mpx_udiv(mpw *qv, mpw *qvl, mpw *rv, mpw *rvl,
   /* --- Now for the main loop --- */
 
   {
-    mpw *rvv;
-    mpw r;
-
-    scale--;
-    rvv = rvl - 2;
-    r = rvv[1];
+    mpw *rvv = rvl - 2;
 
     while (scale) {
-      mpw q, rr;
+      mpw q;
+      mpd rh;
 
       /* --- Get an estimate for the next quotient digit --- */
 
-      rr = *rvv--;
+      mpw r = rvv[1];
+      mpw rr = rvv[0];
+      mpw rrr = *--rvv;
+
+      scale--;
+      rh = ((mpd)r << MPW_BITS) | rr;
       if (r == d)
        q = MPW_MAX;
-      else {
-       mpd rx = (r << MPW_BITS) | rr;
-       q = MPW(rx / d);
-      }
+      else
+       q = MPW(rh / d);
 
       /* --- Refine the estimate --- */
 
       {
        mpd yh = (mpd)d * q;
        mpd yl = (mpd)dd * q;
-       
+
+       if (yl > MPW_MAX) {
+         yh += yl >> MPW_BITS;
+         yl &= MPW_MAX;
+       }
+
+       while (yh > rh || (yh == rh && yl > rrr)) {
+         q--;
+         yh -= d;
+         if (yl < dd) {
+           yh++;
+           yl += MPW_MAX;
+         }
+         yl -= dd;
+       }
+      }
+
+      /* --- Remove a chunk from the dividend --- */
+
+      {
+       mpw *svv;
+       const mpw *dvv;
+       mpw c = 0;
+
+       /* --- Calculate the size of the chunk --- */
+
+       for (svv = sv, dvv = dv; dvv < dvl; svv++, dvv++) {
+         mpd x = (mpd)*dvv * (mpd)q + c;
+         *svv = MPW(x);
+         c = x >> MPW_BITS;
+       }
+       if (c)
+         *svv++ = c;
+
+       /* --- Now make sure that we can cope with the difference --- *
+        *
+        * Take advantage of the fact that subtraction works two's-
+        * complement.
+        */
+
+       mpx_usub(rv + scale, rvl, rv + scale, rvl, sv, svv);
+       if (rvl[-1] > MPW_MAX / 2) {
+         mpx_uadd(rv + scale, rvl, rv + scale, rvl, dv, dvl);
+         q--;
+       }
+      }
+
+      /* --- Done for another iteration --- */
+
+      if (qvl - qv > scale)
+       qv[scale] = q;
+      r = rr;
+      rr = rrr;
+    }
+  }
+
+  /* --- Now fiddle with unnormalizing and things --- */
+
+  mpx_lsr(rv, rvl, rv, rvl, norm);
 }
 
 /*----- That's all, folks -------------------------------------------------*/