X-Git-Url: https://git.distorted.org.uk/u/mdw/catacomb/blobdiff_plain/0f00dc4c8eb47e67bc0f148c2dd109f73a451e0a..5937940fc6c45b93eb95554996d1cd1b5a1bf307:/math/mpreduce.c diff --git a/math/mpreduce.c b/math/mpreduce.c index 669f516..b148dd5 100644 --- a/math/mpreduce.c +++ b/math/mpreduce.c @@ -38,6 +38,42 @@ DA_DECL(instr_v, mpreduce_instr); +/*----- Theory ------------------------------------------------------------* + * + * We're given a modulus %$p = 2^n - d$%, where %$d < 2^n$%, and some %$x$%, + * and we want to compute %$x \bmod p$%. We work in base %$2^w$%, for some + * appropriate %$w$%. The important observation is that + * %$d \equiv 2^n \pmod p$%. + * + * Suppose %$x = x' + z 2^k$%, where %$k \ge n$%; then + * %$x \equiv x' + d z 2^{k-n} \pmod p$%. We can use this to trim the + * representation of %$x$%; each time, we reduce %$x$% by a mutliple of + * %$2^{k-n} p$%. We can do this in two passes: firstly by taking whole + * words off the top, and then (if necessary) by trimming the top word. + * Finally, if %$p \le x < 2^n$% then %$0 \le x - p < p$% and we're done. + * + * A common trick, apparently, is to choose %$d$% such that it has a very + * sparse non-adjacent form, and, moreover, that this form is nicely aligned + * with common word sizes. (That is, write %$d = \sum_{0\le i= X) printf("+ %lu\n", i - 1); - st = Z; -#endif - bb = MPW_BITS - (d + 1)%MPW_BITS; for (i = 0, mp_scan(&sc, p); i < d && mp_step(&sc); i++) { switch (st | mp_bit(&sc)) { @@ -118,15 +198,56 @@ int mpreduce_create(mpreduce *r, mp *p) INSTR(op | !!b, w, b); } } - if (DA_LEN(&iv) && (DA(&iv)[DA_LEN(&iv) - 1].op & ~1u) == MPRI_SUB) { - mp_drop(r->p); - DA_DESTROY(&iv); - return (-1); + + /* --- Fix up wrong-sided decompositions --- * + * + * At this point, we haven't actually finished up the state machine + * properly. We stopped scanning just after bit %$n - 1$% -- the most + * significant one, which we know in advance must be set (since @x@ is + * strictly positive). Therefore we are either in state @X@ or @Z1@. In + * the former case, we have nothing to do. In the latter, there are two + * subcases to deal with. If there are no other instructions, then @x@ is + * a perfect power of two, and %$d = 0$%, so again there is nothing to do. + * + * In the remaining case, we have decomposed @x@ as %$2^{n-1} + d$%, for + * some positive %$d%, which is unfortuante: if we're asked to reduce + * %$2^n$%, say, we'll end up with %$-d$% (or would do, if we weren't + * sticking to unsigned arithmetic for good performance). So instead, we + * rewrite this as %$2^n - 2^{n-1} + d$% and everything will be good. + */ + + if (st == Z1 && DA_LEN(&iv)) { + w = 1; + b = (bb + d)%MPW_BITS; + INSTR(MPRI_ADD | !!b, w, b); } #undef INSTR - /* --- Wrap up --- */ + /* --- Wrap up --- * + * + * Store the generated instruction sequence in our context structure. If + * %$p$%'s bit length %$n$% is a multiple of the word size %$w$% then + * there's nothing much else to do here. Otherwise, we have an additional + * job. + * + * The reduction operation has three phases. The first trims entire words + * from the argument, and the instruction sequence constructed above does + * this well; the second phase reduces an integer which has the same number + * of words as %$p$%, but strictly more bits. (The third phase is a single + * conditional subtraction of %$p$%, in case the argument is the same bit + * length as %$p$% but greater; this doesn't concern us here.) To handle + * the second phase, we create another copy of the instruction stream, with + * all of the target shifts adjusted upwards by %$s = n \bmod w$%. + * + * In this case, we are acting on an %$(N - 1)$%-word operand, and so + * (given the remarks above) must check that this is still valid, but a + * moment's reflection shows that this must be fine: the most distant + * target must be the bit %$s$% from the top of the least-significant word; + * but since we shift all of the targets up by %$s$%, this now addresses + * the bottom bit of the next most significant word, and there is no + * underflow. + */ r->in = DA_LEN(&iv); if (!r->in) @@ -154,9 +275,6 @@ int mpreduce_create(mpreduce *r, mp *p) } DA_DESTROY(&iv); -#ifdef DEBUG - mpreduce_dump(r, stdout); -#endif return (0); } @@ -224,15 +342,6 @@ static void run(const mpreduce_instr *i, const mpreduce_instr *il, mpw *v, mpw z) { for (; i < il; i++) { -#ifdef DEBUG - mp vv; - mp_build(&vv, v - i->argx, v + 1); - printf(" 0x"); mp_writefile(&vv, stdout, 16); - printf(" %c (0x%lx << %u) == 0x", - (i->op & ~1u) == MPRI_ADD ? '+' : '-', - (unsigned long)z, - i->argy); -#endif switch (i->op) { case MPRI_ADD: MPX_UADDN(v - i->argx, v + 1, z); break; case MPRI_ADDLSL: mpx_uaddnlsl(v - i->argx, v + 1, z, i->argy); break; @@ -241,11 +350,6 @@ static void run(const mpreduce_instr *i, const mpreduce_instr *il, default: abort(); } -#ifdef DEBUG - mp_build(&vv, v - i->argx, v + 1); - mp_writefile(&vv, stdout, 16); - printf("\n"); -#endif } } @@ -255,10 +359,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x) const mpreduce_instr *il; mpw z; -#ifdef DEBUG - mp *_r = 0, *_rr = 0; -#endif - /* --- If source is negative, divide --- */ if (MP_NEGP(x)) { @@ -272,14 +372,7 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x) if (d) MP_DROP(d); MP_DEST(x, MP_LEN(x), x->f); - /* --- Do the reduction --- */ - -#ifdef DEBUG - _r = MP_NEW; - mp_div(0, &_r, x, r->p); - MP_PRINTX("x", x); - _rr = 0; -#endif + /* --- Stage one: trim excess words from the most significant end --- */ il = r->iv + r->in; if (MP_LEN(x) >= r->lim) { @@ -290,28 +383,21 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x) z = *vl; *vl = 0; run(r->iv, il, vl, z); -#ifdef DEBUG - MP_PRINTX("x", x); - mp_div(0, &_rr, x, r->p); - assert(MP_EQ(_r, _rr)); -#endif } } + + /* --- Stage two: trim excess bits from the most significant word --- */ + if (r->s) { while (*vl >> r->s) { z = *vl >> r->s; *vl &= ((1 << r->s) - 1); run(r->iv + r->in, il + r->in, vl, z); -#ifdef DEBUG - MP_PRINTX("x", x); - mp_div(0, &_rr, x, r->p); - assert(MP_EQ(_r, _rr)); -#endif } } } - /* --- Finishing touches --- */ + /* --- Stage three: conditional subtraction --- */ MP_SHRINK(x); if (MP_CMP(x, >=, r->p)) @@ -319,11 +405,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x) /* --- Done --- */ -#ifdef DEBUG - assert(MP_EQ(_r, x)); - mp_drop(_r); - mp_drop(_rr); -#endif return (x); }