/*----- Header files ------------------------------------------------------*/
+#include "config.h"
+
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <mLib/bits.h>
#include <mLib/macros.h>
+#include "dispatch.h"
#include "mptypes.h"
#include "mpx.h"
#include "bitops.h"
/*----- Loading and storing -----------------------------------------------*/
+/* --- These are all variations on a theme --- *
+ *
+ * Essentially we want to feed bits into a shift register, @ibits@ bits at a
+ * time, and extract them @obits@ bits at a time whenever there are enough.
+ * Of course, @i@ and @o@ will, in general, be different sizes, and we don't
+ * necessarily know which is larger.
+ *
+ * During an operation, we have a shift register @w@ and a most-recent input
+ * @t@. Together, these hold @bits@ significant bits of input. We arrange
+ * that @bits < ibits + obits <= 2*MPW_BITS@, so we can get away with using
+ * an @mpw@ for both of these quantitities.
+ */
+
+/* --- @MPX_GETBITS@ --- *
+ *
+ * Arguments: @ibits@ = width of input units, in bits
+ * @obits@ = width of output units, in bits
+ * @iavail@ = condition expression: is input data available?
+ * @getbits@ = function or macro: set argument to next input
+ *
+ * Use: Read an input unit into @t@ and update the necessary
+ * variables.
+ *
+ * It is assumed on entry that @bits < obits@. On exit, we have
+ * @bits < ibits + obits@, and @t@ is live.
+ */
+
+#define MPX_GETBITS(ibits, obits, iavail, getbits) do { \
+ if (!iavail) goto flush; \
+ if (bits >= ibits) w |= t << (bits - ibits); \
+ getbits(t); \
+ bits += ibits; \
+} while (0)
+
+/* --- @MPX_PUTBITS@ --- *
+ *
+ * Arguments: @ibits@ = width of input units, in bits
+ * @obits@ = width of output units, in bits
+ * @oavail@ = condition expression: is output space available?
+ * @putbits@ = function or macro: write its argument to output
+ *
+ * Use: Emit an output unit, and update the necessary variables. If
+ * the output buffer is full, then force an immediate return.
+ *
+ * We assume that @bits < ibits + obits@, and that @t@ is only
+ * relevant if @bits >= ibits@. (The @MPX_GETBITS@ macro
+ * ensures that this is true.)
+ */
+
+#define SHRW(w, b) ((b) < MPW_BITS ? (w) >> (b) : 0)
+
+#define MPX_PUTBITS(ibits, obits, oavail, putbits) do { \
+ if (!oavail) return; \
+ if (bits < ibits) { \
+ putbits(w); \
+ bits -= obits; \
+ w = SHRW(w, obits); \
+ } else { \
+ putbits(w | (t << (bits - ibits))); \
+ bits -= obits; \
+ if (bits >= ibits) w = SHRW(w, obits) | (t << (bits - ibits)); \
+ else w = SHRW(w, obits) | (t >> (ibits - bits)); \
+ t = 0; \
+ } \
+} while (0)
+
+/* --- @MPX_LOADSTORE@ --- *
+ *
+ * Arguments: @name@ = name of function to create, without @mpx_@ prefix
+ * @wconst@ = qualifiers for @mpw *@ arguments
+ * @oconst@ = qualifiers for octet pointers
+ * @decls@ = additional declarations needed
+ * @ibits@ = width of input units, in bits
+ * @iavail@ = condition expression: is input data available?
+ * @getbits@ = function or macro: set argument to next input
+ * @obits@ = width of output units, in bits
+ * @oavail@ = condition expression: is output space available?
+ * @putbits@ = function or macro: write its argument to output
+ * @fixfinal@ = statements to fix shift register at the end
+ * @clear@ = statements to clear remainder of output
+ *
+ * Use: Generates a function to convert between a sequence of
+ * multiprecision words and a vector of octets.
+ *
+ * The arguments @ibits@, @iavail@ and @getbits@ are passed on
+ * to @MPX_GETBITS@; similarly, @obits@, @oavail@, and @putbits@
+ * are passed on to @MPX_PUTBITS@.
+ *
+ * The following variables are in scope: @v@ and @vl are the
+ * current base and limit of the word vector; @p@ and @q@ are
+ * the base and limit of the octet vector; @w@ and @t@ form the
+ * shift register used during the conversion (see commentary
+ * above); and @bits@ tracks the number of live bits in the
+ * shift register.
+ */
+
+#define MPX_LOADSTORE(name, wconst, oconst, decls, \
+ ibits, iavail, getbits, obits, oavail, putbits, \
+ fixfinal, clear) \
+ \
+void mpx_##name(wconst mpw *v, wconst mpw *vl, \
+ oconst void *pp, size_t sz) \
+{ \
+ mpw t = 0, w = 0; \
+ oconst octet *p = pp, *q = p + sz; \
+ int bits = 0; \
+ decls \
+ \
+ for (;;) { \
+ while (bits < obits) MPX_GETBITS(ibits, obits, iavail, getbits); \
+ while (bits >= obits) MPX_PUTBITS(ibits, obits, oavail, putbits); \
+ } \
+ \
+flush: \
+ if (bits) { \
+ fixfinal; \
+ while (bits > 0) MPX_PUTBITS(ibits, obits, oavail, putbits); \
+ } \
+ clear; \
+}
+
+#define EMPTY
+
+/* --- Macros for @getbits@ and @putbits@ --- */
+
+#define GETMPW(t) do { t = *v++; } while (0)
+#define PUTMPW(x) do { *v++ = MPW(x); } while (0)
+
+#define GETOCTETI(t) do { t = *p++; } while (0)
+#define PUTOCTETD(x) do { *--q = U8(x); } while (0)
+
+#define PUTOCTETI(x) do { *p++ = U8(x); } while (0)
+#define GETOCTETD(t) do { t = *--q; } while (0)
+
+/* --- Machinery for two's complement I/O --- */
+
+#define DECL_2CN \
+ unsigned c = 1;
+
+#define GETMPW_2CN(t) do { \
+ t = MPW(~*v++ + c); \
+ c = c && !t; \
+} while (0)
+
+#define PUTMPW_2CN(t) do { \
+ mpw _t = MPW(~(t) + c); \
+ c = c && !_t; \
+ *v++ = _t; \
+} while (0)
+
+#define FIXFINALW_2CN do { \
+ if (c && !w && !t); \
+ else if (bits == 8) t ^= ~(mpw)0xffu; \
+ else t ^= ((mpw)1 << (MPW_BITS - bits + 8)) - 256u; \
+} while (0)
+
+#define FLUSHO_2CN do { \
+ memset(p, c ? 0 : 0xff, q - p); \
+} while (0)
+
/* --- @mpx_storel@ --- *
*
* Arguments: @const mpw *v, *vl@ = base and limit of source vector
* isn't enough space for them.
*/
-void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
- mpw n, w = 0;
- octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- while (p < q) {
- if (bits < 8) {
- if (v >= vl) {
- *p++ = U8(w);
- break;
- }
- n = *v++;
- *p++ = U8(w | n << bits);
- w = n >> (8 - bits);
- bits += MPW_BITS - 8;
- } else {
- *p++ = U8(w);
- w >>= 8;
- bits -= 8;
- }
- }
- memset(p, 0, q - p);
-}
+MPX_LOADSTORE(storel, const, EMPTY, EMPTY,
+ MPW_BITS, (v < vl), GETMPW,
+ 8, (p < q), PUTOCTETI,
+ EMPTY, { memset(p, 0, q - p); })
/* --- @mpx_loadl@ --- *
*
* space for them.
*/
-void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
- unsigned n;
- mpw w = 0;
- const octet *p = pp, *q = p + sz;
- unsigned bits = 0;
+MPX_LOADSTORE(loadl, EMPTY, const, EMPTY,
+ 8, (p < q), GETOCTETI,
+ MPW_BITS, (v < vl), PUTMPW,
+ EMPTY, { MPX_ZERO(v, vl); })
- if (v >= vl)
- return;
- while (p < q) {
- n = U8(*p++);
- w |= n << bits;
- bits += 8;
- if (bits >= MPW_BITS) {
- *v++ = MPW(w);
- w = n >> (MPW_BITS - bits + 8);
- bits -= MPW_BITS;
- if (v >= vl)
- return;
- }
- }
- *v++ = w;
- MPX_ZERO(v, vl);
-}
/* --- @mpx_storeb@ --- *
*
* isn't enough space for them.
*/
-void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
- mpw n, w = 0;
- octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- while (q > p) {
- if (bits < 8) {
- if (v >= vl) {
- *--q = U8(w);
- break;
- }
- n = *v++;
- *--q = U8(w | n << bits);
- w = n >> (8 - bits);
- bits += MPW_BITS - 8;
- } else {
- *--q = U8(w);
- w >>= 8;
- bits -= 8;
- }
- }
- memset(p, 0, q - p);
-}
+MPX_LOADSTORE(storeb, const, EMPTY, EMPTY,
+ MPW_BITS, (v < vl), GETMPW,
+ 8, (p < q), PUTOCTETD,
+ EMPTY, { memset(p, 0, q - p); })
/* --- @mpx_loadb@ --- *
*
* space for them.
*/
-void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
- unsigned n;
- mpw w = 0;
- const octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- if (v >= vl)
- return;
- while (q > p) {
- n = U8(*--q);
- w |= n << bits;
- bits += 8;
- if (bits >= MPW_BITS) {
- *v++ = MPW(w);
- w = n >> (MPW_BITS - bits + 8);
- bits -= MPW_BITS;
- if (v >= vl)
- return;
- }
- }
- *v++ = w;
- MPX_ZERO(v, vl);
-}
+MPX_LOADSTORE(loadb, EMPTY, const, EMPTY,
+ 8, (p < q), GETOCTETD,
+ MPW_BITS, (v < vl), PUTMPW,
+ EMPTY, { MPX_ZERO(v, vl); })
/* --- @mpx_storel2cn@ --- *
*
* This obviously makes the output bad.
*/
-void mpx_storel2cn(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
- unsigned c = 1;
- unsigned b = 0;
- mpw n, w = 0;
- octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- while (p < q) {
- if (bits < 8) {
- if (v >= vl) {
- b = w;
- break;
- }
- n = *v++;
- b = w | n << bits;
- w = n >> (8 - bits);
- bits += MPW_BITS - 8;
- } else {
- b = w;
- w >>= 8;
- bits -= 8;
- }
- b = U8(~b + c);
- c = c && !b;
- *p++ = b;
- }
- while (p < q) {
- b = U8(~b + c);
- c = c && !b;
- *p++ = b;
- b = 0;
- }
-}
+MPX_LOADSTORE(storel2cn, const, EMPTY, DECL_2CN,
+ MPW_BITS, (v < vl), GETMPW_2CN,
+ 8, (p < q), PUTOCTETI,
+ EMPTY, { FLUSHO_2CN; })
/* --- @mpx_loadl2cn@ --- *
*
* means you made the wrong choice coming here.
*/
-void mpx_loadl2cn(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
- unsigned n;
- unsigned c = 1;
- mpw w = 0;
- const octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- if (v >= vl)
- return;
- while (p < q) {
- n = U8(~(*p++) + c);
- c = c && !n;
- w |= n << bits;
- bits += 8;
- if (bits >= MPW_BITS) {
- *v++ = MPW(w);
- w = n >> (MPW_BITS - bits + 8);
- bits -= MPW_BITS;
- if (v >= vl)
- return;
- }
- }
- *v++ = w;
- MPX_ZERO(v, vl);
-}
+MPX_LOADSTORE(loadl2cn, EMPTY, const, DECL_2CN,
+ 8, (p < q), GETOCTETI,
+ MPW_BITS, (v < vl), PUTMPW_2CN,
+ { FIXFINALW_2CN; }, { MPX_ZERO(v, vl); })
/* --- @mpx_storeb2cn@ --- *
*
* which probably isn't what you meant.
*/
-void mpx_storeb2cn(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
- mpw n, w = 0;
- unsigned b = 0;
- unsigned c = 1;
- octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- while (q > p) {
- if (bits < 8) {
- if (v >= vl) {
- b = w;
- break;
- }
- n = *v++;
- b = w | n << bits;
- w = n >> (8 - bits);
- bits += MPW_BITS - 8;
- } else {
- b = w;
- w >>= 8;
- bits -= 8;
- }
- b = U8(~b + c);
- c = c && !b;
- *--q = b;
- }
- while (q > p) {
- b = ~b + c;
- c = c && !(b & 0xff);
- *--q = b;
- b = 0;
- }
-}
+MPX_LOADSTORE(storeb2cn, const, EMPTY, DECL_2CN,
+ MPW_BITS, (v < vl), GETMPW_2CN,
+ 8, (p < q), PUTOCTETD,
+ EMPTY, { FLUSHO_2CN; })
/* --- @mpx_loadb2cn@ --- *
*
* chose this function wrongly.
*/
-void mpx_loadb2cn(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
- unsigned n;
- unsigned c = 1;
- mpw w = 0;
- const octet *p = pp, *q = p + sz;
- unsigned bits = 0;
-
- if (v >= vl)
- return;
- while (q > p) {
- n = U8(~(*--q) + c);
- c = c && !n;
- w |= n << bits;
- bits += 8;
- if (bits >= MPW_BITS) {
- *v++ = MPW(w);
- w = n >> (MPW_BITS - bits + 8);
- bits -= MPW_BITS;
- if (v >= vl)
- return;
- }
- }
- *v++ = w;
- MPX_ZERO(v, vl);
-}
+MPX_LOADSTORE(loadb2cn, EMPTY, const, DECL_2CN,
+ 8, (p < q), GETOCTETD,
+ MPW_BITS, (v < vl), PUTMPW_2CN,
+ { FIXFINALW_2CN; }, { MPX_ZERO(v, vl); })
/*----- Logical shifting --------------------------------------------------*/
size_t nr = MPW_BITS - nb;
mpw w;
- av += nw;
- w = av < avl ? *av++ : 0;
- while (av < avl) {
- mpw t;
- if (dv >= dvl) goto done;
- t = *av++;
- *dv++ = MPW((w >> nb) | (t << nr));
- w = t;
+ if (nw >= avl - av)
+ w = 0;
+ else {
+ av += nw;
+ w = *av++;
+
+ while (av < avl) {
+ mpw t;
+ if (dv >= dvl) goto done;
+ t = *av++;
+ *dv++ = MPW((w >> nb) | (t << nr));
+ w = t;
+ }
}
+
if (dv < dvl) {
*dv++ = MPW(w >> nb);
MPX_ZERO(dv, dvl);
* vectors in any way.
*/
-void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
- const mpw *bv, const mpw *bvl)
+CPU_DISPATCH(EMPTY, (void), void, mpx_umul,
+ (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
+ const mpw *bv, const mpw *bvl),
+ (dv, dvl, av, avl, bv, bvl), pick_umul, simple_umul);
+
+static void simple_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
+ const mpw *bv, const mpw *bvl)
{
/* --- This is probably worthwhile on a multiply --- */
}
}
+#define MAYBE_UMUL4(impl) \
+ extern void mpx_umul4_##impl(mpw */*dv*/, \
+ const mpw */*av*/, const mpw */*avl*/, \
+ const mpw */*bv*/, const mpw */*bvl*/); \
+ static void maybe_umul4_##impl(mpw *dv, mpw *dvl, \
+ const mpw *av, const mpw *avl, \
+ const mpw *bv, const mpw *bvl) \
+ { \
+ size_t an = avl - av, bn = bvl - bv, dn = dvl - dv; \
+ if (!an || an%4 != 0 || !bn || bn%4 != 0 || dn < an + bn) \
+ simple_umul(dv, dvl, av, avl, bv, bvl); \
+ else { \
+ mpx_umul4_##impl(dv, av, avl, bv, bvl); \
+ MPX_ZERO(dv + an + bn, dvl); \
+ } \
+ }
+
+#if CPUFAM_X86
+ MAYBE_UMUL4(x86_sse2)
+ MAYBE_UMUL4(x86_avx)
+#endif
+
+#if CPUFAM_AMD64
+ MAYBE_UMUL4(amd64_sse2)
+ MAYBE_UMUL4(amd64_avx)
+#endif
+
+static mpx_umul__functype *pick_umul(void)
+{
+#if CPUFAM_X86
+ DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_avx,
+ cpu_feature_p(CPUFEAT_X86_AVX));
+ DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_sse2,
+ cpu_feature_p(CPUFEAT_X86_SSE2));
+#endif
+#if CPUFAM_AMD64
+ DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_avx,
+ cpu_feature_p(CPUFEAT_X86_AVX));
+ DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_sse2,
+ cpu_feature_p(CPUFEAT_X86_SSE2));
+#endif
+ DISPATCH_PICK_FALLBACK(mpx_umul, simple_umul);
+}
+
/* --- @mpx_umuln@ --- *
*
* Arguments: @mpw *dv, *dvl@ = destination vector base and limit
size_t _sz = (sz); \
mpw *_vv = xmalloc(MPWS(_sz)); \
mpw *_vvl = _vv + _sz; \
+ memset(_vv, 0xa5, MPWS(_sz)); \
(v) = _vv; \
(vl) = _vvl; \
} while (0)
static int twocl(dstr *v)
{
dstr d = DSTR_INIT;
- mpw *m, *ml;
- size_t sz;
+ mpw *m, *ml0, *ml1;
+ size_t sz0, sz1, szmax;
int ok = 1;
+ int i;
- sz = v[0].len; if (v[1].len > sz) sz = v[1].len;
- dstr_ensure(&d, sz);
+ sz0 = MPW_RQ(v[0].len); sz1 = MPW_RQ(v[1].len);
+ dstr_ensure(&d, v[0].len > v[1].len ? v[0].len : v[1].len);
- sz = MPW_RQ(sz);
- m = xmalloc(MPWS(sz));
- ml = m + sz;
+ szmax = sz0 > sz1 ? sz0 : sz1;
+ m = xmalloc(MPWS(szmax));
+ ml0 = m + sz0; ml1 = m + sz1;
- mpx_loadl(m, ml, v[0].buf, v[0].len);
- mpx_storel2cn(m, ml, d.buf, v[1].len);
- if (memcmp(d.buf, v[1].buf, v[1].len)) {
- dumpbits("\n*** storel2cn failed", d.buf, v[1].len);
- ok = 0;
- }
+ for (i = 0; i < 2; i++) {
+ if (i) ml0 = ml1 = m + szmax;
- mpx_loadl2cn(m, ml, v[1].buf, v[1].len);
- mpx_storel(m, ml, d.buf, v[0].len);
- if (memcmp(d.buf, v[0].buf, v[0].len)) {
- dumpbits("\n*** loadl2cn failed", d.buf, v[0].len);
- ok = 0;
+ mpx_loadl(m, ml0, v[0].buf, v[0].len);
+ mpx_storel2cn(m, ml0, d.buf, v[1].len);
+ if (memcmp(d.buf, v[1].buf, v[1].len)) {
+ dumpbits("\n*** storel2cn failed", d.buf, v[1].len);
+ ok = 0;
+ }
+
+ mpx_loadl2cn(m, ml1, v[1].buf, v[1].len);
+ mpx_storel(m, ml1, d.buf, v[0].len);
+ if (memcmp(d.buf, v[0].buf, v[0].len)) {
+ dumpbits("\n*** loadl2cn failed", d.buf, v[0].len);
+ ok = 0;
+ }
}
if (!ok) {
static int twocb(dstr *v)
{
dstr d = DSTR_INIT;
- mpw *m, *ml;
- size_t sz;
+ mpw *m, *ml0, *ml1;
+ size_t sz0, sz1, szmax;
int ok = 1;
+ int i;
- sz = v[0].len; if (v[1].len > sz) sz = v[1].len;
- dstr_ensure(&d, sz);
+ sz0 = MPW_RQ(v[0].len); sz1 = MPW_RQ(v[1].len);
+ dstr_ensure(&d, v[0].len > v[1].len ? v[0].len : v[1].len);
- sz = MPW_RQ(sz);
- m = xmalloc(MPWS(sz));
- ml = m + sz;
+ szmax = sz0 > sz1 ? sz0 : sz1;
+ m = xmalloc(MPWS(szmax));
+ ml0 = m + sz0; ml1 = m + sz1;
- mpx_loadb(m, ml, v[0].buf, v[0].len);
- mpx_storeb2cn(m, ml, d.buf, v[1].len);
- if (memcmp(d.buf, v[1].buf, v[1].len)) {
- dumpbits("\n*** storeb2cn failed", d.buf, v[1].len);
- ok = 0;
- }
+ for (i = 0; i < 2; i++) {
+ if (i) ml0 = ml1 = m + szmax;
- mpx_loadb2cn(m, ml, v[1].buf, v[1].len);
- mpx_storeb(m, ml, d.buf, v[0].len);
- if (memcmp(d.buf, v[0].buf, v[0].len)) {
- dumpbits("\n*** loadb2cn failed", d.buf, v[0].len);
- ok = 0;
+ mpx_loadb(m, ml0, v[0].buf, v[0].len);
+ mpx_storeb2cn(m, ml0, d.buf, v[1].len);
+ if (memcmp(d.buf, v[1].buf, v[1].len)) {
+ dumpbits("\n*** storeb2cn failed", d.buf, v[1].len);
+ ok = 0;
+ }
+
+ mpx_loadb2cn(m, ml1, v[1].buf, v[1].len);
+ mpx_storeb(m, ml1, d.buf, v[0].len);
+ if (memcmp(d.buf, v[0].buf, v[0].len)) {
+ dumpbits("\n*** loadb2cn failed", d.buf, v[0].len);
+ ok = 0;
+ }
}
if (!ok) {