math/mpreduce.c: Add extensive commentary.
authorMark Wooding <mdw@distorted.org.uk>
Mon, 5 Aug 2013 20:13:48 +0000 (21:13 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Mon, 5 Aug 2013 20:13:48 +0000 (21:13 +0100)
The behaviour of this code must have been something of a mystery.  It's
not arbitrary, but it is a little subtle in places.  Add a full
explanation of the whole thing.

math/mpreduce.c

index 669f516..b4543a6 100644 (file)
 
 DA_DECL(instr_v, mpreduce_instr);
 
 
 DA_DECL(instr_v, mpreduce_instr);
 
+/*----- Theory ------------------------------------------------------------*
+ *
+ * We're given a modulus %$p = 2^n - d$%, where %$d < 2^n$%, and some %$x$%,
+ * and we want to compute %$x \bmod p$%.  We work in base %$2^w$%, for some
+ * appropriate %$w$%.  The important observation is that
+ * %$d \equiv 2^n \pmod p$%.
+ *
+ * Suppose %$x = x' + z 2^k$%, where %$k \ge n$%; then
+ * %$x \equiv x' + d z 2^{k-n} \pmod p$%.  We can use this to trim the
+ * representation of %$x$%; each time, we reduce %$x$% by a mutliple of
+ * %$2^{k-n} p$%.  We can do this in two passes: firstly by taking whole
+ * words off the top, and then (if necessary) by trimming the top word.
+ * Finally, if %$p \le x < 2^n$% then %$0 \le x - p < p$% and we're done.
+ *
+ * A common trick, apparently, is to choose %$d$% such that it has a very
+ * sparse non-adjacent form, and, moreover, that this form is nicely aligned
+ * with common word sizes.  (That is, write %$d = \sum_{0\le i<m} d_i 2^i$%,
+ * with %$d_i \in \{ -1, 0, +1 \}$% and most %$d_i = 0$%.)  Then adding
+ * %$z d$% is a matter of adding and subtracting appropriately shifted copies
+ * of %$z$%.
+ *
+ * Most libraries come with hardwired code for doing this for a few
+ * well-known values of %$p$%.  We take a different approach, for two
+ * reasons.
+ *
+ *   * Firstly, it privileges built-in numbers over user-selected ones, even
+ *     if the latter have the right (or better) properties.
+ *
+ *   * Secondly, writing appropriately optimized reduction functions when we
+ *     don't know the exact characteristics of the target machine is rather
+ *     difficult.
+ *
+ * Our solution, then, is to `compile' the value %$p$% at runtime, into a
+ * sequence of simple instructions for doing the reduction.
+ */
+
 /*----- Main code ---------------------------------------------------------*/
 
 /* --- @mpreduce_create@ --- *
 /*----- Main code ---------------------------------------------------------*/
 
 /* --- @mpreduce_create@ --- *
@@ -84,8 +120,64 @@ int mpreduce_create(mpreduce *r, mp *p)
   /* --- Main loop --- *
    *
    * A simple state machine decomposes @p@ conveniently into positive and
   /* --- Main loop --- *
    *
    * A simple state machine decomposes @p@ conveniently into positive and
-   * negative powers of 2.  The pure form of the state machine is left below
-   * for reference (and in case I need inspiration for a NAF exponentiator).
+   * negative powers of 2.
+   *
+   * Here's the relevant theory.  The important observation is that
+   * %$2^i = 2^{i+1} - 2^i$%, and hence
+   *
+   *   * %$\sum_{a\le i<b} 2^i = 2^b - 2^a$%, and
+   *
+   *   * %$2^c - 2^{b+1} + 2^b - 2^a = 2^c - 2^b - 2^a$%.
+   *
+   * The first of these gives us a way of combining a run of several one
+   * bits, and the second gives us a way of handling a single-bit
+   * interruption in such a run.
+   *
+   * We start with a number %$p = \sum_{0\le i<n} p_i 2^i$%, and scan
+   * right-to-left using a simple state-machine keeping (approximate) track
+   * of the two previous bits.  The @Z@ states denote that we're in a string
+   * of zeros; @Z1@ means that we just saw a 1 bit after a sequence of zeros.
+   * Similarly, the @X@ states denote that we're in a string of ones; and
+   * @X0@ means that we just saw a zero bit after a sequence of ones.  The
+   * state machine lets us delay decisions about what to do when we've seen a
+   * change to the status quo (a one after a run of zeros, or vice-versa)
+   * until we've seen the next bit, so we can tell whether this is an
+   * isolated bit or the start of a new sequence.
+   *
+   * More formally: we define two functions %$Z^b_i$% and %$X^b_i$% as
+   * follows.
+   *
+   *   * %$Z^0_i(S, 0) = S$%
+   *   * %$Z^0_i(S, n) = Z^0_{i+1}(S, n)$%
+   *   * %$Z^0_i(S, n + 2^i) = Z^1_{i+1}(S, n)$%
+   *   * %$Z^1_i(S, n) = Z^0_{i+1}(S \cup \{ 2^{i-1} \}, n)$%
+   *   * %$Z^1_i(S, n + 2^i) = X^1_{i+1}(S \cup \{ -2^{i-1} \}, n)$%
+   *   * %$X^0_i(S, n) = Z^0_{i+1}(S, \{ 2^{i-1} \})$%
+   *   * %$X^0_i(S, n + 2^i) = X^1_{i+1}(S \cup \{ -2^{i-1} \}, n)$%
+   *   * %$X^1_i(S, n) = X^0_{i+1}(S, n)$%
+   *   * %$X^1_i(S, n + 2^i) = X^1_{i+1}(S, n)$%
+   *
+   * The reader may verify (by induction on %$n$%) that the following
+   * properties hold.
+   *
+   *   * %$Z^0_0(\emptyset, n)$% is well-defined for all %$n \ge 0$%
+   *   * %$\sum Z^b_i(S, n) = \sum S + n + b 2^{i-1}$%
+   *   * %$\sum X^b_i(S, n) = \sum S + n + (b + 1) 2^{i-1}$%
+   *
+   * From these, of course, we can deduce %$\sum Z^0_0(\emptyset, n) = n$%.
+   *
+   * We apply the above recurrence to build a simple instruction sequence for
+   * adding an appropriate multiple of %$d$% to a given number.  Suppose that
+   * %$2^{w(N-1)} \le 2^{n-1} \le p < 2^n \le 2^{wN}$%.  The machine which
+   * interprets these instructions does so in the context of a
+   * single-precision multiplicand @z@ and a pointer @v@ to the
+   * %%\emph{most}%% significant word of an %$N + 1$%-word integer, and the
+   * instruction sequence should add %$z d$% to this integer.
+   *
+   * The available instructions are named @MPRI_{ADD,SUB}{,LSL}@; they add
+   * (or subtract) %$z$% (shifted left by some amount, in the @LSL@ variants)
+   * to some word earlier than @v@.  The relevant quantities are encoded in
+   * the instruction's immediate operands.
    */
 
 #ifdef DEBUG
    */
 
 #ifdef DEBUG
@@ -118,6 +210,15 @@ int mpreduce_create(mpreduce *r, mp *p)
        INSTR(op | !!b, w, b);
     }
   }
        INSTR(op | !!b, w, b);
     }
   }
+
+  /* --- This doesn't always work --- *
+   *
+   * If %$d \ge 2^{n-1}$% then the above recurrence will output a subtraction
+   * as the final instruction, which may sometimes underflow.  (It interprets
+   * such numbers as being in the form %$2^{n-1} + d$%.)  This is clearly
+   * bad, so detect the situation and fail gracefully.
+   */
+
   if (DA_LEN(&iv) && (DA(&iv)[DA_LEN(&iv) - 1].op & ~1u) == MPRI_SUB) {
     mp_drop(r->p);
     DA_DESTROY(&iv);
   if (DA_LEN(&iv) && (DA(&iv)[DA_LEN(&iv) - 1].op & ~1u) == MPRI_SUB) {
     mp_drop(r->p);
     DA_DESTROY(&iv);
@@ -126,7 +227,30 @@ int mpreduce_create(mpreduce *r, mp *p)
 
 #undef INSTR
 
 
 #undef INSTR
 
-  /* --- Wrap up --- */
+  /* --- Wrap up --- *
+   *
+   * Store the generated instruction sequence in our context structure.  If
+   * %$p$%'s bit length %$n$% is a multiple of the word size %$w$% then
+   * there's nothing much else to do here.  Otherwise, we have an additional
+   * job.
+   *
+   * The reduction operation has three phases.  The first trims entire words
+   * from the argument, and the instruction sequence constructed above does
+   * this well; the second phase reduces an integer which has the same number
+   * of words as %$p$%, but strictly more bits.  (The third phase is a single
+   * conditional subtraction of %$p$%, in case the argument is the same bit
+   * length as %$p$% but greater; this doesn't concern us here.)  To handle
+   * the second phase, we create another copy of the instruction stream, with
+   * all of the target shifts adjusted upwards by %$s = n \bmod w$%.
+   *
+   * In this case, we are acting on an %$(N - 1)$%-word operand, and so
+   * (given the remarks above) must check that this is still valid, but a
+   * moment's reflection shows that this must be fine: the most distant
+   * target must be the bit %$s$% from the top of the least-significant word;
+   * but since we shift all of the targets up by %$s$%, this now addresses
+   * the bottom bit of the next most significant word, and there is no
+   * underflow.
+   */
 
   r->in = DA_LEN(&iv);
   if (!r->in)
 
   r->in = DA_LEN(&iv);
   if (!r->in)
@@ -272,7 +396,7 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
   if (d) MP_DROP(d);
   MP_DEST(x, MP_LEN(x), x->f);
 
   if (d) MP_DROP(d);
   MP_DEST(x, MP_LEN(x), x->f);
 
-  /* --- Do the reduction --- */
+  /* --- Stage one: trim excess words from the most significant end --- */
 
 #ifdef DEBUG
   _r = MP_NEW;
 
 #ifdef DEBUG
   _r = MP_NEW;
@@ -297,6 +421,9 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
 #endif
       }
     }
 #endif
       }
     }
+
+    /* --- Stage two: trim excess bits from the most significant word --- */
+
     if (r->s) {
       while (*vl >> r->s) {
        z = *vl >> r->s;
     if (r->s) {
       while (*vl >> r->s) {
        z = *vl >> r->s;
@@ -311,7 +438,7 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
     }
   }
 
     }
   }
 
-  /* --- Finishing touches --- */
+  /* --- Stage three: conditional subtraction --- */
 
   MP_SHRINK(x);
   if (MP_CMP(x, >=, r->p))
 
   MP_SHRINK(x);
   if (MP_CMP(x, >=, r->p))