Indentation fix.
[u/mdw/catacomb] / mptext.c
index a55f1c0..f6bc2e3 100644 (file)
--- a/mptext.c
+++ b/mptext.c
@@ -1,6 +1,6 @@
 /* -*-c-*-
  *
- * $Id: mptext.c,v 1.8 2000/12/06 20:32:42 mdw Exp $
+ * $Id: mptext.c,v 1.11 2001/06/16 23:42:17 mdw Exp $
  *
  * Textual representation of multiprecision numbers
  *
 /*----- Revision history --------------------------------------------------* 
  *
  * $Log: mptext.c,v $
+ * Revision 1.11  2001/06/16 23:42:17  mdw
+ * Typesetting fixes.
+ *
+ * Revision 1.10  2001/06/16 13:22:39  mdw
+ * Added fast-track code for binary output bases, and tests.
+ *
+ * Revision 1.9  2001/02/03 16:05:17  mdw
+ * Make flags be unsigned.  Improve the write algorithm: recurse until the
+ * parts are one word long and use single-precision arithmetic from there.
+ * Fix off-by-one bug when breaking the number apart.
+ *
  * Revision 1.8  2000/12/06 20:32:42  mdw
  * Reduce binary bytes (to allow marker bits to be ignored).  Fix error
  * message string a bit.  Allow leading `+' signs.
@@ -74,7 +85,7 @@
  *
  * This is the number of bits in a @size_t@ object.  Why? 
  *
- * To see this, let %$b = \mathit{MPW\_MAX} + 1$% and let %$Z$% be the
+ * To see this, let %$b = \textit{MPW\_MAX} + 1$% and let %$Z$% be the
  * largest @size_t@ value.  Then the largest possible @mp@ is %$M - 1$% where
  * %$M = b^Z$%.  Let %$r$% be a radix to read or write.  Since the recursion
  * squares the radix at each step, the highest number reached by the
@@ -147,10 +158,9 @@ mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
 
   /* --- Flags --- */
 
-  enum {
-    f_neg = 1u,
-    f_ok = 2u
-  };
+#define f_neg 1u
+#define f_ok 2u
+#define f_start 4u
 
   /* --- Initialize the stacks --- */
 
@@ -204,10 +214,119 @@ mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
     r = -1;
   }
 
+  /* --- Use fast algorithm for binary radix --- *
+   *
+   * This is the restart point after having parsed a radix number from the
+   * input.  We check whether the radix is binary, and if so use a fast
+   * algorithm which just stacks the bits up in the right order.
+   */
+
+restart:
+  switch (rd) {
+    unsigned bit;
+
+    case   2: bit = 1; goto bin;
+    case   4: bit = 2; goto bin;
+    case   8: bit = 3; goto bin;
+    case  16: bit = 4; goto bin;
+    case  32: bit = 5; goto bin;
+    case  64: bit = 6; goto bin;
+    case 128: bit = 7; goto bin;
+    default:
+      break;
+
+  /* --- The fast binary algorithm --- *
+   *
+   * We stack bits up starting at the top end of a word.  When one word is
+   * full, we write it to the integer, and start another with the left-over
+   * bits.  When the array in the integer is full, we resize using low-level
+   * calls and copy the current data to the top end.  Finally, we do a single
+   * bit-shift when we know where the end of the number is.
+   */
+
+  bin: {
+    mpw a = 0;
+    unsigned b = MPW_BITS;
+    size_t len, n;
+    mpw *v;
+
+    m = mp_dest(MP_NEW, 1, nf);
+    len = n = m->sz;
+    n = len;
+    v = m->v + n;
+    for (;; ch = ops->get(p)) {
+      unsigned x;
+
+      if (ch < 0)
+       break;
+
+      /* --- Check that the character is a digit and in range --- */
+
+      if (radix < 0)
+       x = ch % rd;
+      else {
+       if (!isalnum(ch))
+         break;
+       if (ch >= '0' && ch <= '9')
+         x = ch - '0';
+       else {
+         ch = tolower(ch);
+         if (ch >= 'a' && ch <= 'z')   /* ASCII dependent! */
+           x = ch - 'a' + 10;
+         else
+           break;
+       }
+      }
+      if (x >= rd)
+       break;
+
+      /* --- Feed the digit into the accumulator --- */
+
+      f |= f_ok;
+      if (!x && !(f & f_start))
+       continue;
+      f |= f_start;
+      if (b > bit) {
+       b -= bit;
+       a |= MPW(x) << b;
+      } else {
+       a |= MPW(x) >> (bit - b);
+       b += MPW_BITS - bit;
+       *--v = MPW(a);
+       n--;
+       if (!n) {
+         n = len;
+         len <<= 1;
+         v = mpalloc(m->a, len);
+         memcpy(v + n, m->v, MPWS(n));
+         mpfree(m->a, m->v);
+         m->v = v;
+         v = m->v + n;
+       }
+       a = (b < MPW_BITS) ? MPW(x) << b : 0;
+      }
+    }
+
+    /* --- Finish up --- */
+
+    if (!(f & f_ok)) {
+      mp_drop(m);
+      m = 0;
+    } else {
+      *--v = MPW(a);
+      n--;
+      m->sz = len;
+      m->vl = m->v + len;
+      m->f &= ~MP_UNDEF;
+      m = mp_lsr(m, m, (unsigned long)n * MPW_BITS + b);
+    }
+    goto done;
+  }}
+
   /* --- Time to start --- */
 
   for (;; ch = ops->get(p)) {
-    int x;
+    unsigned x;
 
     if (ch < 0)
       break;
@@ -231,7 +350,8 @@ mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
       rd = r;
       r = -1;
       f &= ~f_ok;
-      continue;
+      ch = ops->get(p);
+      goto restart;
     }
 
     /* --- Check that the character is a digit and in range --- */
@@ -345,6 +465,7 @@ mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
 
   /* --- Bail out if the number was bad --- */
 
+done:
   if (!(f & f_ok))    
     return (0);
 
@@ -353,6 +474,10 @@ mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
   if (f & f_neg)
     m->f |= MP_NEG;
   return (m);
+
+#undef f_start
+#undef f_neg
+#undef f_ok
 }
 
 /* --- @mp_write@ --- *
@@ -369,14 +494,14 @@ mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
 
 /* --- Simple case --- *
  *
- * Use a fixed-sized buffer and the simple single-precision division
- * algorithm to pick off low-order digits.  Put each digit in a buffer,
- * working backwards from the end.  If the buffer becomes full, recurse to
- * get another one.  Ensure that there are at least @z@ digits by writing
- * leading zeroes if there aren't enough real digits.
+ * Use a fixed-sized buffer and single-precision arithmetic to pick off
+ * low-order digits.  Put each digit in a buffer, working backwards from the
+ * end.  If the buffer becomes full, recurse to get another one.  Ensure that
+ * there are at least @z@ digits by writing leading zeroes if there aren't
+ * enough real digits.
  */
 
-static int simple(mp *m, int radix, unsigned z,
+static int simple(mpw n, int radix, unsigned z,
                  const mptext_ops *ops, void *p)
 {
   int rc = 0;
@@ -388,36 +513,34 @@ static int simple(mp *m, int radix, unsigned z,
     int ch;
     mpw x;
 
-    x = mpx_udivn(m->v, m->vl, m->v, m->vl, rd);
-    MP_SHRINK(m);
+    x = n % rd;
+    n /= rd;
     if (radix < 0)
       ch = x;
-    else {
-      if (x < 10)
-       ch = '0' + x;
-      else
-       ch = 'a' + x - 10;
-    }
+    else if (x < 10)
+      ch = '0' + x;
+    else
+      ch = 'a' + x - 10;
     buf[--i] = ch;
     if (z)
       z--;
-  } while (i && MP_LEN(m));
+  } while (i && n);
 
-  if (MP_LEN(m))
-    rc = simple(m, radix, z, ops, p);
+  if (n)
+    rc = simple(n, radix, z, ops, p);
   else {
-    static const char zero[32] = "00000000000000000000000000000000";
-    while (!rc && z >= sizeof(zero)) {
-      rc = ops->put(zero, sizeof(zero), p);
-      z -= sizeof(zero);
+    char zbuf[32];
+    memset(zbuf, (radix < 0) ? 0 : '0', sizeof(zbuf));
+    while (!rc && z >= sizeof(zbuf)) {
+      rc = ops->put(zbuf, sizeof(zbuf), p);
+      z -= sizeof(zbuf);
     }
     if (!rc && z)
-      rc = ops->put(zero, z, p);
+      rc = ops->put(zbuf, z, p);
   }
   if (!rc)
-    ops->put(buf + i, sizeof(buf) - i, p);
-  if (m->f & MP_BURN)
-    BURN(buf);
+    rc = ops->put(buf + i, sizeof(buf) - i, p);
+  BURN(buf);
   return (rc);
 }
 
@@ -436,9 +559,10 @@ static int complicated(mp *m, int radix, mp **pr, unsigned i, unsigned z,
   mp *q = MP_NEW;
   unsigned d = 1 << i;
 
-  if (MP_LEN(m) < 8)
-    return (simple(m, radix, z, ops, p));
+  if (MP_LEN(m) < 2)
+    return (simple(MP_LEN(m) ? m->v[0] : 0, radix, z, ops, p));
 
+  assert(i);
   mp_div(&q, &m, m, pr[i]);
   if (!MP_LEN(q))
     d = z;
@@ -455,6 +579,94 @@ static int complicated(mp *m, int radix, mp **pr, unsigned i, unsigned z,
   return (rc);
 }
 
+/* --- Binary case --- *
+ *
+ * Special case for binary output.  Goes much faster.
+ */
+
+static int binary(mp *m, int bit, int radix, const mptext_ops *ops, void *p)
+{
+  mpw *v;
+  mpw a;
+  int rc = 0;
+  unsigned b;
+  unsigned mask;
+  unsigned long n;
+  unsigned f = 0;
+  char buf[8], *q;
+  unsigned x;
+  int ch;
+
+#define f_out 1u
+
+  /* --- Work out where to start --- */
+
+  n = mp_bits(m);
+  n += bit - (n % bit);
+  b = n % MPW_BITS;
+  n /= MPW_BITS;
+  
+  if (n > MP_LEN(m)) {
+    n--;
+    b += MPW_BITS;
+  }
+
+  v = m->v + n;
+  a = *v;
+  mask = (1 << bit) - 1;
+  q = buf;
+
+  /* --- Main code --- */
+
+  for (;;) {
+    if (b > bit) {
+      b -= bit;
+      x = a >> b;
+    } else {
+      x = a << (bit - b);
+      b += MPW_BITS - bit;
+      if (v == m->v)
+       break;
+      a = *--v;
+      if (b < MPW_BITS)
+       x |= a >> b;
+    }
+    x &= mask;
+    if (!x && !(f & f_out))
+      continue;
+
+    if (radix < 0)
+      ch = x;
+    else if (x < 10)
+      ch = '0' + x;
+    else
+      ch = 'a' + x - 10;
+    *q++ = ch;
+    if (q >= buf + sizeof(buf)) {
+      if ((rc = ops->put(buf, sizeof(buf), p)) != 0)
+       goto done;
+      q = buf;
+    }
+    f |= f_out;
+  }
+
+  x &= mask;
+  if (radix < 0)
+    ch = x;
+  else if (x < 10)
+    ch = '0' + x;
+  else
+    ch = 'a' + x - 10;
+  *q++ = ch;
+  rc = ops->put(buf, q - buf, p);
+
+done:
+  mp_drop(m);
+  return (rc);
+
+#undef f_out
+}
+
 /* --- Main driver code --- */
 
 int mp_write(mp *m, int radix, const mptext_ops *ops, void *p)
@@ -483,10 +695,22 @@ int mp_write(mp *m, int radix, const mptext_ops *ops, void *p)
     m->f &= ~MP_NEG;
   }
 
+  /* --- Handle binary radix --- */
+
+  switch (radix) {
+    case   2: case   -2: return (binary(m, 1, radix, ops, p));
+    case   4: case   -4: return (binary(m, 2, radix, ops, p));
+    case   8: case   -8: return (binary(m, 3, radix, ops, p));
+    case  16: case  -16: return (binary(m, 4, radix, ops, p));
+    case  32: case  -32: return (binary(m, 5, radix, ops, p));
+              case  -64: return (binary(m, 6, radix, ops, p));
+              case -128: return (binary(m, 7, radix, ops, p));
+  }
+
   /* --- If the number is small, do it the easy way --- */
 
-  if (MP_LEN(m) < 8)
-    rc = simple(m, radix, 0, ops, p);
+  if (MP_LEN(m) < 2)
+    rc = simple(MP_LEN(m) ? m->v[0] : 0, radix, 0, ops, p);
 
   /* --- Use a clever algorithm --- *
    *
@@ -499,7 +723,7 @@ int mp_write(mp *m, int radix, const mptext_ops *ops, void *p)
 
   else {
     mp *pr[DEPTH];
-    size_t target = MP_LEN(m) / 2;
+    size_t target = (MP_LEN(m) + 1) / 2;
     unsigned i = 0;
     mp *z = mp_new(1, 0);
 
@@ -546,7 +770,7 @@ static int verify(dstr *v)
   if (m) {
     if (!ob) {
       fprintf(stderr, "*** unexpected successful parse\n"
-                     "*** input [%i] = ", ib);
+                     "*** input [%2i] =     ", ib);
       if (ib < 0)
        type_hex.dump(&v[1], stderr);
       else
@@ -558,17 +782,17 @@ static int verify(dstr *v)
       mp_writedstr(m, &d, ob);
       if (d.len != v[3].len || memcmp(d.buf, v[3].buf, d.len) != 0) {
        fprintf(stderr, "*** failed read or write\n"
-                       "*** input [%i]    = ", ib);
+                       "*** input [%2i]      = ", ib);
        if (ib < 0)
          type_hex.dump(&v[1], stderr);
        else
          fputs(v[1].buf, stderr);
-       fprintf(stderr, "\n*** output [%i]   = ", ob);
+       fprintf(stderr, "\n*** output [%2i]     = ", ob);
        if (ob < 0)
          type_hex.dump(&d, stderr);
        else
          fputs(d.buf, stderr);
-       fprintf(stderr, "\n*** expected [%i]   = ", ob);
+       fprintf(stderr, "\n*** expected [%2i]   = ", ob);
        if (ob < 0)
          type_hex.dump(&v[3], stderr);
        else