From: Mark Wooding Date: Wed, 14 Oct 2015 10:00:51 +0000 (+0100) Subject: math/mptext.c: Radically refactor `mp_read'. X-Git-Tag: 2.2.1~6 X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb/commitdiff_plain/a6b6ae6baba64bf876bac6e4d34b364035666183 math/mptext.c: Radically refactor `mp_read'. It used to be the largest function in the library -- possibly in my codebase. * Split it into three main pieces: the special-purpose binary reader, an efficient stack-based general-radix reader, and a high-level syntax parser which picks out signs and base indicators. This removes the complicated entangling of the base indicator parsing with the general-radix reader which was the worst feature of the old version. * Split commonly-used functionality out into separate functions, notably `char_digit' and `read_digit'. The result is code which is easier to understand and actually shorter. --- diff --git a/math/mptext.c b/math/mptext.c index ddf50137..c902264d 100644 --- a/math/mptext.c +++ b/math/mptext.c @@ -96,112 +96,31 @@ * bizarre syntax. */ -mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p) +static int char_digit(int ch, int radix) { - int ch; /* Current char being considered */ - unsigned f = 0; /* Flags about the current number */ - int r; /* Radix to switch over to */ - mpw rd; /* Radix as an @mp@ digit */ - mp rr; /* The @mp@ for the radix */ - unsigned nf = m ? m->f & MP_BURN : 0; /* New @mp@ flags */ - - /* --- Stacks --- */ - - mp *pow[DEPTH]; /* List of powers */ - unsigned pows; /* Next index to fill */ - struct { unsigned i; mp *m; } s[DEPTH]; /* Main stack */ - unsigned sp; /* Current stack pointer */ - - /* --- Flags --- */ - -#define f_neg 1u -#define f_ok 2u -#define f_start 4u - - /* --- Initialize the stacks --- */ - - mp_build(&rr, &rd, &rd + 1); - pow[0] = &rr; - pows = 1; - - sp = 0; - - /* --- Initialize the destination number --- */ - - if (m) - MP_DROP(m); - - /* --- Read an initial character --- */ - - ch = ops->get(p); - if (radix >= 0) { - while (isspace(ch)) - ch = ops->get(p); - } - - /* --- Handle an initial sign --- */ - - if (radix >= 0 && (ch == '-' || ch == '+')) { - if (ch == '-') - f |= f_neg; - do ch = ops->get(p); while isspace(ch); - } - - /* --- If the radix is zero, look for leading zeros --- */ - - if (radix > 0) { - assert(((void)"ascii radix must be <= 62", radix <= 62)); - rd = radix; - r = -1; - } else if (radix < 0) { - rd = -radix; - assert(((void)"binary radix must fit in a byte", rd <= UCHAR_MAX)); - r = -1; - } else if (ch != '0') { - rd = 10; - r = 0; - } else { - ch = ops->get(p); - switch (ch) { - case 'x': - rd = 16; - goto prefix; - case 'o': - rd = 8; - goto prefix; - case 'b': - rd = 2; - goto prefix; - prefix: - ch = ops->get(p); - break; - default: - rd = 8; - f |= f_ok; - } - r = -1; - } - - /* --- Use fast algorithm for binary radix --- * - * - * This is the restart point after having parsed a radix number from the - * input. We check whether the radix is binary, and if so use a fast - * algorithm which just stacks the bits up in the right order. - */ + int r = radix < 0 ? -radix : radix; + int d; + + if (ch < 0) return (-1); + if (radix < 0) d = ch; + else if ('0' <= ch && ch <= '9') d = ch - '0'; + else if ('a' <= ch && ch <= 'z') d = ch - 'a' + 10; + else if ('A' <= ch && ch <= 'Z') d = ch - 'A' + (radix > 36 ? 36 : 10); + else return (-1); + if (d >= r) return (-1); + return (d); +} -restart: - switch (rd) { - unsigned bit; - - case 2: bit = 1; goto bin; - case 4: bit = 2; goto bin; - case 8: bit = 3; goto bin; - case 16: bit = 4; goto bin; - case 32: bit = 5; goto bin; - case 64: bit = 6; goto bin; - case 128: bit = 7; goto bin; - default: - break; +static mp *read_binary(int radix, unsigned bit, unsigned nf, + const mptext_ops *ops, void *p) +{ + mpw a = 0; + unsigned b = MPW_BITS; + int any = 0, nz = 0; + int ch, d; + size_t len, n; + mpw *v; + mp *m; /* --- The fast binary algorithm --- * * @@ -212,246 +131,293 @@ restart: * bit-shift when we know where the end of the number is. */ - bin: { - mpw a = 0; - unsigned b = MPW_BITS; - size_t len, n; - mpw *v; - - m = mp_dest(MP_NEW, 1, nf); - len = n = m->sz; - n = len; - v = m->v + n; - for (;; ch = ops->get(p)) { - unsigned x; + m = mp_dest(MP_NEW, 1, nf); + len = n = m->sz; + n = len; + v = m->v + n; - if (ch < 0) - break; + for (;;) { + ch = ops->get(p); + if ((d = char_digit(ch, radix)) < 0) break; - /* --- Check that the character is a digit and in range --- */ + /* --- Ignore leading zeroes, but notice that the number is valid --- */ - if (radix < 0) - x = ch % rd; - else { - if (!isalnum(ch)) - break; - if (ch >= '0' && ch <= '9') - x = ch - '0'; - else { - if (rd <= 36) - ch = tolower(ch); - if (ch >= 'a' && ch <= 'z') /* ASCII dependent! */ - x = ch - 'a' + 10; - else if (ch >= 'A' && ch <= 'Z') - x = ch - 'A' + 36; - else - break; - } - } - if (x >= rd) - break; - - /* --- Feed the digit into the accumulator --- */ - - f |= f_ok; - if (!x && !(f & f_start)) - continue; - f |= f_start; - if (b > bit) { - b -= bit; - a |= MPW(x) << b; - } else { - a |= MPW(x) >> (bit - b); - b += MPW_BITS - bit; - *--v = MPW(a); - n--; - if (!n) { - n = len; - len <<= 1; - v = mpalloc(m->a, len); - memcpy(v + n, m->v, MPWS(n)); - mpfree(m->a, m->v); - m->v = v; - v = m->v + n; - } - a = (b < MPW_BITS) ? MPW(x) << b : 0; - } - } + any = 1; + if (!d && !nz) continue; + nz = 1; - /* --- Finish up --- */ + /* --- Feed the digit into the accumulator --- */ - if (!(f & f_ok)) { - mp_drop(m); - m = 0; + if (b > bit) { + b -= bit; + a |= MPW(d) << b; } else { - *--v = MPW(a); - n--; - m->sz = len; - m->vl = m->v + len; - m->f &= ~MP_UNDEF; - m = mp_lsr(m, m, (unsigned long)n * MPW_BITS + b); + a |= MPW(d) >> (bit - b); + b += MPW_BITS - bit; + *--v = MPW(a); n--; + if (!n) { + n = len; len <<= 1; + v = mpalloc(m->a, len); + memcpy(v + n, m->v, MPWS(n)); + mpfree(m->a, m->v); + m->v = v; v = m->v + n; + } + a = (b < MPW_BITS) ? MPW(d) << b : 0; } - ops->unget(ch, p); - goto done; - }} + } - /* --- Time to start --- */ + /* --- Finish up --- */ - for (;; ch = ops->get(p)) { - unsigned x; + ops->unget(ch, p); + if (!any) { mp_drop(m); return (0); } - if (ch < 0) - break; + *--v = MPW(a); n--; + m->sz = len; + m->vl = m->v + len; + m->f &= ~MP_UNDEF; + m = mp_lsr(m, m, (unsigned long)n * MPW_BITS + b); - /* --- An underscore indicates a numbered base --- */ + return (m); +} - if (ch == '_' && r > 0 && r <= 62) { - unsigned i; +struct readstate { - /* --- Clear out the stacks --- */ + /* --- State for the general-base reader --- * + * + * There are two arrays. The @pow@ array is set so that @pow[i]@ contains + * %$R^{2^i}$% for @i < pows@. The stack @s@ contains partial results: + * each entry contains a value @m@ corresponding to %$2^i$% digits. + * Inductively, an empty stack represents zero; if a stack represents %$x$% + * then pushing a new entry on the top causes the stack to represent + * %$R^{2^i} x + m$%. + * + * It is an invariant that each entry has a strictly smaller @i@ than the + * items beneath it. This is achieved by coaslescing entries at the top if + * they have equal %$i$% values: if the top items are %$(m, i)$%, and + * %$(M', i)$%, and the rest of the stack represents the integer %$x$%, + * then %$R^{2^i} (R^{2^i} x + M) + m = R^{2^{i+1}} x + (R^{2^i} M + m)$%, + * so we replace the top two items by %$((R^{2^i} M + m), i + 1)$%, and + * repeat if necessary. + */ - for (i = 1; i < pows; i++) - MP_DROP(pow[i]); - pows = 1; - for (i = 0; i < sp; i++) - MP_DROP(s[i].m); - sp = 0; + unsigned pows, sp; + struct { unsigned i; mp *m; } s[DEPTH]; + mp *pow[DEPTH]; +}; - /* --- Restart the search --- */ +static void ensure_power(struct readstate *rs) +{ + /* --- Make sure we have the necessary %$R^{2^i}$% computed --- */ - rd = r; - r = -1; - f &= ~f_ok; - ch = ops->get(p); - goto restart; - } + if (rs->s[rs->sp].i >= rs->pows) { + assert(rs->pows < DEPTH); + rs->pow[rs->pows] = mp_sqr(MP_NEW, rs->pow[rs->pows - 1]); + rs->pows++; + } +} - /* --- Check that the character is a digit and in range --- */ +static void read_digit(struct readstate *rs, unsigned nf, int d) +{ + mp *m = mp_new(1, nf); + m->v[0] = d; - if (radix < 0) - x = ch % rd; - else { - if (!isalnum(ch)) - break; - if (ch >= '0' && ch <= '9') - x = ch - '0'; - else { - if (rd <= 36) - ch = tolower(ch); - if (ch >= 'a' && ch <= 'z') /* ASCII dependent! */ - x = ch - 'a' + 10; - else if (ch >= 'A' && ch <= 'Z') - x = ch - 'A' + 36; - else - break; - } - } + /* --- Put the new digit on top --- */ - /* --- Sort out what to do with the character --- */ + assert(rs->sp < DEPTH); + rs->s[rs->sp].m = m; + rs->s[rs->sp].i = 0; - if (x >= 10 && r >= 0) - r = -1; - if (x >= rd) - break; + /* --- Restore the stack invariant --- */ - if (r >= 0) - r = r * 10 + x; + while (rs->sp && rs->s[rs->sp - 1].i <= rs->s[rs->sp].i) { + assert(rs->sp > 0); + ensure_power(rs); + rs->sp--; - /* --- Stick the character on the end of my integer --- */ + m = rs->s[rs->sp].m; + m = mp_mul(m, m, rs->pow[rs->s[rs->sp + 1].i]); + m = mp_add(m, m, rs->s[rs->sp + 1].m); + MP_DROP(rs->s[rs->sp + 1].m); + rs->s[rs->sp].m = m; + rs->s[rs->sp].i++; + } - assert(((void)"Number is too unimaginably huge", sp < DEPTH)); - s[sp].m = m = mp_new(1, nf); - m->v[0] = x; - s[sp].i = 0; + /* --- Leave the stack pointer at an empty item --- */ - /* --- Now grind through the stack --- */ + rs->sp++; +} - while (sp > 0 && s[sp - 1].i == s[sp].i) { +static mp *read_general(int radix, unsigned t, unsigned nf, + const mptext_ops *ops, void *p) +{ + struct readstate rs; + unsigned char v[4]; + unsigned i; + mpw r; + int any = 0; + int ch, d; + mp rr; + mp *m, *z, *n; + + /* --- Prepare the stack --- */ + + r = radix < 0 ? -radix : radix; + mp_build(&rr, &r, &r + 1); + rs.pow[0] = &rr; + rs.pows = 1; + rs.sp = 0; + + /* --- If we've partially parsed some input then feed it in --- * + * + * Unfortunately, what we've got is backwards. Fortunately there's a + * fairly tight upper bound on how many digits @t@ might be, since we + * aborted that loop once it got too large. + */ - /* --- Combine the top two items --- */ + if (t) { + i = 0; + while (t) { assert(i < sizeof(v)); v[i++] = t%r; t /= r; } + while (i) read_digit(&rs, nf, v[--i]); + any = 1; + } - sp--; - m = s[sp].m; - m = mp_mul(m, m, pow[s[sp].i]); - m = mp_add(m, m, s[sp + 1].m); - s[sp].m = m; - MP_DROP(s[sp + 1].m); - s[sp].i++; + /* --- Read more stuff --- */ - /* --- Make a new radix power if necessary --- */ + for (;;) { + ch = ops->get(p); + if ((d = char_digit(ch, radix)) < 0) break; + read_digit(&rs, nf, d); any = 1; + } + ops->unget(ch, p); - if (s[sp].i >= pows) { - assert(((void)"Number is too unimaginably huge", pows < DEPTH)); - pow[pows] = mp_sqr(MP_NEW, pow[pows - 1]); - pows++; - } - } - f |= f_ok; - sp++; + /* --- Stitch all of the numbers together --- * + * + * This is not the same code as @read_digit@. In particular, here we must + * cope with the partial result being some inconvenient power of %$R$%, + * rather than %$R^{2^i}$%. + */ + + if (!any) return (0); + m = MP_ZERO; z = MP_ONE; + while (rs.sp) { + rs.sp--; + ensure_power(&rs); + n = rs.s[rs.sp].m; + n = mp_mul(n, n, z); + m = mp_add(m, m, n); + z = mp_mul(z, z, rs.pow[rs.s[rs.sp].i]); + MP_DROP(n); } + for (i = 0; i < rs.pows; i++) MP_DROP(rs.pow[i]); + MP_DROP(z); + return (m); +} - ops->unget(ch, p); +mp *mp_read(mp *m, int radix, const mptext_ops *ops, void *p) +{ + unsigned t = 0; + unsigned nf = 0; + int ch, d, rd; - /* --- If we're done, compute the rest of the number --- */ + unsigned f = 0; +#define f_neg 1u +#define f_ok 2u - if (f & f_ok) { - if (!sp) - return (MP_ZERO); - else { - mp *z = MP_ONE; - sp--; + /* --- We don't actually need a destination so throw it away --- * + * + * But note the flags before we lose it entirely. + */ - while (sp > 0) { + if (m) { + nf = m->f & MP_BURN; + MP_DROP(m); + } - /* --- Combine the top two items --- */ + /* --- Maintain a lookahead character --- */ - sp--; - m = s[sp].m; - z = mp_mul(z, z, pow[s[sp + 1].i]); - m = mp_mul(m, m, z); - m = mp_add(m, m, s[sp + 1].m); - s[sp].m = m; - MP_DROP(s[sp + 1].m); + ch = ops->get(p); - /* --- Make a new radix power if necessary --- */ + /* --- If we're reading text, skip leading space, and maybe a sign --- */ - if (s[sp].i >= pows) { - assert(((void)"Number is too unimaginably huge", pows < DEPTH)); - pow[pows] = mp_sqr(MP_NEW, pow[pows - 1]); - pows++; - } + if (radix >= 0) { + while (isspace(ch)) ch = ops->get(p); + switch (ch) { + case '-': f |= f_neg; /* and on */ + case '+': do ch = ops->get(p); while (isspace(ch)); + } + } + + /* --- If we don't have a fixed radix, then parse one from the input --- * + * + * This is moderately easy if the input starts with `0x' or similar. If it + * starts with `0' and something else, then it might be octal, or just a + * plain old zero. Finally, it might start with a leading `NN_', in which + * case we carefully collect the decimal number until we're sure it's + * either a radix prefix (in which case we accept it and start over) or it + * isn't (in which case it's actually the start of a large number we need + * to read). + */ + + if (radix == 0) { + if (ch == '0') { + ch = ops->get(p); + switch (ch) { + case 'x': case 'X': radix = 16; goto fetch; + case 'o': case 'O': radix = 8; goto fetch; + case 'b': case 'B': radix = 2; goto fetch; + fetch: ch = ops->get(p); break; + default: radix = 8; f |= f_ok; break; + } + } else { + if ((d = char_digit(ch, 10)) < 0) { ops->unget(ch, p); return (0); } + for (;;) { + t = 10*t + d; + ch = ops->get(p); + if (t > 52) break; + if ((d = char_digit(ch, 10)) < 0) break; + } + if (ch != '_' || t > 52) radix = 10; + else { + radix = t; t = 0; + ch = ops->get(p); } - MP_DROP(z); - m = s[0].m; } - } else { - unsigned i; - for (i = 0; i < sp; i++) - MP_DROP(s[i].m); } - /* --- Clear the radix power list --- */ + /* --- We're now ready to dispatch to the correct handler --- */ - { - unsigned i; - for (i = 1; i < pows; i++) - MP_DROP(pow[i]); + rd = radix < 0 ? -radix : radix; + ops->unget(ch, p); + switch (rd) { + case 2: m = read_binary(radix, 1, nf, ops, p); break; + case 4: m = read_binary(radix, 2, nf, ops, p); break; + case 8: m = read_binary(radix, 3, nf, ops, p); break; + case 16: m = read_binary(radix, 4, nf, ops, p); break; + case 32: m = read_binary(radix, 5, nf, ops, p); break; + case 64: m = read_binary(radix, 6, nf, ops, p); break; + case 128: m = read_binary(radix, 7, nf, ops, p); break; + default: m = read_general(radix, t, nf, ops, p); break; } - /* --- Bail out if the number was bad --- */ + /* --- That didn't work --- * + * + * If we've already read something then return that. Otherwise it's an + * error. + */ -done: - if (!(f & f_ok)) - return (0); + if (!m) { + if (f & f_ok) return (MP_ZERO); + else return (0); + } + + /* --- Negate the result if we should do that --- */ + + if (f & f_neg) m = mp_neg(m, m); - /* --- Set the sign and return --- */ + /* --- And we're all done --- */ - if (f & f_neg) - m->f |= MP_NEG; - MP_SHRINK(m); return (m); -#undef f_start #undef f_neg #undef f_ok }