math/mpx.c: Eliminate clone-and-hack of shifting primitives.
authorMark Wooding <mdw@distorted.org.uk>
Thu, 27 Mar 2014 03:18:32 +0000 (03:18 +0000)
committerMark Wooding <mdw@distorted.org.uk>
Sat, 29 Mar 2014 10:39:07 +0000 (10:39 +0000)
Replace with some fancy macros.

math/mpx.c

index 5a9a176..5f7ffab 100644 (file)
@@ -375,102 +375,159 @@ void mpx_loadb2cn(mpw *v, mpw *vl, const void *pp, size_t sz)
 
 /*----- Logical shifting --------------------------------------------------*/
 
-/* --- @mpx_lsl@ --- *
+/* --- @MPX_SHIFT1@ --- *
  *
- * Arguments:  @mpw *dv, *dvl@ = destination vector base and limit
- *             @const mpw *av, *avl@ = source vector base and limit
- *             @size_t n@ = number of bit positions to shift by
+ * Arguments:  @init@ = initial accumulator value
+ *             @out@ = expression to store in each output word
+ *             @next@ = expression for next accumulator value
  *
- * Returns:    ---
+ * Use:                Performs a single-position shift.  The input is scanned
+ *             right-to-left.  In the expressions @out@ and @next@, the
+ *             accumulator is available in @w@ and the current input word is
+ *             in @t@.
  *
- * Use:                Performs a logical shift left operation on an integer.
+ *             This macro is intended to be used in the @shift1@ argument of
+ *             @MPX_SHIFTOP@, and expects variables describing the operation
+ *             to be set up accordingly.
  */
 
-void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
-{
-  size_t nw;
-  unsigned nb;
-
-  /* --- Trivial special case --- */
-
-  if (n == 0)
-    MPX_COPY(dv, dvl, av, avl);
-
-  /* --- Single bit shifting --- */
-
-  else if (n == 1) {
-    mpw w = 0;
-    while (av < avl) {
-      mpw t;
-      if (dv >= dvl)
-       goto done;
-      t = *av++;
-      *dv++ = MPW((t << 1) | w);
-      w = t >> (MPW_BITS - 1);
-    }
-    if (dv >= dvl)
-      goto done;
-    *dv++ = MPW(w);
-    MPX_ZERO(dv, dvl);
-    goto done;
-  }
-
-  /* --- Break out word and bit shifts for more sophisticated work --- */
-
-  nw = n / MPW_BITS;
-  nb = n % MPW_BITS;
-
-  /* --- Handle a shift by a multiple of the word size --- */
+#define MPX_SHIFT1(init, out, next) do {                               \
+  mpw t, w = (init);                                                   \
+  while (av < avl) {                                                   \
+    if (dv >= dvl) break;                                              \
+    t = MPW(*av++);                                                    \
+    *dv++ = (out);                                                     \
+    w = (next);                                                                \
+  }                                                                    \
+  if (dv < dvl) { *dv++ = MPW(w); MPX_ZERO(dv, dvl); }                 \
+} while (0)
 
-  if (nb == 0) {
-    if (nw >= dvl - dv)
-      MPX_ZERO(dv, dvl);
-    else {
-      MPX_COPY(dv + nw, dvl, av, avl);
-      memset(dv, 0, MPWS(nw));
-    }
-  }
+/* --- @MPX_SHIFTW@ --- *
+ *
+ * Arguments:  @max@ = the maximum shift (in words) which is nontrivial
+ *             @clear@ = function (or macro) to clear low-order output words
+ *             @copy@ = statement to copy words from input to output
+ *
+ * Use:                Performs a shift by a whole number of words.  If the shift
+ *             amount is @max@ or more words, then the destination is
+ *             @clear@ed entirely; otherwise, @copy@ is executed.
+ *
+ *             This macro is intended to be used in the @shiftw@ argument of
+ *             @MPX_SHIFTOP@, and expects variables describing the operation
+ *             to be set up accordingly.
+ */
 
-  /* --- And finally the difficult case --- *
-   *
-   * This is a little convoluted, because I have to start from the end and
-   * work backwards to avoid overwriting the source, if they're both the same
-   * block of memory.
-   */
+#define MPX_SHIFTW(max, clear, copy) do {                              \
+  if (nw >= (max)) clear(dv, dvl);                                     \
+  else copy                                                            \
+} while (0)
 
-  else {
-    mpw w;
-    size_t nr = MPW_BITS - nb;
-    size_t dvn = dvl - dv;
-    size_t avn = avl - av;
+/* --- @MPX_SHIFTOP@ --- *
+ *
+ * Arguments:  @name@ = name of function to define (without `@mpx_@' prefix)
+ *             @shift1@ = statement to shift by a single bit
+ *             @shiftw@ = statement to shift by a whole number of words
+ *             @shift@ = statement to perform a general shift
+ *
+ * Use:                Emits a shift operation.  The input is @av@..@avl@; the
+ *             output is @dv@..@dvl@; and the shift amount (in bits) is
+ *             @n@.  In @shiftw@ and @shift@, @nw@ and @nb@ are set up such
+ *             that @n = nw*MPW_BITS + nb@ and @nb < MPW_BITS@.
+ */
 
-    if (dvn <= nw) {
-      MPX_ZERO(dv, dvl);
-      goto done;
-    }
+#define MPX_SHIFTOP(name, shift1, shiftw, shift)                       \
+                                                                       \
+void mpx_##name(mpw *dv, mpw *dvl,                                     \
+               const mpw *av, const mpw *avl,                          \
+               size_t n)                                               \
+{                                                                      \
+                                                                       \
+  if (n == 0)                                                          \
+    MPX_COPY(dv, dvl, av, avl);                                                \
+  else if (n == 1)                                                     \
+    do shift1 while (0);                                               \
+  else {                                                               \
+    size_t nw = n/MPW_BITS;                                            \
+    unsigned nb = n%MPW_BITS;                                          \
+    if (!nb) do shiftw while (0);                                      \
+    else do shift while (0);                                           \
+  }                                                                    \
+}
 
-    if (dvn > avn + nw) {
-      size_t off = avn + nw + 1;
-      MPX_ZERO(dv + off, dvl);
-      dvl = dv + off;
-      w = 0;
-    } else {
-      avl = av + dvn - nw;
-      w = *--avl << nb;
-    }
+/* --- @MPX_SHIFT_LEFT@ --- *
+ *
+ * Arguments:  @name@ = name of function to define (without `@mpx_@' prefix)
+ *             @init1@ = initializer for single-bit shift accumulator
+ *             @clear@ = function (or macro) to clear low-order output words
+ *             @flush@ = expression for low-order nontrivial output word
+ *
+ * Use:                Emits a left-shift operation.  This expands to a call on
+ *             @MPX_SHIFTOP@, but implements the complicated @shift@
+ *             statement.
+ *
+ *             The @init1@ argument is as for @MPX_SHIFT1@, and @clear@ is
+ *             as for @MPX_SHIFTW@ (though is used elsewhere).  In a general
+ *             shift, @nw@ whole low-order output words are set using
+ *             @clear@; high-order words are zeroed; and the remaining words
+ *             set with a left-to-right pass across the input; at the end of
+ *             the operation, the least significant output word above those
+ *             @clear@ed is set using @flush@, which may use the accumulator
+ *             @w@ = @av[0] << nb@.
+ */
 
-    while (avl > av) {
-      mpw t = *--avl;
-      *--dvl = MPW((t >> nr) | w);
-      w = t << nb;
-    }
+#define MPX_SHIFT_LEFT(name, init1, clear, flush)                      \
+MPX_SHIFTOP(name, {                                                    \
+  MPX_SHIFT1(init1,                                                    \
+            w | (t << 1),                                              \
+            t >> (MPW_BITS - 1));                                      \
+}, {                                                                   \
+  MPX_SHIFTW(dvl - dv, clear, {                                                \
+    MPX_COPY(dv + nw, dvl, av, avl);                                   \
+    clear(dv, dv + nw);                                                        \
+  });                                                                  \
+}, {                                                                   \
+  size_t nr = MPW_BITS - nb;                                           \
+  size_t dvn = dvl - dv;                                               \
+  size_t avn = avl - av;                                               \
+  mpw w;                                                               \
+                                                                       \
+  if (dvn <= nw) {                                                     \
+    clear(dv, dvl);                                                    \
+    break;                                                             \
+  }                                                                    \
+                                                                       \
+  if (dvn <= avn + nw) {                                               \
+    avl = av + dvn - nw;                                               \
+    w = *--avl << nb;                                                  \
+  } else {                                                             \
+    size_t off = avn + nw + 1;                                         \
+    MPX_ZERO(dv + off, dvl);                                           \
+    dvl = dv + off;                                                    \
+    w = 0;                                                             \
+  }                                                                    \
+                                                                       \
+  while (avl > av) {                                                   \
+    mpw t = *--avl;                                                    \
+    *--dvl = MPW(w | (t >> nr));                                       \
+    w = t << nb;                                                       \
+  }                                                                    \
+                                                                       \
+  *--dvl = MPW(flush);                                                 \
+  clear(dv, dvl);                                                      \
+})
 
-    *--dvl = MPW(w);
-    MPX_ZERO(dv, dvl);
-  }
+/* --- @mpx_lsl@ --- *
+ *
+ * Arguments:  @mpw *dv, *dvl@ = destination vector base and limit
+ *             @const mpw *av, *avl@ = source vector base and limit
+ *             @size_t n@ = number of bit positions to shift by
+ *
+ * Returns:    ---
+ *
+ * Use:                Performs a logical shift left operation on an integer.
+ */
 
-done:;
-}
+MPX_SHIFT_LEFT(lsl, 0, MPX_ZERO, w)
 
 /* --- @mpx_lslc@ --- *
  *
@@ -484,91 +541,7 @@ done:;
  *             it fills in the bits with ones instead of zeroes.
  */
 
-void mpx_lslc(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
-{
-  size_t nw;
-  unsigned nb;
-
-  /* --- Trivial special case --- */
-
-  if (n == 0)
-    MPX_COPY(dv, dvl, av, avl);
-
-  /* --- Single bit shifting --- */
-
-  else if (n == 1) {
-    mpw w = 1;
-    while (av < avl) {
-      mpw t;
-      if (dv >= dvl)
-       goto done;
-      t = *av++;
-      *dv++ = MPW((t << 1) | w);
-      w = t >> (MPW_BITS - 1);
-    }
-    if (dv >= dvl)
-      goto done;
-    *dv++ = MPW(w);
-    MPX_ZERO(dv, dvl);
-    goto done;
-  }
-
-  /* --- Break out word and bit shifts for more sophisticated work --- */
-
-  nw = n / MPW_BITS;
-  nb = n % MPW_BITS;
-
-  /* --- Handle a shift by a multiple of the word size --- */
-
-  if (nb == 0) {
-    if (nw >= dvl - dv)
-      MPX_ONE(dv, dvl);
-    else {
-      MPX_COPY(dv + nw, dvl, av, avl);
-      MPX_ONE(dv, dv + nw);
-    }
-  }
-
-  /* --- And finally the difficult case --- *
-   *
-   * This is a little convoluted, because I have to start from the end and
-   * work backwards to avoid overwriting the source, if they're both the same
-   * block of memory.
-   */
-
-  else {
-    mpw w;
-    size_t nr = MPW_BITS - nb;
-    size_t dvn = dvl - dv;
-    size_t avn = avl - av;
-
-    if (dvn <= nw) {
-      MPX_ONE(dv, dvl);
-      goto done;
-    }
-
-    if (dvn > avn + nw) {
-      size_t off = avn + nw + 1;
-      MPX_ZERO(dv + off, dvl);
-      dvl = dv + off;
-      w = 0;
-    } else {
-      avl = av + dvn - nw;
-      w = *--avl << nb;
-    }
-
-    while (avl > av) {
-      mpw t = *--avl;
-      *--dvl = MPW((t >> nr) | w);
-      w = t << nb;
-    }
-
-    *--dvl = MPW((MPW_MAX >> nr) | w);
-    MPX_ONE(dv, dvl);
-  }
-
-done:;
-}
+MPX_SHIFT_LEFT(lslc, 1, MPX_ONE, w | (MPW_MAX >> nr))
 
 /* --- @mpx_lsr@ --- *
  *
@@ -581,73 +554,32 @@ done:;
  * Use:                Performs a logical shift right operation on an integer.
  */
 
-void mpx_lsr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
-{
-  size_t nw;
-  unsigned nb;
-
-  /* --- Trivial special case --- */
-
-  if (n == 0)
-    MPX_COPY(dv, dvl, av, avl);
-
-  /* --- Single bit shifting --- */
-
-  else if (n == 1) {
-    mpw w = av < avl ? *av++ >> 1 : 0;
-    while (av < avl) {
-      mpw t;
-      if (dv >= dvl)
-       goto done;
-      t = *av++;
-      *dv++ = MPW((t << (MPW_BITS - 1)) | w);
-      w = t >> 1;
-    }
-    if (dv >= dvl)
-      goto done;
-    *dv++ = MPW(w);
-    MPX_ZERO(dv, dvl);
-    goto done;
-  }
-
-  /* --- Break out word and bit shifts for more sophisticated work --- */
-
-  nw = n / MPW_BITS;
-  nb = n % MPW_BITS;
-
-  /* --- Handle a shift by a multiple of the word size --- */
-
-  if (nb == 0) {
-    if (nw >= avl - av)
-      MPX_ZERO(dv, dvl);
-    else
-      MPX_COPY(dv, dvl, av + nw, avl);
+MPX_SHIFTOP(lsr, {
+  MPX_SHIFT1(av < avl ? *av++ >> 1 : 0,
+            w | (t << (MPW_BITS - 1)),
+            t >> 1);
+}, {
+  MPX_SHIFTW(avl - av, MPX_ZERO,
+            { MPX_COPY(dv, dvl, av + nw, avl); });
+}, {
+  size_t nr = MPW_BITS - nb;
+  mpw w;
+
+  av += nw;
+  w = av < avl ? *av++ : 0;
+  while (av < avl) {
+    mpw t;
+    if (dv >= dvl) goto done;
+    t = *av++;
+    *dv++ = MPW((w >> nb) | (t << nr));
+    w = t;
   }
-
-  /* --- And finally the difficult case --- */
-
-  else {
-    mpw w;
-    size_t nr = MPW_BITS - nb;
-
-    av += nw;
-    w = av < avl ? *av++ : 0;
-    while (av < avl) {
-      mpw t;
-      if (dv >= dvl)
-       goto done;
-      t = *av++;
-      *dv++ = MPW((w >> nb) | (t << nr));
-      w = t;
-    }
-    if (dv < dvl) {
-      *dv++ = MPW(w >> nb);
-      MPX_ZERO(dv, dvl);
-    }
+  if (dv < dvl) {
+    *dv++ = MPW(w >> nb);
+    MPX_ZERO(dv, dvl);
   }
-
 done:;
-}
+})
 
 /*----- Bitwise operations ------------------------------------------------*/