X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb/blobdiff_plain/23bbea75793621e6b21fbb13c00d8223113cf7b5..ea1b3cec199052eda3a0054d86c70e948c6e7580:/math/mpx.c diff --git a/math/mpx.c b/math/mpx.c index 5a9a1760..d3d0a04a 100644 --- a/math/mpx.c +++ b/math/mpx.c @@ -27,6 +27,8 @@ /*----- Header files ------------------------------------------------------*/ +#include "config.h" + #include #include #include @@ -35,12 +37,173 @@ #include #include +#include "dispatch.h" #include "mptypes.h" #include "mpx.h" #include "bitops.h" /*----- Loading and storing -----------------------------------------------*/ +/* --- These are all variations on a theme --- * + * + * Essentially we want to feed bits into a shift register, @ibits@ bits at a + * time, and extract them @obits@ bits at a time whenever there are enough. + * Of course, @i@ and @o@ will, in general, be different sizes, and we don't + * necessarily know which is larger. + * + * During an operation, we have a shift register @w@ and a most-recent input + * @t@. Together, these hold @bits@ significant bits of input. We arrange + * that @bits < ibits + obits <= 2*MPW_BITS@, so we can get away with using + * an @mpw@ for both of these quantitities. + */ + +/* --- @MPX_GETBITS@ --- * + * + * Arguments: @ibits@ = width of input units, in bits + * @obits@ = width of output units, in bits + * @iavail@ = condition expression: is input data available? + * @getbits@ = function or macro: set argument to next input + * + * Use: Read an input unit into @t@ and update the necessary + * variables. + * + * It is assumed on entry that @bits < obits@. On exit, we have + * @bits < ibits + obits@, and @t@ is live. + */ + +#define MPX_GETBITS(ibits, obits, iavail, getbits) do { \ + if (!iavail) goto flush; \ + if (bits >= ibits) w |= t << (bits - ibits); \ + getbits(t); \ + bits += ibits; \ +} while (0) + +/* --- @MPX_PUTBITS@ --- * + * + * Arguments: @ibits@ = width of input units, in bits + * @obits@ = width of output units, in bits + * @oavail@ = condition expression: is output space available? + * @putbits@ = function or macro: write its argument to output + * + * Use: Emit an output unit, and update the necessary variables. If + * the output buffer is full, then force an immediate return. + * + * We assume that @bits < ibits + obits@, and that @t@ is only + * relevant if @bits >= ibits@. (The @MPX_GETBITS@ macro + * ensures that this is true.) + */ + +#define SHRW(w, b) ((b) < MPW_BITS ? (w) >> (b) : 0) + +#define MPX_PUTBITS(ibits, obits, oavail, putbits) do { \ + if (!oavail) return; \ + if (bits < ibits) { \ + putbits(w); \ + bits -= obits; \ + w = SHRW(w, obits); \ + } else { \ + putbits(w | (t << (bits - ibits))); \ + bits -= obits; \ + if (bits >= ibits) w = SHRW(w, obits) | (t << (bits - ibits)); \ + else w = SHRW(w, obits) | (t >> (ibits - bits)); \ + t = 0; \ + } \ +} while (0) + +/* --- @MPX_LOADSTORE@ --- * + * + * Arguments: @name@ = name of function to create, without @mpx_@ prefix + * @wconst@ = qualifiers for @mpw *@ arguments + * @oconst@ = qualifiers for octet pointers + * @decls@ = additional declarations needed + * @ibits@ = width of input units, in bits + * @iavail@ = condition expression: is input data available? + * @getbits@ = function or macro: set argument to next input + * @obits@ = width of output units, in bits + * @oavail@ = condition expression: is output space available? + * @putbits@ = function or macro: write its argument to output + * @fixfinal@ = statements to fix shift register at the end + * @clear@ = statements to clear remainder of output + * + * Use: Generates a function to convert between a sequence of + * multiprecision words and a vector of octets. + * + * The arguments @ibits@, @iavail@ and @getbits@ are passed on + * to @MPX_GETBITS@; similarly, @obits@, @oavail@, and @putbits@ + * are passed on to @MPX_PUTBITS@. + * + * The following variables are in scope: @v@ and @vl are the + * current base and limit of the word vector; @p@ and @q@ are + * the base and limit of the octet vector; @w@ and @t@ form the + * shift register used during the conversion (see commentary + * above); and @bits@ tracks the number of live bits in the + * shift register. + */ + +#define MPX_LOADSTORE(name, wconst, oconst, decls, \ + ibits, iavail, getbits, obits, oavail, putbits, \ + fixfinal, clear) \ + \ +void mpx_##name(wconst mpw *v, wconst mpw *vl, \ + oconst void *pp, size_t sz) \ +{ \ + mpw t = 0, w = 0; \ + oconst octet *p = pp, *q = p + sz; \ + int bits = 0; \ + decls \ + \ + for (;;) { \ + while (bits < obits) MPX_GETBITS(ibits, obits, iavail, getbits); \ + while (bits >= obits) MPX_PUTBITS(ibits, obits, oavail, putbits); \ + } \ + \ +flush: \ + if (bits) { \ + fixfinal; \ + while (bits > 0) MPX_PUTBITS(ibits, obits, oavail, putbits); \ + } \ + clear; \ +} + +#define EMPTY + +/* --- Macros for @getbits@ and @putbits@ --- */ + +#define GETMPW(t) do { t = *v++; } while (0) +#define PUTMPW(x) do { *v++ = MPW(x); } while (0) + +#define GETOCTETI(t) do { t = *p++; } while (0) +#define PUTOCTETD(x) do { *--q = U8(x); } while (0) + +#define PUTOCTETI(x) do { *p++ = U8(x); } while (0) +#define GETOCTETD(t) do { t = *--q; } while (0) + +/* --- Machinery for two's complement I/O --- */ + +#define DECL_2CN \ + unsigned c = 1; + +#define GETMPW_2CN(t) do { \ + t = MPW(~*v++ + c); \ + c = c && !t; \ +} while (0) + +#define PUTMPW_2CN(t) do { \ + mpw _t = MPW(~(t) + c); \ + c = c && !_t; \ + *v++ = _t; \ +} while (0) + +#define FIXFINALW_2CN do { \ + if (c && !w && !t); \ + else if (bits == 8) t ^= ~(mpw)0xffu; \ + else t ^= ((mpw)1 << (MPW_BITS - bits + 8)) - 256u; \ +} while (0) + +#define FLUSHO_2CN do { \ + memset(p, c ? 0 : 0xff, q - p); \ +} while (0) + /* --- @mpx_storel@ --- * * * Arguments: @const mpw *v, *vl@ = base and limit of source vector @@ -54,30 +217,10 @@ * isn't enough space for them. */ -void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz) -{ - mpw n, w = 0; - octet *p = pp, *q = p + sz; - unsigned bits = 0; - - while (p < q) { - if (bits < 8) { - if (v >= vl) { - *p++ = U8(w); - break; - } - n = *v++; - *p++ = U8(w | n << bits); - w = n >> (8 - bits); - bits += MPW_BITS - 8; - } else { - *p++ = U8(w); - w >>= 8; - bits -= 8; - } - } - memset(p, 0, q - p); -} +MPX_LOADSTORE(storel, const, EMPTY, EMPTY, + MPW_BITS, (v < vl), GETMPW, + 8, (p < q), PUTOCTETI, + EMPTY, { memset(p, 0, q - p); }) /* --- @mpx_loadl@ --- * * @@ -92,30 +235,11 @@ void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz) * space for them. */ -void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz) -{ - unsigned n; - mpw w = 0; - const octet *p = pp, *q = p + sz; - unsigned bits = 0; +MPX_LOADSTORE(loadl, EMPTY, const, EMPTY, + 8, (p < q), GETOCTETI, + MPW_BITS, (v < vl), PUTMPW, + EMPTY, { MPX_ZERO(v, vl); }) - if (v >= vl) - return; - while (p < q) { - n = U8(*p++); - w |= n << bits; - bits += 8; - if (bits >= MPW_BITS) { - *v++ = MPW(w); - w = n >> (MPW_BITS - bits + 8); - bits -= MPW_BITS; - if (v >= vl) - return; - } - } - *v++ = w; - MPX_ZERO(v, vl); -} /* --- @mpx_storeb@ --- * * @@ -130,30 +254,10 @@ void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz) * isn't enough space for them. */ -void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz) -{ - mpw n, w = 0; - octet *p = pp, *q = p + sz; - unsigned bits = 0; - - while (q > p) { - if (bits < 8) { - if (v >= vl) { - *--q = U8(w); - break; - } - n = *v++; - *--q = U8(w | n << bits); - w = n >> (8 - bits); - bits += MPW_BITS - 8; - } else { - *--q = U8(w); - w >>= 8; - bits -= 8; - } - } - memset(p, 0, q - p); -} +MPX_LOADSTORE(storeb, const, EMPTY, EMPTY, + MPW_BITS, (v < vl), GETMPW, + 8, (p < q), PUTOCTETD, + EMPTY, { memset(p, 0, q - p); }) /* --- @mpx_loadb@ --- * * @@ -168,30 +272,10 @@ void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz) * space for them. */ -void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz) -{ - unsigned n; - mpw w = 0; - const octet *p = pp, *q = p + sz; - unsigned bits = 0; - - if (v >= vl) - return; - while (q > p) { - n = U8(*--q); - w |= n << bits; - bits += 8; - if (bits >= MPW_BITS) { - *v++ = MPW(w); - w = n >> (MPW_BITS - bits + 8); - bits -= MPW_BITS; - if (v >= vl) - return; - } - } - *v++ = w; - MPX_ZERO(v, vl); -} +MPX_LOADSTORE(loadb, EMPTY, const, EMPTY, + 8, (p < q), GETOCTETD, + MPW_BITS, (v < vl), PUTMPW, + EMPTY, { MPX_ZERO(v, vl); }) /* --- @mpx_storel2cn@ --- * * @@ -207,40 +291,10 @@ void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz) * This obviously makes the output bad. */ -void mpx_storel2cn(const mpw *v, const mpw *vl, void *pp, size_t sz) -{ - unsigned c = 1; - unsigned b = 0; - mpw n, w = 0; - octet *p = pp, *q = p + sz; - unsigned bits = 0; - - while (p < q) { - if (bits < 8) { - if (v >= vl) { - b = w; - break; - } - n = *v++; - b = w | n << bits; - w = n >> (8 - bits); - bits += MPW_BITS - 8; - } else { - b = w; - w >>= 8; - bits -= 8; - } - b = U8(~b + c); - c = c && !b; - *p++ = b; - } - while (p < q) { - b = U8(~b + c); - c = c && !b; - *p++ = b; - b = 0; - } -} +MPX_LOADSTORE(storel2cn, const, EMPTY, DECL_2CN, + MPW_BITS, (v < vl), GETMPW_2CN, + 8, (p < q), PUTOCTETI, + EMPTY, { FLUSHO_2CN; }) /* --- @mpx_loadl2cn@ --- * * @@ -256,32 +310,10 @@ void mpx_storel2cn(const mpw *v, const mpw *vl, void *pp, size_t sz) * means you made the wrong choice coming here. */ -void mpx_loadl2cn(mpw *v, mpw *vl, const void *pp, size_t sz) -{ - unsigned n; - unsigned c = 1; - mpw w = 0; - const octet *p = pp, *q = p + sz; - unsigned bits = 0; - - if (v >= vl) - return; - while (p < q) { - n = U8(~(*p++) + c); - c = c && !n; - w |= n << bits; - bits += 8; - if (bits >= MPW_BITS) { - *v++ = MPW(w); - w = n >> (MPW_BITS - bits + 8); - bits -= MPW_BITS; - if (v >= vl) - return; - } - } - *v++ = w; - MPX_ZERO(v, vl); -} +MPX_LOADSTORE(loadl2cn, EMPTY, const, DECL_2CN, + 8, (p < q), GETOCTETI, + MPW_BITS, (v < vl), PUTMPW_2CN, + { FIXFINALW_2CN; }, { MPX_ZERO(v, vl); }) /* --- @mpx_storeb2cn@ --- * * @@ -297,40 +329,10 @@ void mpx_loadl2cn(mpw *v, mpw *vl, const void *pp, size_t sz) * which probably isn't what you meant. */ -void mpx_storeb2cn(const mpw *v, const mpw *vl, void *pp, size_t sz) -{ - mpw n, w = 0; - unsigned b = 0; - unsigned c = 1; - octet *p = pp, *q = p + sz; - unsigned bits = 0; - - while (q > p) { - if (bits < 8) { - if (v >= vl) { - b = w; - break; - } - n = *v++; - b = w | n << bits; - w = n >> (8 - bits); - bits += MPW_BITS - 8; - } else { - b = w; - w >>= 8; - bits -= 8; - } - b = U8(~b + c); - c = c && !b; - *--q = b; - } - while (q > p) { - b = ~b + c; - c = c && !(b & 0xff); - *--q = b; - b = 0; - } -} +MPX_LOADSTORE(storeb2cn, const, EMPTY, DECL_2CN, + MPW_BITS, (v < vl), GETMPW_2CN, + 8, (p < q), PUTOCTETD, + EMPTY, { FLUSHO_2CN; }) /* --- @mpx_loadb2cn@ --- * * @@ -346,131 +348,166 @@ void mpx_storeb2cn(const mpw *v, const mpw *vl, void *pp, size_t sz) * chose this function wrongly. */ -void mpx_loadb2cn(mpw *v, mpw *vl, const void *pp, size_t sz) -{ - unsigned n; - unsigned c = 1; - mpw w = 0; - const octet *p = pp, *q = p + sz; - unsigned bits = 0; - - if (v >= vl) - return; - while (q > p) { - n = U8(~(*--q) + c); - c = c && !n; - w |= n << bits; - bits += 8; - if (bits >= MPW_BITS) { - *v++ = MPW(w); - w = n >> (MPW_BITS - bits + 8); - bits -= MPW_BITS; - if (v >= vl) - return; - } - } - *v++ = w; - MPX_ZERO(v, vl); -} +MPX_LOADSTORE(loadb2cn, EMPTY, const, DECL_2CN, + 8, (p < q), GETOCTETD, + MPW_BITS, (v < vl), PUTMPW_2CN, + { FIXFINALW_2CN; }, { MPX_ZERO(v, vl); }) /*----- Logical shifting --------------------------------------------------*/ -/* --- @mpx_lsl@ --- * +/* --- @MPX_SHIFT1@ --- * * - * Arguments: @mpw *dv, *dvl@ = destination vector base and limit - * @const mpw *av, *avl@ = source vector base and limit - * @size_t n@ = number of bit positions to shift by + * Arguments: @init@ = initial accumulator value + * @out@ = expression to store in each output word + * @next@ = expression for next accumulator value * - * Returns: --- + * Use: Performs a single-position shift. The input is scanned + * right-to-left. In the expressions @out@ and @next@, the + * accumulator is available in @w@ and the current input word is + * in @t@. * - * Use: Performs a logical shift left operation on an integer. + * This macro is intended to be used in the @shift1@ argument of + * @MPX_SHIFTOP@, and expects variables describing the operation + * to be set up accordingly. */ -void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n) -{ - size_t nw; - unsigned nb; - - /* --- Trivial special case --- */ - - if (n == 0) - MPX_COPY(dv, dvl, av, avl); - - /* --- Single bit shifting --- */ - - else if (n == 1) { - mpw w = 0; - while (av < avl) { - mpw t; - if (dv >= dvl) - goto done; - t = *av++; - *dv++ = MPW((t << 1) | w); - w = t >> (MPW_BITS - 1); - } - if (dv >= dvl) - goto done; - *dv++ = MPW(w); - MPX_ZERO(dv, dvl); - goto done; - } - - /* --- Break out word and bit shifts for more sophisticated work --- */ - - nw = n / MPW_BITS; - nb = n % MPW_BITS; - - /* --- Handle a shift by a multiple of the word size --- */ +#define MPX_SHIFT1(init, out, next) do { \ + mpw t, w = (init); \ + while (av < avl) { \ + if (dv >= dvl) break; \ + t = MPW(*av++); \ + *dv++ = (out); \ + w = (next); \ + } \ + if (dv < dvl) { *dv++ = MPW(w); MPX_ZERO(dv, dvl); } \ +} while (0) - if (nb == 0) { - if (nw >= dvl - dv) - MPX_ZERO(dv, dvl); - else { - MPX_COPY(dv + nw, dvl, av, avl); - memset(dv, 0, MPWS(nw)); - } - } +/* --- @MPX_SHIFTW@ --- * + * + * Arguments: @max@ = the maximum shift (in words) which is nontrivial + * @clear@ = function (or macro) to clear low-order output words + * @copy@ = statement to copy words from input to output + * + * Use: Performs a shift by a whole number of words. If the shift + * amount is @max@ or more words, then the destination is + * @clear@ed entirely; otherwise, @copy@ is executed. + * + * This macro is intended to be used in the @shiftw@ argument of + * @MPX_SHIFTOP@, and expects variables describing the operation + * to be set up accordingly. + */ - /* --- And finally the difficult case --- * - * - * This is a little convoluted, because I have to start from the end and - * work backwards to avoid overwriting the source, if they're both the same - * block of memory. - */ +#define MPX_SHIFTW(max, clear, copy) do { \ + if (nw >= (max)) clear(dv, dvl); \ + else copy \ +} while (0) - else { - mpw w; - size_t nr = MPW_BITS - nb; - size_t dvn = dvl - dv; - size_t avn = avl - av; +/* --- @MPX_SHIFTOP@ --- * + * + * Arguments: @name@ = name of function to define (without `@mpx_@' prefix) + * @shift1@ = statement to shift by a single bit + * @shiftw@ = statement to shift by a whole number of words + * @shift@ = statement to perform a general shift + * + * Use: Emits a shift operation. The input is @av@..@avl@; the + * output is @dv@..@dvl@; and the shift amount (in bits) is + * @n@. In @shiftw@ and @shift@, @nw@ and @nb@ are set up such + * that @n = nw*MPW_BITS + nb@ and @nb < MPW_BITS@. + */ - if (dvn <= nw) { - MPX_ZERO(dv, dvl); - goto done; - } +#define MPX_SHIFTOP(name, shift1, shiftw, shift) \ + \ +void mpx_##name(mpw *dv, mpw *dvl, \ + const mpw *av, const mpw *avl, \ + size_t n) \ +{ \ + \ + if (n == 0) \ + MPX_COPY(dv, dvl, av, avl); \ + else if (n == 1) \ + do shift1 while (0); \ + else { \ + size_t nw = n/MPW_BITS; \ + unsigned nb = n%MPW_BITS; \ + if (!nb) do shiftw while (0); \ + else do shift while (0); \ + } \ +} - if (dvn > avn + nw) { - size_t off = avn + nw + 1; - MPX_ZERO(dv + off, dvl); - dvl = dv + off; - w = 0; - } else { - avl = av + dvn - nw; - w = *--avl << nb; - } +/* --- @MPX_SHIFT_LEFT@ --- * + * + * Arguments: @name@ = name of function to define (without `@mpx_@' prefix) + * @init1@ = initializer for single-bit shift accumulator + * @clear@ = function (or macro) to clear low-order output words + * @flush@ = expression for low-order nontrivial output word + * + * Use: Emits a left-shift operation. This expands to a call on + * @MPX_SHIFTOP@, but implements the complicated @shift@ + * statement. + * + * The @init1@ argument is as for @MPX_SHIFT1@, and @clear@ is + * as for @MPX_SHIFTW@ (though is used elsewhere). In a general + * shift, @nw@ whole low-order output words are set using + * @clear@; high-order words are zeroed; and the remaining words + * set with a left-to-right pass across the input; at the end of + * the operation, the least significant output word above those + * @clear@ed is set using @flush@, which may use the accumulator + * @w@ = @av[0] << nb@. + */ - while (avl > av) { - mpw t = *--avl; - *--dvl = MPW((t >> nr) | w); - w = t << nb; - } +#define MPX_SHIFT_LEFT(name, init1, clear, flush) \ +MPX_SHIFTOP(name, { \ + MPX_SHIFT1(init1, \ + w | (t << 1), \ + t >> (MPW_BITS - 1)); \ +}, { \ + MPX_SHIFTW(dvl - dv, clear, { \ + MPX_COPY(dv + nw, dvl, av, avl); \ + clear(dv, dv + nw); \ + }); \ +}, { \ + size_t nr = MPW_BITS - nb; \ + size_t dvn = dvl - dv; \ + size_t avn = avl - av; \ + mpw w; \ + \ + if (dvn <= nw) { \ + clear(dv, dvl); \ + break; \ + } \ + \ + if (dvn <= avn + nw) { \ + avl = av + dvn - nw; \ + w = *--avl << nb; \ + } else { \ + size_t off = avn + nw + 1; \ + MPX_ZERO(dv + off, dvl); \ + dvl = dv + off; \ + w = 0; \ + } \ + \ + while (avl > av) { \ + mpw t = *--avl; \ + *--dvl = MPW(w | (t >> nr)); \ + w = t << nb; \ + } \ + \ + *--dvl = MPW(flush); \ + clear(dv, dvl); \ +}) - *--dvl = MPW(w); - MPX_ZERO(dv, dvl); - } +/* --- @mpx_lsl@ --- * + * + * Arguments: @mpw *dv, *dvl@ = destination vector base and limit + * @const mpw *av, *avl@ = source vector base and limit + * @size_t n@ = number of bit positions to shift by + * + * Returns: --- + * + * Use: Performs a logical shift left operation on an integer. + */ -done:; -} +MPX_SHIFT_LEFT(lsl, 0, MPX_ZERO, w) /* --- @mpx_lslc@ --- * * @@ -484,91 +521,7 @@ done:; * it fills in the bits with ones instead of zeroes. */ -void mpx_lslc(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n) -{ - size_t nw; - unsigned nb; - - /* --- Trivial special case --- */ - - if (n == 0) - MPX_COPY(dv, dvl, av, avl); - - /* --- Single bit shifting --- */ - - else if (n == 1) { - mpw w = 1; - while (av < avl) { - mpw t; - if (dv >= dvl) - goto done; - t = *av++; - *dv++ = MPW((t << 1) | w); - w = t >> (MPW_BITS - 1); - } - if (dv >= dvl) - goto done; - *dv++ = MPW(w); - MPX_ZERO(dv, dvl); - goto done; - } - - /* --- Break out word and bit shifts for more sophisticated work --- */ - - nw = n / MPW_BITS; - nb = n % MPW_BITS; - - /* --- Handle a shift by a multiple of the word size --- */ - - if (nb == 0) { - if (nw >= dvl - dv) - MPX_ONE(dv, dvl); - else { - MPX_COPY(dv + nw, dvl, av, avl); - MPX_ONE(dv, dv + nw); - } - } - - /* --- And finally the difficult case --- * - * - * This is a little convoluted, because I have to start from the end and - * work backwards to avoid overwriting the source, if they're both the same - * block of memory. - */ - - else { - mpw w; - size_t nr = MPW_BITS - nb; - size_t dvn = dvl - dv; - size_t avn = avl - av; - - if (dvn <= nw) { - MPX_ONE(dv, dvl); - goto done; - } - - if (dvn > avn + nw) { - size_t off = avn + nw + 1; - MPX_ZERO(dv + off, dvl); - dvl = dv + off; - w = 0; - } else { - avl = av + dvn - nw; - w = *--avl << nb; - } - - while (avl > av) { - mpw t = *--avl; - *--dvl = MPW((t >> nr) | w); - w = t << nb; - } - - *--dvl = MPW((MPW_MAX >> nr) | w); - MPX_ONE(dv, dvl); - } - -done:; -} +MPX_SHIFT_LEFT(lslc, 1, MPX_ONE, w | (MPW_MAX >> nr)) /* --- @mpx_lsr@ --- * * @@ -581,73 +534,38 @@ done:; * Use: Performs a logical shift right operation on an integer. */ -void mpx_lsr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n) -{ - size_t nw; - unsigned nb; - - /* --- Trivial special case --- */ - - if (n == 0) - MPX_COPY(dv, dvl, av, avl); - - /* --- Single bit shifting --- */ - - else if (n == 1) { - mpw w = av < avl ? *av++ >> 1 : 0; - while (av < avl) { - mpw t; - if (dv >= dvl) - goto done; - t = *av++; - *dv++ = MPW((t << (MPW_BITS - 1)) | w); - w = t >> 1; - } - if (dv >= dvl) - goto done; - *dv++ = MPW(w); - MPX_ZERO(dv, dvl); - goto done; - } - - /* --- Break out word and bit shifts for more sophisticated work --- */ - - nw = n / MPW_BITS; - nb = n % MPW_BITS; - - /* --- Handle a shift by a multiple of the word size --- */ - - if (nb == 0) { - if (nw >= avl - av) - MPX_ZERO(dv, dvl); - else - MPX_COPY(dv, dvl, av + nw, avl); - } - - /* --- And finally the difficult case --- */ - +MPX_SHIFTOP(lsr, { + MPX_SHIFT1(av < avl ? *av++ >> 1 : 0, + w | (t << (MPW_BITS - 1)), + t >> 1); +}, { + MPX_SHIFTW(avl - av, MPX_ZERO, + { MPX_COPY(dv, dvl, av + nw, avl); }); +}, { + size_t nr = MPW_BITS - nb; + mpw w; + + if (nw >= avl - av) + w = 0; else { - mpw w; - size_t nr = MPW_BITS - nb; - av += nw; - w = av < avl ? *av++ : 0; + w = *av++; + while (av < avl) { mpw t; - if (dv >= dvl) - goto done; + if (dv >= dvl) goto done; t = *av++; *dv++ = MPW((w >> nb) | (t << nr)); w = t; } - if (dv < dvl) { - *dv++ = MPW(w >> nb); - MPX_ZERO(dv, dvl); - } } + if (dv < dvl) { + *dv++ = MPW(w >> nb); + MPX_ZERO(dv, dvl); + } done:; -} +}) /*----- Bitwise operations ------------------------------------------------*/ @@ -900,7 +818,7 @@ void mpx_usub(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, void mpx_usubn(mpw *dv, mpw *dvl, mpw n) { MPX_USUBN(dv, dvl, n); } -/* --- @mpx_uaddnlsl@ --- * +/* --- @mpx_usubnlsl@ --- * * * Arguments: @mpw *dv, *dvl@ = destination and first argument vector * @mpw a@ = second argument @@ -941,8 +859,13 @@ void mpx_usubnlsl(mpw *dv, mpw *dvl, mpw a, unsigned o) * vectors in any way. */ -void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, - const mpw *bv, const mpw *bvl) +CPU_DISPATCH(EMPTY, (void), void, mpx_umul, + (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, + const mpw *bv, const mpw *bvl), + (dv, dvl, av, avl, bv, bvl), pick_umul, simple_umul); + +static void simple_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, + const mpw *bv, const mpw *bvl) { /* --- This is probably worthwhile on a multiply --- */ @@ -981,6 +904,65 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, } } +#define MAYBE_UMUL4(impl) \ + extern void mpx_umul4_##impl(mpw */*dv*/, \ + const mpw */*av*/, const mpw */*avl*/, \ + const mpw */*bv*/, const mpw */*bvl*/); \ + static void maybe_umul4_##impl(mpw *dv, mpw *dvl, \ + const mpw *av, const mpw *avl, \ + const mpw *bv, const mpw *bvl) \ + { \ + size_t an = avl - av, bn = bvl - bv, dn = dvl - dv; \ + if (!an || an%4 != 0 || !bn || bn%4 != 0 || dn < an + bn) \ + simple_umul(dv, dvl, av, avl, bv, bvl); \ + else { \ + mpx_umul4_##impl(dv, av, avl, bv, bvl); \ + MPX_ZERO(dv + an + bn, dvl); \ + } \ + } + +#if CPUFAM_X86 + MAYBE_UMUL4(x86_sse2) + MAYBE_UMUL4(x86_avx) +#endif + +#if CPUFAM_AMD64 + MAYBE_UMUL4(amd64_sse2) + MAYBE_UMUL4(amd64_avx) +#endif + +#if CPUFAM_ARMEL + MAYBE_UMUL4(arm_neon) +#endif + +#if CPUFAM_ARM64 + MAYBE_UMUL4(arm64_simd) +#endif + +static mpx_umul__functype *pick_umul(void) +{ +#if CPUFAM_X86 + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_avx, + cpu_feature_p(CPUFEAT_X86_AVX)); + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_sse2, + cpu_feature_p(CPUFEAT_X86_SSE2)); +#endif +#if CPUFAM_AMD64 + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_avx, + cpu_feature_p(CPUFEAT_X86_AVX)); + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_sse2, + cpu_feature_p(CPUFEAT_X86_SSE2)); +#endif +#if CPUFAM_ARMEL + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_arm_neon, + cpu_feature_p(CPUFEAT_ARM_NEON)); +#endif +#if CPUFAM_ARM64 + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_arm64_simd, 1); +#endif + DISPATCH_PICK_FALLBACK(mpx_umul, simple_umul); +} + /* --- @mpx_umuln@ --- * * * Arguments: @mpw *dv, *dvl@ = destination vector base and limit @@ -1307,15 +1289,21 @@ mpw mpx_udivn(mpw *qv, mpw *qvl, const mpw *rv, const mpw *rvl, mpw d) #include #include +#include #include #include +#ifdef ENABLE_ASM_DEBUG +# include "regdump.h" +#endif + #include "mpscan.h" #define ALLOC(v, vl, sz) do { \ size_t _sz = (sz); \ mpw *_vv = xmalloc(MPWS(_sz)); \ mpw *_vvl = _vv + _sz; \ + memset(_vv, 0xa5, MPWS(_sz)); \ (v) = _vv; \ (vl) = _vvl; \ } while (0) @@ -1398,7 +1386,7 @@ static int loadstore(dstr *v) ok = 0; MPX_OCTETS(oct, m, ml); mpx_storel(m, ml, d.buf, d.sz); - if (memcmp(d.buf, v->buf, oct) != 0) { + if (MEMCMP(d.buf, !=, v->buf, oct)) { dumpbits("\n*** storel failed", d.buf, d.sz); ok = 0; } @@ -1408,7 +1396,7 @@ static int loadstore(dstr *v) ok = 0; MPX_OCTETS(oct, m, ml); mpx_storeb(m, ml, d.buf, d.sz); - if (memcmp(d.buf + d.sz - oct, v->buf + v->len - oct, oct) != 0) { + if (MEMCMP(d.buf + d.sz - oct, !=, v->buf + v->len - oct, oct)) { dumpbits("\n*** storeb failed", d.buf, d.sz); ok = 0; } @@ -1425,29 +1413,34 @@ static int loadstore(dstr *v) static int twocl(dstr *v) { dstr d = DSTR_INIT; - mpw *m, *ml; - size_t sz; + mpw *m, *ml0, *ml1; + size_t sz0, sz1, szmax; int ok = 1; + int i; - sz = v[0].len; if (v[1].len > sz) sz = v[1].len; - dstr_ensure(&d, sz); + sz0 = MPW_RQ(v[0].len); sz1 = MPW_RQ(v[1].len); + dstr_ensure(&d, v[0].len > v[1].len ? v[0].len : v[1].len); - sz = MPW_RQ(sz); - m = xmalloc(MPWS(sz)); - ml = m + sz; + szmax = sz0 > sz1 ? sz0 : sz1; + m = xmalloc(MPWS(szmax)); + ml0 = m + sz0; ml1 = m + sz1; - mpx_loadl(m, ml, v[0].buf, v[0].len); - mpx_storel2cn(m, ml, d.buf, v[1].len); - if (memcmp(d.buf, v[1].buf, v[1].len)) { - dumpbits("\n*** storel2cn failed", d.buf, v[1].len); - ok = 0; - } + for (i = 0; i < 2; i++) { + if (i) ml0 = ml1 = m + szmax; - mpx_loadl2cn(m, ml, v[1].buf, v[1].len); - mpx_storel(m, ml, d.buf, v[0].len); - if (memcmp(d.buf, v[0].buf, v[0].len)) { - dumpbits("\n*** loadl2cn failed", d.buf, v[0].len); - ok = 0; + mpx_loadl(m, ml0, v[0].buf, v[0].len); + mpx_storel2cn(m, ml0, d.buf, v[1].len); + if (MEMCMP(d.buf, !=, v[1].buf, v[1].len)) { + dumpbits("\n*** storel2cn failed", d.buf, v[1].len); + ok = 0; + } + + mpx_loadl2cn(m, ml1, v[1].buf, v[1].len); + mpx_storel(m, ml1, d.buf, v[0].len); + if (MEMCMP(d.buf, !=, v[0].buf, v[0].len)) { + dumpbits("\n*** loadl2cn failed", d.buf, v[0].len); + ok = 0; + } } if (!ok) { @@ -1464,29 +1457,34 @@ static int twocl(dstr *v) static int twocb(dstr *v) { dstr d = DSTR_INIT; - mpw *m, *ml; - size_t sz; + mpw *m, *ml0, *ml1; + size_t sz0, sz1, szmax; int ok = 1; + int i; - sz = v[0].len; if (v[1].len > sz) sz = v[1].len; - dstr_ensure(&d, sz); + sz0 = MPW_RQ(v[0].len); sz1 = MPW_RQ(v[1].len); + dstr_ensure(&d, v[0].len > v[1].len ? v[0].len : v[1].len); - sz = MPW_RQ(sz); - m = xmalloc(MPWS(sz)); - ml = m + sz; + szmax = sz0 > sz1 ? sz0 : sz1; + m = xmalloc(MPWS(szmax)); + ml0 = m + sz0; ml1 = m + sz1; - mpx_loadb(m, ml, v[0].buf, v[0].len); - mpx_storeb2cn(m, ml, d.buf, v[1].len); - if (memcmp(d.buf, v[1].buf, v[1].len)) { - dumpbits("\n*** storeb2cn failed", d.buf, v[1].len); - ok = 0; - } + for (i = 0; i < 2; i++) { + if (i) ml0 = ml1 = m + szmax; - mpx_loadb2cn(m, ml, v[1].buf, v[1].len); - mpx_storeb(m, ml, d.buf, v[0].len); - if (memcmp(d.buf, v[0].buf, v[0].len)) { - dumpbits("\n*** loadb2cn failed", d.buf, v[0].len); - ok = 0; + mpx_loadb(m, ml0, v[0].buf, v[0].len); + mpx_storeb2cn(m, ml0, d.buf, v[1].len); + if (MEMCMP(d.buf, !=, v[1].buf, v[1].len)) { + dumpbits("\n*** storeb2cn failed", d.buf, v[1].len); + ok = 0; + } + + mpx_loadb2cn(m, ml1, v[1].buf, v[1].len); + mpx_storeb(m, ml1, d.buf, v[0].len); + if (MEMCMP(d.buf, !=, v[0].buf, v[0].len)) { + dumpbits("\n*** loadb2cn failed", d.buf, v[0].len); + ok = 0; + } } if (!ok) { @@ -1730,6 +1728,9 @@ static test_chunk defs[] = { int main(int argc, char *argv[]) { +#ifdef ENABLE_ASM_DEBUG + regdump_init(); +#endif test_run(argc, argv, defs, SRCDIR"/t/mpx"); return (0); }