math/gfreduce.[ch]: Fix out-of-bounds memory access.
[u/mdw/catacomb] / math / mpreduce.c
index 669f516..b148dd5 100644 (file)
 
 DA_DECL(instr_v, mpreduce_instr);
 
+/*----- Theory ------------------------------------------------------------*
+ *
+ * We're given a modulus %$p = 2^n - d$%, where %$d < 2^n$%, and some %$x$%,
+ * and we want to compute %$x \bmod p$%.  We work in base %$2^w$%, for some
+ * appropriate %$w$%.  The important observation is that
+ * %$d \equiv 2^n \pmod p$%.
+ *
+ * Suppose %$x = x' + z 2^k$%, where %$k \ge n$%; then
+ * %$x \equiv x' + d z 2^{k-n} \pmod p$%.  We can use this to trim the
+ * representation of %$x$%; each time, we reduce %$x$% by a mutliple of
+ * %$2^{k-n} p$%.  We can do this in two passes: firstly by taking whole
+ * words off the top, and then (if necessary) by trimming the top word.
+ * Finally, if %$p \le x < 2^n$% then %$0 \le x - p < p$% and we're done.
+ *
+ * A common trick, apparently, is to choose %$d$% such that it has a very
+ * sparse non-adjacent form, and, moreover, that this form is nicely aligned
+ * with common word sizes.  (That is, write %$d = \sum_{0\le i<m} d_i 2^i$%,
+ * with %$d_i \in \{ -1, 0, +1 \}$% and most %$d_i = 0$%.)  Then adding
+ * %$z d$% is a matter of adding and subtracting appropriately shifted copies
+ * of %$z$%.
+ *
+ * Most libraries come with hardwired code for doing this for a few
+ * well-known values of %$p$%.  We take a different approach, for two
+ * reasons.
+ *
+ *   * Firstly, it privileges built-in numbers over user-selected ones, even
+ *     if the latter have the right (or better) properties.
+ *
+ *   * Secondly, writing appropriately optimized reduction functions when we
+ *     don't know the exact characteristics of the target machine is rather
+ *     difficult.
+ *
+ * Our solution, then, is to `compile' the value %$p$% at runtime, into a
+ * sequence of simple instructions for doing the reduction.
+ */
+
 /*----- Main code ---------------------------------------------------------*/
 
 /* --- @mpreduce_create@ --- *
@@ -45,7 +81,10 @@ DA_DECL(instr_v, mpreduce_instr);
  * Arguments:  @gfreduce *r@ = structure to fill in
  *             @mp *x@ = an integer
  *
- * Returns:    Zero if successful; nonzero on failure.
+ * Returns:    Zero if successful; nonzero on failure.  The current
+ *             algorithm always succeeds when given positive @x@.  Earlier
+ *             versions used to fail on particular kinds of integers, but
+ *             this is guaranteed not to happen any more.
  *
  * Use:                Initializes a context structure for reduction.
  */
@@ -84,25 +123,66 @@ int mpreduce_create(mpreduce *r, mp *p)
   /* --- Main loop --- *
    *
    * A simple state machine decomposes @p@ conveniently into positive and
-   * negative powers of 2.  The pure form of the state machine is left below
-   * for reference (and in case I need inspiration for a NAF exponentiator).
+   * negative powers of 2.
+   *
+   * Here's the relevant theory.  The important observation is that
+   * %$2^i = 2^{i+1} - 2^i$%, and hence
+   *
+   *   * %$\sum_{a\le i<b} 2^i = 2^b - 2^a$%, and
+   *
+   *   * %$2^c - 2^{b+1} + 2^b - 2^a = 2^c - 2^b - 2^a$%.
+   *
+   * The first of these gives us a way of combining a run of several one
+   * bits, and the second gives us a way of handling a single-bit
+   * interruption in such a run.
+   *
+   * We start with a number %$p = \sum_{0\le i<n} p_i 2^i$%, and scan
+   * right-to-left using a simple state-machine keeping (approximate) track
+   * of the two previous bits.  The @Z@ states denote that we're in a string
+   * of zeros; @Z1@ means that we just saw a 1 bit after a sequence of zeros.
+   * Similarly, the @X@ states denote that we're in a string of ones; and
+   * @X0@ means that we just saw a zero bit after a sequence of ones.  The
+   * state machine lets us delay decisions about what to do when we've seen a
+   * change to the status quo (a one after a run of zeros, or vice-versa)
+   * until we've seen the next bit, so we can tell whether this is an
+   * isolated bit or the start of a new sequence.
+   *
+   * More formally: we define two functions %$Z^b_i$% and %$X^b_i$% as
+   * follows.
+   *
+   *   * %$Z^0_i(S, 0) = S$%
+   *   * %$Z^0_i(S, n) = Z^0_{i+1}(S, n)$%
+   *   * %$Z^0_i(S, n + 2^i) = Z^1_{i+1}(S, n)$%
+   *   * %$Z^1_i(S, n) = Z^0_{i+1}(S \cup \{ 2^{i-1} \}, n)$%
+   *   * %$Z^1_i(S, n + 2^i) = X^1_{i+1}(S \cup \{ -2^{i-1} \}, n)$%
+   *   * %$X^0_i(S, n) = Z^0_{i+1}(S, \{ 2^{i-1} \})$%
+   *   * %$X^0_i(S, n + 2^i) = X^1_{i+1}(S \cup \{ -2^{i-1} \}, n)$%
+   *   * %$X^1_i(S, n) = X^0_{i+1}(S, n)$%
+   *   * %$X^1_i(S, n + 2^i) = X^1_{i+1}(S, n)$%
+   *
+   * The reader may verify (by induction on %$n$%) that the following
+   * properties hold.
+   *
+   *   * %$Z^0_0(\emptyset, n)$% is well-defined for all %$n \ge 0$%
+   *   * %$\sum Z^b_i(S, n) = \sum S + n + b 2^{i-1}$%
+   *   * %$\sum X^b_i(S, n) = \sum S + n + (b + 1) 2^{i-1}$%
+   *
+   * From these, of course, we can deduce %$\sum Z^0_0(\emptyset, n) = n$%.
+   *
+   * We apply the above recurrence to build a simple instruction sequence for
+   * adding an appropriate multiple of %$d$% to a given number.  Suppose that
+   * %$2^{w(N-1)} \le 2^{n-1} \le p < 2^n \le 2^{wN}$%.  The machine which
+   * interprets these instructions does so in the context of a
+   * single-precision multiplicand @z@ and a pointer @v@ to the
+   * %%\emph{most}%% significant word of an %$N + 1$%-word integer, and the
+   * instruction sequence should add %$z d$% to this integer.
+   *
+   * The available instructions are named @MPRI_{ADD,SUB}{,LSL}@; they add
+   * (or subtract) %$z$% (shifted left by some amount, in the @LSL@ variants)
+   * to some word earlier than @v@.  The relevant quantities are encoded in
+   * the instruction's immediate operands.
    */
 
-#ifdef DEBUG
-  for (i = 0, mp_scan(&sc, p); mp_step(&sc); i++) {
-    switch (st | mp_bit(&sc)) {
-      case  Z | 1: st = Z1; break;
-      case Z1 | 0: st =         Z; printf("+ %lu\n", i - 1); break;
-      case Z1 | 1: st =         X; printf("- %lu\n", i - 1); break;
-      case  X | 0: st = X0; break;
-      case X0 | 1: st =         X; printf("- %lu\n", i - 1); break;
-      case X0 | 0: st =         Z; printf("+ %lu\n", i - 1); break;
-    }
-  }
-  if (st >= X) printf("+ %lu\n", i - 1);
-  st = Z;
-#endif
-
   bb = MPW_BITS - (d + 1)%MPW_BITS;
   for (i = 0, mp_scan(&sc, p); i < d && mp_step(&sc); i++) {
     switch (st | mp_bit(&sc)) {
@@ -118,15 +198,56 @@ int mpreduce_create(mpreduce *r, mp *p)
        INSTR(op | !!b, w, b);
     }
   }
-  if (DA_LEN(&iv) && (DA(&iv)[DA_LEN(&iv) - 1].op & ~1u) == MPRI_SUB) {
-    mp_drop(r->p);
-    DA_DESTROY(&iv);
-    return (-1);
+
+  /* --- Fix up wrong-sided decompositions --- *
+   *
+   * At this point, we haven't actually finished up the state machine
+   * properly.  We stopped scanning just after bit %$n - 1$% -- the most
+   * significant one, which we know in advance must be set (since @x@ is
+   * strictly positive).  Therefore we are either in state @X@ or @Z1@.  In
+   * the former case, we have nothing to do.  In the latter, there are two
+   * subcases to deal with.  If there are no other instructions, then @x@ is
+   * a perfect power of two, and %$d = 0$%, so again there is nothing to do.
+   *
+   * In the remaining case, we have decomposed @x@ as %$2^{n-1} + d$%, for
+   * some positive %$d%, which is unfortuante: if we're asked to reduce
+   * %$2^n$%, say, we'll end up with %$-d$% (or would do, if we weren't
+   * sticking to unsigned arithmetic for good performance).  So instead, we
+   * rewrite this as %$2^n - 2^{n-1} + d$% and everything will be good.
+   */
+
+  if (st == Z1 && DA_LEN(&iv)) {
+    w = 1;
+    b = (bb + d)%MPW_BITS;
+    INSTR(MPRI_ADD | !!b, w, b);
   }
 
 #undef INSTR
 
-  /* --- Wrap up --- */
+  /* --- Wrap up --- *
+   *
+   * Store the generated instruction sequence in our context structure.  If
+   * %$p$%'s bit length %$n$% is a multiple of the word size %$w$% then
+   * there's nothing much else to do here.  Otherwise, we have an additional
+   * job.
+   *
+   * The reduction operation has three phases.  The first trims entire words
+   * from the argument, and the instruction sequence constructed above does
+   * this well; the second phase reduces an integer which has the same number
+   * of words as %$p$%, but strictly more bits.  (The third phase is a single
+   * conditional subtraction of %$p$%, in case the argument is the same bit
+   * length as %$p$% but greater; this doesn't concern us here.)  To handle
+   * the second phase, we create another copy of the instruction stream, with
+   * all of the target shifts adjusted upwards by %$s = n \bmod w$%.
+   *
+   * In this case, we are acting on an %$(N - 1)$%-word operand, and so
+   * (given the remarks above) must check that this is still valid, but a
+   * moment's reflection shows that this must be fine: the most distant
+   * target must be the bit %$s$% from the top of the least-significant word;
+   * but since we shift all of the targets up by %$s$%, this now addresses
+   * the bottom bit of the next most significant word, and there is no
+   * underflow.
+   */
 
   r->in = DA_LEN(&iv);
   if (!r->in)
@@ -154,9 +275,6 @@ int mpreduce_create(mpreduce *r, mp *p)
   }
   DA_DESTROY(&iv);
 
-#ifdef DEBUG
-  mpreduce_dump(r, stdout);
-#endif
   return (0);
 }
 
@@ -224,15 +342,6 @@ static void run(const mpreduce_instr *i, const mpreduce_instr *il,
                mpw *v, mpw z)
 {
   for (; i < il; i++) {
-#ifdef DEBUG
-    mp vv;
-    mp_build(&vv, v - i->argx, v + 1);
-    printf("  0x"); mp_writefile(&vv, stdout, 16);
-    printf(" %c (0x%lx << %u) == 0x",
-          (i->op & ~1u) == MPRI_ADD ? '+' : '-',
-          (unsigned long)z,
-          i->argy);
-#endif
     switch (i->op) {
       case MPRI_ADD: MPX_UADDN(v - i->argx, v + 1, z); break;
       case MPRI_ADDLSL: mpx_uaddnlsl(v - i->argx, v + 1, z, i->argy); break;
@@ -241,11 +350,6 @@ static void run(const mpreduce_instr *i, const mpreduce_instr *il,
       default:
        abort();
     }
-#ifdef DEBUG
-    mp_build(&vv, v - i->argx, v + 1);
-    mp_writefile(&vv, stdout, 16);
-    printf("\n");
-#endif
   }
 }
 
@@ -255,10 +359,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
   const mpreduce_instr *il;
   mpw z;
 
-#ifdef DEBUG
-  mp *_r = 0, *_rr = 0;
-#endif
-
   /* --- If source is negative, divide --- */
 
   if (MP_NEGP(x)) {
@@ -272,14 +372,7 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
   if (d) MP_DROP(d);
   MP_DEST(x, MP_LEN(x), x->f);
 
-  /* --- Do the reduction --- */
-
-#ifdef DEBUG
-  _r = MP_NEW;
-  mp_div(0, &_r, x, r->p);
-  MP_PRINTX("x", x);
-  _rr = 0;
-#endif
+  /* --- Stage one: trim excess words from the most significant end --- */
 
   il = r->iv + r->in;
   if (MP_LEN(x) >= r->lim) {
@@ -290,28 +383,21 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
        z = *vl;
        *vl = 0;
        run(r->iv, il, vl, z);
-#ifdef DEBUG
-       MP_PRINTX("x", x);
-       mp_div(0, &_rr, x, r->p);
-       assert(MP_EQ(_r, _rr));
-#endif
       }
     }
+
+    /* --- Stage two: trim excess bits from the most significant word --- */
+
     if (r->s) {
       while (*vl >> r->s) {
        z = *vl >> r->s;
        *vl &= ((1 << r->s) - 1);
        run(r->iv + r->in, il + r->in, vl, z);
-#ifdef DEBUG
-       MP_PRINTX("x", x);
-       mp_div(0, &_rr, x, r->p);
-       assert(MP_EQ(_r, _rr));
-#endif
       }
     }
   }
 
-  /* --- Finishing touches --- */
+  /* --- Stage three: conditional subtraction --- */
 
   MP_SHRINK(x);
   if (MP_CMP(x, >=, r->p))
@@ -319,11 +405,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
 
   /* --- Done --- */
 
-#ifdef DEBUG
-  assert(MP_EQ(_r, x));
-  mp_drop(_r);
-  mp_drop(_rr);
-#endif
   return (x);
 }