math/mptext.c: Radically refactor `mp_read'.
authorMark Wooding <mdw@distorted.org.uk>
Wed, 14 Oct 2015 10:00:51 +0000 (11:00 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Wed, 14 Oct 2015 16:08:17 +0000 (17:08 +0100)
It used to be the largest function in the library -- possibly in my
codebase.

  * Split it into three main pieces: the special-purpose binary reader,
    an efficient stack-based general-radix reader, and a high-level
    syntax parser which picks out signs and base indicators.  This
    removes the complicated entangling of the base indicator parsing
    with the general-radix reader which was the worst feature of the old
    version.

  * Split commonly-used functionality out into separate functions,
    notably `char_digit' and `read_digit'.

The result is code which is easier to understand and actually shorter.

math/mptext.c

index ddf5013..c902264 100644 (file)
  * bizarre syntax.
  */
 
-mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
+static int char_digit(int ch, int radix)
 {
-  int ch;                              /* Current char being considered */
-  unsigned f = 0;                      /* Flags about the current number */
-  int r;                               /* Radix to switch over to */
-  mpw rd;                              /* Radix as an @mp@ digit */
-  mp rr;                               /* The @mp@ for the radix */
-  unsigned nf = m ? m->f & MP_BURN : 0;        /* New @mp@ flags */
-
-  /* --- Stacks --- */
-
-  mp *pow[DEPTH];                      /* List of powers */
-  unsigned pows;                       /* Next index to fill */
-  struct { unsigned i; mp *m; } s[DEPTH]; /* Main stack */
-  unsigned sp;                         /* Current stack pointer */
-
-  /* --- Flags --- */
-
-#define f_neg 1u
-#define f_ok 2u
-#define f_start 4u
-
-  /* --- Initialize the stacks --- */
-
-  mp_build(&rr, &rd, &rd + 1);
-  pow[0] = &rr;
-  pows = 1;
-
-  sp = 0;
-
-  /* --- Initialize the destination number --- */
-
-  if (m)
-    MP_DROP(m);
-
-  /* --- Read an initial character --- */
-
-  ch = ops->get(p);
-  if (radix >= 0) {
-    while (isspace(ch))
-      ch = ops->get(p);
-  }
-
-  /* --- Handle an initial sign --- */
-
-  if (radix >= 0 && (ch == '-' || ch == '+')) {
-    if (ch == '-')
-      f |= f_neg;
-    do ch = ops->get(p); while isspace(ch);
-  }
-
-  /* --- If the radix is zero, look for leading zeros --- */
-
-  if (radix > 0) {
-    assert(((void)"ascii radix must be <= 62", radix <= 62));
-    rd = radix;
-    r = -1;
-  } else if (radix < 0) {
-    rd = -radix;
-    assert(((void)"binary radix must fit in a byte", rd <= UCHAR_MAX));
-    r = -1;
-  } else if (ch != '0') {
-    rd = 10;
-    r = 0;
-  } else {
-    ch = ops->get(p);
-    switch (ch) {
-      case 'x':
-       rd = 16;
-       goto prefix;
-      case 'o':
-       rd = 8;
-       goto prefix;
-      case 'b':
-       rd = 2;
-       goto prefix;
-      prefix:
-       ch = ops->get(p);
-       break;
-      default:
-       rd = 8;
-       f |= f_ok;
-    }
-    r = -1;
-  }
-
-  /* --- Use fast algorithm for binary radix --- *
-   *
-   * This is the restart point after having parsed a radix number from the
-   * input.  We check whether the radix is binary, and if so use a fast
-   * algorithm which just stacks the bits up in the right order.
-   */
+  int r = radix < 0 ? -radix : radix;
+  int d;
+
+  if (ch < 0) return (-1);
+  if (radix < 0) d = ch;
+  else if ('0' <= ch && ch <= '9') d = ch - '0';
+  else if ('a' <= ch && ch <= 'z') d = ch - 'a' + 10;
+  else if ('A' <= ch && ch <= 'Z') d = ch - 'A' + (radix > 36 ? 36 : 10);
+  else return (-1);
+  if (d >= r) return (-1);
+  return (d);
+}
 
-restart:
-  switch (rd) {
-    unsigned bit;
-
-    case   2: bit = 1; goto bin;
-    case   4: bit = 2; goto bin;
-    case   8: bit = 3; goto bin;
-    case  16: bit = 4; goto bin;
-    case  32: bit = 5; goto bin;
-    case  64: bit = 6; goto bin;
-    case 128: bit = 7; goto bin;
-    default:
-      break;
+static mp *read_binary(int radix, unsigned bit, unsigned nf,
+                      const mptext_ops *ops, void *p)
+{
+  mpw a = 0;
+  unsigned b = MPW_BITS;
+  int any = 0, nz = 0;
+  int ch, d;
+  size_t len, n;
+  mpw *v;
+  mp *m;
 
   /* --- The fast binary algorithm --- *
    *
@@ -212,246 +131,293 @@ restart:
    * bit-shift when we know where the end of the number is.
    */
 
-  bin: {
-    mpw a = 0;
-    unsigned b = MPW_BITS;
-    size_t len, n;
-    mpw *v;
-
-    m = mp_dest(MP_NEW, 1, nf);
-    len = n = m->sz;
-    n = len;
-    v = m->v + n;
-    for (;; ch = ops->get(p)) {
-      unsigned x;
+  m = mp_dest(MP_NEW, 1, nf);
+  len = n = m->sz;
+  n = len;
+  v = m->v + n;
 
-      if (ch < 0)
-       break;
+  for (;;) {
+    ch = ops->get(p);
+    if ((d = char_digit(ch, radix)) < 0) break;
 
-      /* --- Check that the character is a digit and in range --- */
+    /* --- Ignore leading zeroes, but notice that the number is valid --- */
 
-      if (radix < 0)
-       x = ch % rd;
-      else {
-       if (!isalnum(ch))
-         break;
-       if (ch >= '0' && ch <= '9')
-         x = ch - '0';
-       else {
-         if (rd <= 36)
-           ch = tolower(ch);
-         if (ch >= 'a' && ch <= 'z')   /* ASCII dependent! */
-           x = ch - 'a' + 10;
-         else if (ch >= 'A' && ch <= 'Z')
-           x = ch - 'A' + 36;
-         else
-           break;
-       }
-      }
-      if (x >= rd)
-       break;
-
-      /* --- Feed the digit into the accumulator --- */
-
-      f |= f_ok;
-      if (!x && !(f & f_start))
-       continue;
-      f |= f_start;
-      if (b > bit) {
-       b -= bit;
-       a |= MPW(x) << b;
-      } else {
-       a |= MPW(x) >> (bit - b);
-       b += MPW_BITS - bit;
-       *--v = MPW(a);
-       n--;
-       if (!n) {
-         n = len;
-         len <<= 1;
-         v = mpalloc(m->a, len);
-         memcpy(v + n, m->v, MPWS(n));
-         mpfree(m->a, m->v);
-         m->v = v;
-         v = m->v + n;
-       }
-       a = (b < MPW_BITS) ? MPW(x) << b : 0;
-      }
-    }
+    any = 1;
+    if (!d && !nz) continue;
+    nz = 1;
 
-    /* --- Finish up --- */
+    /* --- Feed the digit into the accumulator --- */
 
-    if (!(f & f_ok)) {
-      mp_drop(m);
-      m = 0;
+    if (b > bit) {
+      b -= bit;
+      a |= MPW(d) << b;
     } else {
-      *--v = MPW(a);
-      n--;
-      m->sz = len;
-      m->vl = m->v + len;
-      m->f &= ~MP_UNDEF;
-      m = mp_lsr(m, m, (unsigned long)n * MPW_BITS + b);
+      a |= MPW(d) >> (bit - b);
+      b += MPW_BITS - bit;
+      *--v = MPW(a); n--;
+      if (!n) {
+       n = len; len <<= 1;
+       v = mpalloc(m->a, len);
+       memcpy(v + n, m->v, MPWS(n));
+       mpfree(m->a, m->v);
+       m->v = v; v = m->v + n;
+      }
+      a = (b < MPW_BITS) ? MPW(d) << b : 0;
     }
-    ops->unget(ch, p);
-    goto done;
-  }}
+  }
 
-  /* --- Time to start --- */
+  /* --- Finish up --- */
 
-  for (;; ch = ops->get(p)) {
-    unsigned x;
+  ops->unget(ch, p);
+  if (!any) { mp_drop(m); return (0); }
 
-    if (ch < 0)
-      break;
+  *--v = MPW(a); n--;
+  m->sz = len;
+  m->vl = m->v + len;
+  m->f &= ~MP_UNDEF;
+  m = mp_lsr(m, m, (unsigned long)n * MPW_BITS + b);
 
-    /* --- An underscore indicates a numbered base --- */
+  return (m);
+}
 
-    if (ch == '_' && r > 0 && r <= 62) {
-      unsigned i;
+struct readstate {
 
-      /* --- Clear out the stacks --- */
+  /* --- State for the general-base reader --- *
+   *
+   * There are two arrays.  The @pow@ array is set so that @pow[i]@ contains
+   * %$R^{2^i}$% for @i < pows@.  The stack @s@ contains partial results:
+   * each entry contains a value @m@ corresponding to %$2^i$% digits.
+   * Inductively, an empty stack represents zero; if a stack represents %$x$%
+   * then pushing a new entry on the top causes the stack to represent
+   * %$R^{2^i} x + m$%.
+   *
+   * It is an invariant that each entry has a strictly smaller @i@ than the
+   * items beneath it.  This is achieved by coaslescing entries at the top if
+   * they have equal %$i$% values: if the top items are %$(m, i)$%, and
+   * %$(M', i)$%, and the rest of the stack represents the integer %$x$%,
+   * then %$R^{2^i} (R^{2^i} x + M) + m = R^{2^{i+1}} x + (R^{2^i} M + m)$%,
+   * so we replace the top two items by %$((R^{2^i} M + m), i + 1)$%, and
+   * repeat if necessary.
+   */
 
-      for (i = 1; i < pows; i++)
-       MP_DROP(pow[i]);
-      pows = 1;
-      for (i = 0; i < sp; i++)
-       MP_DROP(s[i].m);
-      sp = 0;
+  unsigned pows, sp;
+  struct { unsigned i; mp *m; } s[DEPTH];
+  mp *pow[DEPTH];
+};
 
-      /* --- Restart the search --- */
+static void ensure_power(struct readstate *rs)
+{
+  /* --- Make sure we have the necessary %$R^{2^i}$% computed --- */
 
-      rd = r;
-      r = -1;
-      f &= ~f_ok;
-      ch = ops->get(p);
-      goto restart;
-    }
+  if (rs->s[rs->sp].i >= rs->pows) {
+    assert(rs->pows < DEPTH);
+    rs->pow[rs->pows] = mp_sqr(MP_NEW, rs->pow[rs->pows - 1]);
+    rs->pows++;
+  }
+}
 
-    /* --- Check that the character is a digit and in range --- */
+static void read_digit(struct readstate *rs, unsigned nf, int d)
+{
+  mp *m = mp_new(1, nf);
+  m->v[0] = d;
 
-    if (radix < 0)
-      x = ch % rd;
-    else {
-      if (!isalnum(ch))
-       break;
-      if (ch >= '0' && ch <= '9')
-       x = ch - '0';
-      else {
-       if (rd <= 36)
-         ch = tolower(ch);
-       if (ch >= 'a' && ch <= 'z')     /* ASCII dependent! */
-         x = ch - 'a' + 10;
-       else if (ch >= 'A' && ch <= 'Z')
-         x = ch - 'A' + 36;
-       else
-         break;
-      }
-    }
+  /* --- Put the new digit on top --- */
 
-    /* --- Sort out what to do with the character --- */
+  assert(rs->sp < DEPTH);
+  rs->s[rs->sp].m = m;
+  rs->s[rs->sp].i = 0;
 
-    if (x >= 10 && r >= 0)
-      r = -1;
-    if (x >= rd)
-      break;
+  /* --- Restore the stack invariant --- */
 
-    if (r >= 0)
-      r = r * 10 + x;
+  while (rs->sp && rs->s[rs->sp - 1].i <= rs->s[rs->sp].i) {
+    assert(rs->sp > 0);
+    ensure_power(rs);
+    rs->sp--;
 
-    /* --- Stick the character on the end of my integer --- */
+    m = rs->s[rs->sp].m;
+    m = mp_mul(m, m, rs->pow[rs->s[rs->sp + 1].i]);
+    m = mp_add(m, m, rs->s[rs->sp + 1].m);
+    MP_DROP(rs->s[rs->sp + 1].m);
+    rs->s[rs->sp].m = m;
+    rs->s[rs->sp].i++;
+  }
 
-    assert(((void)"Number is too unimaginably huge", sp < DEPTH));
-    s[sp].m = m = mp_new(1, nf);
-    m->v[0] = x;
-    s[sp].i = 0;
+  /* --- Leave the stack pointer at an empty item --- */
 
-    /* --- Now grind through the stack --- */
+  rs->sp++;
+}
 
-    while (sp > 0 && s[sp - 1].i == s[sp].i) {
+static mp *read_general(int radix, unsigned t, unsigned nf,
+                       const mptext_ops *ops, void *p)
+{
+  struct readstate rs;
+  unsigned char v[4];
+  unsigned i;
+  mpw r;
+  int any = 0;
+  int ch, d;
+  mp rr;
+  mp *m, *z, *n;
+
+  /* --- Prepare the stack --- */
+
+  r = radix < 0 ? -radix : radix;
+  mp_build(&rr, &r, &r + 1);
+  rs.pow[0] = &rr;
+  rs.pows = 1;
+  rs.sp = 0;
+
+  /* --- If we've partially parsed some input then feed it in --- *
+   *
+   * Unfortunately, what we've got is backwards.  Fortunately there's a
+   * fairly tight upper bound on how many digits @t@ might be, since we
+   * aborted that loop once it got too large.
+   */
 
-      /* --- Combine the top two items --- */
+  if (t) {
+    i = 0;
+    while (t) { assert(i < sizeof(v)); v[i++] = t%r; t /= r; }
+    while (i) read_digit(&rs, nf, v[--i]);
+    any = 1;
+  }
 
-      sp--;
-      m = s[sp].m;
-      m = mp_mul(m, m, pow[s[sp].i]);
-      m = mp_add(m, m, s[sp + 1].m);
-      s[sp].m = m;
-      MP_DROP(s[sp + 1].m);
-      s[sp].i++;
+  /* --- Read more stuff --- */
 
-      /* --- Make a new radix power if necessary --- */
+  for (;;) {
+    ch = ops->get(p);
+    if ((d = char_digit(ch, radix)) < 0) break;
+    read_digit(&rs, nf, d); any = 1;
+  }
+  ops->unget(ch, p);
 
-      if (s[sp].i >= pows) {
-       assert(((void)"Number is too unimaginably huge", pows < DEPTH));
-       pow[pows] = mp_sqr(MP_NEW, pow[pows - 1]);
-       pows++;
-      }
-    }
-    f |= f_ok;
-    sp++;
+  /* --- Stitch all of the numbers together --- *
+   *
+   * This is not the same code as @read_digit@.  In particular, here we must
+   * cope with the partial result being some inconvenient power of %$R$%,
+   * rather than %$R^{2^i}$%.
+   */
+
+  if (!any) return (0);
+  m = MP_ZERO; z = MP_ONE;
+  while (rs.sp) {
+    rs.sp--;
+    ensure_power(&rs);
+    n = rs.s[rs.sp].m;
+    n = mp_mul(n, n, z);
+    m = mp_add(m, m, n);
+    z = mp_mul(z, z, rs.pow[rs.s[rs.sp].i]);
+    MP_DROP(n);
   }
+  for (i = 0; i < rs.pows; i++) MP_DROP(rs.pow[i]);
+  MP_DROP(z);
+  return (m);
+}
 
-  ops->unget(ch, p);
+mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p)
+{
+  unsigned t = 0;
+  unsigned nf = 0;
+  int ch, d, rd;
 
-  /* --- If we're done, compute the rest of the number --- */
+  unsigned f = 0;
+#define f_neg 1u
+#define f_ok 2u
 
-  if (f & f_ok) {
-    if (!sp)
-      return (MP_ZERO);
-    else {
-      mp *z = MP_ONE;
-      sp--;
+  /* --- We don't actually need a destination so throw it away --- *
+   *
+   * But note the flags before we lose it entirely.
+   */
 
-      while (sp > 0) {
+  if (m) {
+    nf = m->f & MP_BURN;
+    MP_DROP(m);
+  }
 
-       /* --- Combine the top two items --- */
+  /* --- Maintain a lookahead character --- */
 
-       sp--;
-       m = s[sp].m;
-       z = mp_mul(z, z, pow[s[sp + 1].i]);
-       m = mp_mul(m, m, z);
-       m = mp_add(m, m, s[sp + 1].m);
-       s[sp].m = m;
-       MP_DROP(s[sp + 1].m);
+  ch = ops->get(p);
 
-       /* --- Make a new radix power if necessary --- */
+  /* --- If we're reading text, skip leading space, and maybe a sign --- */
 
-       if (s[sp].i >= pows) {
-         assert(((void)"Number is too unimaginably huge", pows < DEPTH));
-         pow[pows] = mp_sqr(MP_NEW, pow[pows - 1]);
-         pows++;
-       }
+  if (radix >= 0) {
+    while (isspace(ch)) ch = ops->get(p);
+    switch (ch) {
+      case '-': f |= f_neg; /* and on */
+      case '+': do ch = ops->get(p); while (isspace(ch));
+    }
+  }
+
+  /* --- If we don't have a fixed radix, then parse one from the input --- *
+   *
+   * This is moderately easy if the input starts with `0x' or similar.  If it
+   * starts with `0' and something else, then it might be octal, or just a
+   * plain old zero.  Finally, it might start with a leading `NN_', in which
+   * case we carefully collect the decimal number until we're sure it's
+   * either a radix prefix (in which case we accept it and start over) or it
+   * isn't (in which case it's actually the start of a large number we need
+   * to read).
+   */
+
+  if (radix == 0) {
+    if (ch == '0') {
+      ch = ops->get(p);
+      switch (ch) {
+       case 'x': case 'X': radix = 16; goto fetch;
+       case 'o': case 'O': radix = 8; goto fetch;
+       case 'b': case 'B': radix = 2; goto fetch;
+       fetch: ch = ops->get(p); break;
+       default: radix = 8; f |= f_ok; break;
+      }
+    } else {
+      if ((d = char_digit(ch, 10)) < 0) { ops->unget(ch, p); return (0); }
+      for (;;) {
+       t = 10*t + d;
+       ch = ops->get(p);
+       if (t > 52) break;
+       if ((d = char_digit(ch, 10)) < 0) break;
+      }
+      if (ch != '_' || t > 52) radix = 10;
+      else {
+       radix = t; t = 0;
+       ch = ops->get(p);
       }
-      MP_DROP(z);
-      m = s[0].m;
     }
-  } else {
-    unsigned i;
-    for (i = 0; i < sp; i++)
-      MP_DROP(s[i].m);
   }
 
-  /* --- Clear the radix power list --- */
+  /* --- We're now ready to dispatch to the correct handler --- */
 
-  {
-    unsigned i;
-    for (i = 1; i < pows; i++)
-      MP_DROP(pow[i]);
+  rd = radix < 0 ? -radix : radix;
+  ops->unget(ch, p);
+  switch (rd) {
+    case   2: m = read_binary(radix,  1, nf, ops, p); break;
+    case   4: m = read_binary(radix,  2, nf, ops, p); break;
+    case   8: m = read_binary(radix,  3, nf, ops, p); break;
+    case  16: m = read_binary(radix,  4, nf, ops, p); break;
+    case  32: m = read_binary(radix,  5, nf, ops, p); break;
+    case  64: m = read_binary(radix,  6, nf, ops, p); break;
+    case 128: m = read_binary(radix,  7, nf, ops, p); break;
+    default:  m = read_general(radix, t, nf, ops, p); break;
   }
 
-  /* --- Bail out if the number was bad --- */
+  /* --- That didn't work --- *
+   *
+   * If we've already read something then return that.  Otherwise it's an
+   * error.
+   */
 
-done:
-  if (!(f & f_ok))
-    return (0);
+  if (!m) {
+    if (f & f_ok) return (MP_ZERO);
+    else return (0);
+  }
+
+  /* --- Negate the result if we should do that --- */
+
+  if (f & f_neg) m = mp_neg(m, m);
 
-  /* --- Set the sign and return --- */
+  /* --- And we're all done --- */
 
-  if (f & f_neg)
-    m->f |= MP_NEG;
-  MP_SHRINK(m);
   return (m);
 
-#undef f_start
 #undef f_neg
 #undef f_ok
 }