X-Git-Url: https://git.distorted.org.uk/~mdw/catacomb/blobdiff_plain/0c9ebe471cfa8343f2ac5d8bd206870f82e87837..a90d420cbe87490c844ae422c966e746d3134b07:/math/mpx.c diff --git a/math/mpx.c b/math/mpx.c index 2745fe0f..07a6c20f 100644 --- a/math/mpx.c +++ b/math/mpx.c @@ -27,6 +27,8 @@ /*----- Header files ------------------------------------------------------*/ +#include "config.h" + #include #include #include @@ -35,6 +37,7 @@ #include #include +#include "dispatch.h" #include "mptypes.h" #include "mpx.h" #include "bitops.h" @@ -119,6 +122,7 @@ * @obits@ = width of output units, in bits * @oavail@ = condition expression: is output space available? * @putbits@ = function or macro: write its argument to output + * @fixfinal@ = statements to fix shift register at the end * @clear@ = statements to clear remainder of output * * Use: Generates a function to convert between a sequence of @@ -138,7 +142,7 @@ #define MPX_LOADSTORE(name, wconst, oconst, decls, \ ibits, iavail, getbits, obits, oavail, putbits, \ - clear) \ + fixfinal, clear) \ \ void mpx_##name(wconst mpw *v, wconst mpw *vl, \ oconst void *pp, size_t sz) \ @@ -154,7 +158,10 @@ void mpx_##name(wconst mpw *v, wconst mpw *vl, \ } \ \ flush: \ - while (bits > 0) MPX_PUTBITS(ibits, obits, oavail, putbits); \ + if (bits) { \ + fixfinal; \ + while (bits > 0) MPX_PUTBITS(ibits, obits, oavail, putbits); \ + } \ clear; \ } @@ -187,13 +194,14 @@ flush: \ *v++ = _t; \ } while (0) -#define FLUSHW_2CN do { \ - if (c) MPX_ONE(v, vl); \ - else MPX_ZERO(v, vl); \ +#define FIXFINALW_2CN do { \ + if (c && !w && !t); \ + else if (bits == 8) t ^= ~(mpw)0xffu; \ + else t ^= ((mpw)1 << (MPW_BITS - bits + 8)) - 256u; \ } while (0) #define FLUSHO_2CN do { \ - memset(p, c ? 0xff : 0, q - p); \ + memset(p, c ? 0 : 0xff, q - p); \ } while (0) /* --- @mpx_storel@ --- * @@ -212,7 +220,7 @@ flush: \ MPX_LOADSTORE(storel, const, EMPTY, EMPTY, MPW_BITS, (v < vl), GETMPW, 8, (p < q), PUTOCTETI, - { memset(p, 0, q - p); }) + EMPTY, { memset(p, 0, q - p); }) /* --- @mpx_loadl@ --- * * @@ -230,7 +238,7 @@ MPX_LOADSTORE(storel, const, EMPTY, EMPTY, MPX_LOADSTORE(loadl, EMPTY, const, EMPTY, 8, (p < q), GETOCTETI, MPW_BITS, (v < vl), PUTMPW, - { MPX_ZERO(v, vl); }) + EMPTY, { MPX_ZERO(v, vl); }) /* --- @mpx_storeb@ --- * @@ -249,7 +257,7 @@ MPX_LOADSTORE(loadl, EMPTY, const, EMPTY, MPX_LOADSTORE(storeb, const, EMPTY, EMPTY, MPW_BITS, (v < vl), GETMPW, 8, (p < q), PUTOCTETD, - { memset(p, 0, q - p); }) + EMPTY, { memset(p, 0, q - p); }) /* --- @mpx_loadb@ --- * * @@ -267,7 +275,7 @@ MPX_LOADSTORE(storeb, const, EMPTY, EMPTY, MPX_LOADSTORE(loadb, EMPTY, const, EMPTY, 8, (p < q), GETOCTETD, MPW_BITS, (v < vl), PUTMPW, - { MPX_ZERO(v, vl); }) + EMPTY, { MPX_ZERO(v, vl); }) /* --- @mpx_storel2cn@ --- * * @@ -286,7 +294,7 @@ MPX_LOADSTORE(loadb, EMPTY, const, EMPTY, MPX_LOADSTORE(storel2cn, const, EMPTY, DECL_2CN, MPW_BITS, (v < vl), GETMPW_2CN, 8, (p < q), PUTOCTETI, - { FLUSHO_2CN; }) + EMPTY, { FLUSHO_2CN; }) /* --- @mpx_loadl2cn@ --- * * @@ -305,7 +313,7 @@ MPX_LOADSTORE(storel2cn, const, EMPTY, DECL_2CN, MPX_LOADSTORE(loadl2cn, EMPTY, const, DECL_2CN, 8, (p < q), GETOCTETI, MPW_BITS, (v < vl), PUTMPW_2CN, - { FLUSHW_2CN; }) + { FIXFINALW_2CN; }, { MPX_ZERO(v, vl); }) /* --- @mpx_storeb2cn@ --- * * @@ -324,7 +332,7 @@ MPX_LOADSTORE(loadl2cn, EMPTY, const, DECL_2CN, MPX_LOADSTORE(storeb2cn, const, EMPTY, DECL_2CN, MPW_BITS, (v < vl), GETMPW_2CN, 8, (p < q), PUTOCTETD, - { FLUSHO_2CN; }) + EMPTY, { FLUSHO_2CN; }) /* --- @mpx_loadb2cn@ --- * * @@ -343,7 +351,7 @@ MPX_LOADSTORE(storeb2cn, const, EMPTY, DECL_2CN, MPX_LOADSTORE(loadb2cn, EMPTY, const, DECL_2CN, 8, (p < q), GETOCTETD, MPW_BITS, (v < vl), PUTMPW_2CN, - { FLUSHW_2CN; }) + { FIXFINALW_2CN; }, { MPX_ZERO(v, vl); }) /*----- Logical shifting --------------------------------------------------*/ @@ -537,15 +545,21 @@ MPX_SHIFTOP(lsr, { size_t nr = MPW_BITS - nb; mpw w; - av += nw; - w = av < avl ? *av++ : 0; - while (av < avl) { - mpw t; - if (dv >= dvl) goto done; - t = *av++; - *dv++ = MPW((w >> nb) | (t << nr)); - w = t; + if (nw >= avl - av) + w = 0; + else { + av += nw; + w = *av++; + + while (av < avl) { + mpw t; + if (dv >= dvl) goto done; + t = *av++; + *dv++ = MPW((w >> nb) | (t << nr)); + w = t; + } } + if (dv < dvl) { *dv++ = MPW(w >> nb); MPX_ZERO(dv, dvl); @@ -804,7 +818,7 @@ void mpx_usub(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, void mpx_usubn(mpw *dv, mpw *dvl, mpw n) { MPX_USUBN(dv, dvl, n); } -/* --- @mpx_uaddnlsl@ --- * +/* --- @mpx_usubnlsl@ --- * * * Arguments: @mpw *dv, *dvl@ = destination and first argument vector * @mpw a@ = second argument @@ -845,8 +859,13 @@ void mpx_usubnlsl(mpw *dv, mpw *dvl, mpw a, unsigned o) * vectors in any way. */ -void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, - const mpw *bv, const mpw *bvl) +CPU_DISPATCH(EMPTY, (void), void, mpx_umul, + (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, + const mpw *bv, const mpw *bvl), + (dv, dvl, av, avl, bv, bvl), pick_umul, simple_umul); + +static void simple_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, + const mpw *bv, const mpw *bvl) { /* --- This is probably worthwhile on a multiply --- */ @@ -885,6 +904,50 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, } } +#define MAYBE_UMUL4(impl) \ + extern void mpx_umul4_##impl(mpw */*dv*/, \ + const mpw */*av*/, const mpw */*avl*/, \ + const mpw */*bv*/, const mpw */*bvl*/); \ + static void maybe_umul4_##impl(mpw *dv, mpw *dvl, \ + const mpw *av, const mpw *avl, \ + const mpw *bv, const mpw *bvl) \ + { \ + size_t an = avl - av, bn = bvl - bv, dn = dvl - dv; \ + if (!an || an%4 != 0 || !bn || bn%4 != 0 || dn < an + bn) \ + simple_umul(dv, dvl, av, avl, bv, bvl); \ + else { \ + mpx_umul4_##impl(dv, av, avl, bv, bvl); \ + MPX_ZERO(dv + an + bn, dvl); \ + } \ + } + +#if CPUFAM_X86 + MAYBE_UMUL4(x86_sse2) + MAYBE_UMUL4(x86_avx) +#endif + +#if CPUFAM_AMD64 + MAYBE_UMUL4(amd64_sse2) + MAYBE_UMUL4(amd64_avx) +#endif + +static mpx_umul__functype *pick_umul(void) +{ +#if CPUFAM_X86 + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_avx, + cpu_feature_p(CPUFEAT_X86_AVX)); + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_sse2, + cpu_feature_p(CPUFEAT_X86_SSE2)); +#endif +#if CPUFAM_AMD64 + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_avx, + cpu_feature_p(CPUFEAT_X86_AVX)); + DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_sse2, + cpu_feature_p(CPUFEAT_X86_SSE2)); +#endif + DISPATCH_PICK_FALLBACK(mpx_umul, simple_umul); +} + /* --- @mpx_umuln@ --- * * * Arguments: @mpw *dv, *dvl@ = destination vector base and limit @@ -1211,6 +1274,7 @@ mpw mpx_udivn(mpw *qv, mpw *qvl, const mpw *rv, const mpw *rvl, mpw d) #include #include +#include #include #include @@ -1220,6 +1284,7 @@ mpw mpx_udivn(mpw *qv, mpw *qvl, const mpw *rv, const mpw *rvl, mpw d) size_t _sz = (sz); \ mpw *_vv = xmalloc(MPWS(_sz)); \ mpw *_vvl = _vv + _sz; \ + memset(_vv, 0xa5, MPWS(_sz)); \ (v) = _vv; \ (vl) = _vvl; \ } while (0) @@ -1302,7 +1367,7 @@ static int loadstore(dstr *v) ok = 0; MPX_OCTETS(oct, m, ml); mpx_storel(m, ml, d.buf, d.sz); - if (memcmp(d.buf, v->buf, oct) != 0) { + if (MEMCMP(d.buf, !=, v->buf, oct)) { dumpbits("\n*** storel failed", d.buf, d.sz); ok = 0; } @@ -1312,7 +1377,7 @@ static int loadstore(dstr *v) ok = 0; MPX_OCTETS(oct, m, ml); mpx_storeb(m, ml, d.buf, d.sz); - if (memcmp(d.buf + d.sz - oct, v->buf + v->len - oct, oct) != 0) { + if (MEMCMP(d.buf + d.sz - oct, !=, v->buf + v->len - oct, oct)) { dumpbits("\n*** storeb failed", d.buf, d.sz); ok = 0; } @@ -1329,29 +1394,34 @@ static int loadstore(dstr *v) static int twocl(dstr *v) { dstr d = DSTR_INIT; - mpw *m, *ml; - size_t sz; + mpw *m, *ml0, *ml1; + size_t sz0, sz1, szmax; int ok = 1; + int i; - sz = v[0].len; if (v[1].len > sz) sz = v[1].len; - dstr_ensure(&d, sz); + sz0 = MPW_RQ(v[0].len); sz1 = MPW_RQ(v[1].len); + dstr_ensure(&d, v[0].len > v[1].len ? v[0].len : v[1].len); - sz = MPW_RQ(sz); - m = xmalloc(MPWS(sz)); - ml = m + sz; + szmax = sz0 > sz1 ? sz0 : sz1; + m = xmalloc(MPWS(szmax)); + ml0 = m + sz0; ml1 = m + sz1; - mpx_loadl(m, ml, v[0].buf, v[0].len); - mpx_storel2cn(m, ml, d.buf, v[1].len); - if (memcmp(d.buf, v[1].buf, v[1].len)) { - dumpbits("\n*** storel2cn failed", d.buf, v[1].len); - ok = 0; - } + for (i = 0; i < 2; i++) { + if (i) ml0 = ml1 = m + szmax; - mpx_loadl2cn(m, ml, v[1].buf, v[1].len); - mpx_storel(m, ml, d.buf, v[0].len); - if (memcmp(d.buf, v[0].buf, v[0].len)) { - dumpbits("\n*** loadl2cn failed", d.buf, v[0].len); - ok = 0; + mpx_loadl(m, ml0, v[0].buf, v[0].len); + mpx_storel2cn(m, ml0, d.buf, v[1].len); + if (MEMCMP(d.buf, !=, v[1].buf, v[1].len)) { + dumpbits("\n*** storel2cn failed", d.buf, v[1].len); + ok = 0; + } + + mpx_loadl2cn(m, ml1, v[1].buf, v[1].len); + mpx_storel(m, ml1, d.buf, v[0].len); + if (MEMCMP(d.buf, !=, v[0].buf, v[0].len)) { + dumpbits("\n*** loadl2cn failed", d.buf, v[0].len); + ok = 0; + } } if (!ok) { @@ -1368,29 +1438,34 @@ static int twocl(dstr *v) static int twocb(dstr *v) { dstr d = DSTR_INIT; - mpw *m, *ml; - size_t sz; + mpw *m, *ml0, *ml1; + size_t sz0, sz1, szmax; int ok = 1; + int i; - sz = v[0].len; if (v[1].len > sz) sz = v[1].len; - dstr_ensure(&d, sz); + sz0 = MPW_RQ(v[0].len); sz1 = MPW_RQ(v[1].len); + dstr_ensure(&d, v[0].len > v[1].len ? v[0].len : v[1].len); - sz = MPW_RQ(sz); - m = xmalloc(MPWS(sz)); - ml = m + sz; + szmax = sz0 > sz1 ? sz0 : sz1; + m = xmalloc(MPWS(szmax)); + ml0 = m + sz0; ml1 = m + sz1; - mpx_loadb(m, ml, v[0].buf, v[0].len); - mpx_storeb2cn(m, ml, d.buf, v[1].len); - if (memcmp(d.buf, v[1].buf, v[1].len)) { - dumpbits("\n*** storeb2cn failed", d.buf, v[1].len); - ok = 0; - } + for (i = 0; i < 2; i++) { + if (i) ml0 = ml1 = m + szmax; - mpx_loadb2cn(m, ml, v[1].buf, v[1].len); - mpx_storeb(m, ml, d.buf, v[0].len); - if (memcmp(d.buf, v[0].buf, v[0].len)) { - dumpbits("\n*** loadb2cn failed", d.buf, v[0].len); - ok = 0; + mpx_loadb(m, ml0, v[0].buf, v[0].len); + mpx_storeb2cn(m, ml0, d.buf, v[1].len); + if (MEMCMP(d.buf, !=, v[1].buf, v[1].len)) { + dumpbits("\n*** storeb2cn failed", d.buf, v[1].len); + ok = 0; + } + + mpx_loadb2cn(m, ml1, v[1].buf, v[1].len); + mpx_storeb(m, ml1, d.buf, v[0].len); + if (MEMCMP(d.buf, !=, v[0].buf, v[0].len)) { + dumpbits("\n*** loadb2cn failed", d.buf, v[0].len); + ok = 0; + } } if (!ok) {