math/gfreduce.[ch]: Fix out-of-bounds memory access.
[u/mdw/catacomb] / math / mpreduce.c
index b4543a6..b148dd5 100644 (file)
@@ -81,7 +81,10 @@ DA_DECL(instr_v, mpreduce_instr);
  * Arguments:  @gfreduce *r@ = structure to fill in
  *             @mp *x@ = an integer
  *
- * Returns:    Zero if successful; nonzero on failure.
+ * Returns:    Zero if successful; nonzero on failure.  The current
+ *             algorithm always succeeds when given positive @x@.  Earlier
+ *             versions used to fail on particular kinds of integers, but
+ *             this is guaranteed not to happen any more.
  *
  * Use:                Initializes a context structure for reduction.
  */
@@ -180,21 +183,6 @@ int mpreduce_create(mpreduce *r, mp *p)
    * the instruction's immediate operands.
    */
 
-#ifdef DEBUG
-  for (i = 0, mp_scan(&sc, p); mp_step(&sc); i++) {
-    switch (st | mp_bit(&sc)) {
-      case  Z | 1: st = Z1; break;
-      case Z1 | 0: st =         Z; printf("+ %lu\n", i - 1); break;
-      case Z1 | 1: st =         X; printf("- %lu\n", i - 1); break;
-      case  X | 0: st = X0; break;
-      case X0 | 1: st =         X; printf("- %lu\n", i - 1); break;
-      case X0 | 0: st =         Z; printf("+ %lu\n", i - 1); break;
-    }
-  }
-  if (st >= X) printf("+ %lu\n", i - 1);
-  st = Z;
-#endif
-
   bb = MPW_BITS - (d + 1)%MPW_BITS;
   for (i = 0, mp_scan(&sc, p); i < d && mp_step(&sc); i++) {
     switch (st | mp_bit(&sc)) {
@@ -211,18 +199,27 @@ int mpreduce_create(mpreduce *r, mp *p)
     }
   }
 
-  /* --- This doesn't always work --- *
+  /* --- Fix up wrong-sided decompositions --- *
+   *
+   * At this point, we haven't actually finished up the state machine
+   * properly.  We stopped scanning just after bit %$n - 1$% -- the most
+   * significant one, which we know in advance must be set (since @x@ is
+   * strictly positive).  Therefore we are either in state @X@ or @Z1@.  In
+   * the former case, we have nothing to do.  In the latter, there are two
+   * subcases to deal with.  If there are no other instructions, then @x@ is
+   * a perfect power of two, and %$d = 0$%, so again there is nothing to do.
    *
-   * If %$d \ge 2^{n-1}$% then the above recurrence will output a subtraction
-   * as the final instruction, which may sometimes underflow.  (It interprets
-   * such numbers as being in the form %$2^{n-1} + d$%.)  This is clearly
-   * bad, so detect the situation and fail gracefully.
+   * In the remaining case, we have decomposed @x@ as %$2^{n-1} + d$%, for
+   * some positive %$d%, which is unfortuante: if we're asked to reduce
+   * %$2^n$%, say, we'll end up with %$-d$% (or would do, if we weren't
+   * sticking to unsigned arithmetic for good performance).  So instead, we
+   * rewrite this as %$2^n - 2^{n-1} + d$% and everything will be good.
    */
 
-  if (DA_LEN(&iv) && (DA(&iv)[DA_LEN(&iv) - 1].op & ~1u) == MPRI_SUB) {
-    mp_drop(r->p);
-    DA_DESTROY(&iv);
-    return (-1);
+  if (st == Z1 && DA_LEN(&iv)) {
+    w = 1;
+    b = (bb + d)%MPW_BITS;
+    INSTR(MPRI_ADD | !!b, w, b);
   }
 
 #undef INSTR
@@ -278,9 +275,6 @@ int mpreduce_create(mpreduce *r, mp *p)
   }
   DA_DESTROY(&iv);
 
-#ifdef DEBUG
-  mpreduce_dump(r, stdout);
-#endif
   return (0);
 }
 
@@ -348,15 +342,6 @@ static void run(const mpreduce_instr *i, const mpreduce_instr *il,
                mpw *v, mpw z)
 {
   for (; i < il; i++) {
-#ifdef DEBUG
-    mp vv;
-    mp_build(&vv, v - i->argx, v + 1);
-    printf("  0x"); mp_writefile(&vv, stdout, 16);
-    printf(" %c (0x%lx << %u) == 0x",
-          (i->op & ~1u) == MPRI_ADD ? '+' : '-',
-          (unsigned long)z,
-          i->argy);
-#endif
     switch (i->op) {
       case MPRI_ADD: MPX_UADDN(v - i->argx, v + 1, z); break;
       case MPRI_ADDLSL: mpx_uaddnlsl(v - i->argx, v + 1, z, i->argy); break;
@@ -365,11 +350,6 @@ static void run(const mpreduce_instr *i, const mpreduce_instr *il,
       default:
        abort();
     }
-#ifdef DEBUG
-    mp_build(&vv, v - i->argx, v + 1);
-    mp_writefile(&vv, stdout, 16);
-    printf("\n");
-#endif
   }
 }
 
@@ -379,10 +359,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
   const mpreduce_instr *il;
   mpw z;
 
-#ifdef DEBUG
-  mp *_r = 0, *_rr = 0;
-#endif
-
   /* --- If source is negative, divide --- */
 
   if (MP_NEGP(x)) {
@@ -398,13 +374,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
 
   /* --- Stage one: trim excess words from the most significant end --- */
 
-#ifdef DEBUG
-  _r = MP_NEW;
-  mp_div(0, &_r, x, r->p);
-  MP_PRINTX("x", x);
-  _rr = 0;
-#endif
-
   il = r->iv + r->in;
   if (MP_LEN(x) >= r->lim) {
     v = x->v + r->lim;
@@ -414,11 +383,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
        z = *vl;
        *vl = 0;
        run(r->iv, il, vl, z);
-#ifdef DEBUG
-       MP_PRINTX("x", x);
-       mp_div(0, &_rr, x, r->p);
-       assert(MP_EQ(_r, _rr));
-#endif
       }
     }
 
@@ -429,11 +393,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
        z = *vl >> r->s;
        *vl &= ((1 << r->s) - 1);
        run(r->iv + r->in, il + r->in, vl, z);
-#ifdef DEBUG
-       MP_PRINTX("x", x);
-       mp_div(0, &_rr, x, r->p);
-       assert(MP_EQ(_r, _rr));
-#endif
       }
     }
   }
@@ -446,11 +405,6 @@ mp *mpreduce_do(mpreduce *r, mp *d, mp *x)
 
   /* --- Done --- */
 
-#ifdef DEBUG
-  assert(MP_EQ(_r, x));
-  mp_drop(_r);
-  mp_drop(_rr);
-#endif
   return (x);
 }