pub/, progs/: Implement Bernstein's Ed25519 signature scheme.
[catacomb] / math / mpx.c
index 5f7ffab..e759c5f 100644 (file)
@@ -27,6 +27,8 @@
 
 /*----- Header files ------------------------------------------------------*/
 
+#include "config.h"
+
 #include <assert.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <mLib/bits.h>
 #include <mLib/macros.h>
 
+#include "dispatch.h"
 #include "mptypes.h"
 #include "mpx.h"
 #include "bitops.h"
 
 /*----- Loading and storing -----------------------------------------------*/
 
+/* --- These are all variations on a theme --- *
+ *
+ * Essentially we want to feed bits into a shift register, @ibits@ bits at a
+ * time, and extract them @obits@ bits at a time whenever there are enough.
+ * Of course, @i@ and @o@ will, in general, be different sizes, and we don't
+ * necessarily know which is larger.
+ *
+ * During an operation, we have a shift register @w@ and a most-recent input
+ * @t@.  Together, these hold @bits@ significant bits of input.  We arrange
+ * that @bits < ibits + obits <= 2*MPW_BITS@, so we can get away with using
+ * an @mpw@ for both of these quantitities.
+ */
+
+/* --- @MPX_GETBITS@ --- *
+ *
+ * Arguments:  @ibits@ = width of input units, in bits
+ *             @obits@ = width of output units, in bits
+ *             @iavail@ = condition expression: is input data available?
+ *             @getbits@ = function or macro: set argument to next input
+ *
+ * Use:                Read an input unit into @t@ and update the necessary
+ *             variables.
+ *
+ *             It is assumed on entry that @bits < obits@.  On exit, we have
+ *             @bits < ibits + obits@, and @t@ is live.
+ */
+
+#define MPX_GETBITS(ibits, obits, iavail, getbits) do {                        \
+  if (!iavail) goto flush;                                             \
+  if (bits >= ibits) w |= t << (bits - ibits);                         \
+  getbits(t);                                                          \
+  bits += ibits;                                                       \
+} while (0)
+
+/* --- @MPX_PUTBITS@ --- *
+ *
+ * Arguments:  @ibits@ = width of input units, in bits
+ *             @obits@ = width of output units, in bits
+ *             @oavail@ = condition expression: is output space available?
+ *             @putbits@ = function or macro: write its argument to output
+ *
+ * Use:                Emit an output unit, and update the necessary variables.  If
+ *             the output buffer is full, then force an immediate return.
+ *
+ *             We assume that @bits < ibits + obits@, and that @t@ is only
+ *             relevant if @bits >= ibits@.  (The @MPX_GETBITS@ macro
+ *             ensures that this is true.)
+ */
+
+#define SHRW(w, b) ((b) < MPW_BITS ? (w) >> (b) : 0)
+
+#define MPX_PUTBITS(ibits, obits, oavail, putbits) do {                        \
+  if (!oavail) return;                                                 \
+  if (bits < ibits) {                                                  \
+    putbits(w);                                                                \
+    bits -= obits;                                                     \
+    w = SHRW(w, obits);                                                        \
+  } else {                                                             \
+    putbits(w | (t << (bits - ibits)));                                        \
+    bits -= obits;                                                     \
+    if (bits >= ibits) w = SHRW(w, obits) | (t << (bits - ibits));     \
+    else w = SHRW(w, obits) | (t >> (ibits - bits));                   \
+    t = 0;                                                             \
+  }                                                                    \
+} while (0)
+
+/* --- @MPX_LOADSTORE@ --- *
+ *
+ * Arguments:  @name@ = name of function to create, without @mpx_@ prefix
+ *             @wconst@ = qualifiers for @mpw *@ arguments
+ *             @oconst@ = qualifiers for octet pointers
+ *             @decls@ = additional declarations needed
+ *             @ibits@ = width of input units, in bits
+ *             @iavail@ = condition expression: is input data available?
+ *             @getbits@ = function or macro: set argument to next input
+ *             @obits@ = width of output units, in bits
+ *             @oavail@ = condition expression: is output space available?
+ *             @putbits@ = function or macro: write its argument to output
+ *             @clear@ = statements to clear remainder of output
+ *
+ * Use:                Generates a function to convert between a sequence of
+ *             multiprecision words and a vector of octets.
+ *
+ *             The arguments @ibits@, @iavail@ and @getbits@ are passed on
+ *             to @MPX_GETBITS@; similarly, @obits@, @oavail@, and @putbits@
+ *             are passed on to @MPX_PUTBITS@.
+ *
+ *             The following variables are in scope: @v@ and @vl are the
+ *             current base and limit of the word vector; @p@ and @q@ are
+ *             the base and limit of the octet vector; @w@ and @t@ form the
+ *             shift register used during the conversion (see commentary
+ *             above); and @bits@ tracks the number of live bits in the
+ *             shift register.
+ */
+
+#define MPX_LOADSTORE(name, wconst, oconst, decls,                     \
+                     ibits, iavail, getbits, obits, oavail, putbits,   \
+                     clear)                                            \
+                                                                       \
+void mpx_##name(wconst mpw *v, wconst mpw *vl,                         \
+               oconst void *pp, size_t sz)                             \
+{                                                                      \
+  mpw t = 0, w = 0;                                                    \
+  oconst octet *p = pp, *q = p + sz;                                   \
+  int bits = 0;                                                                \
+  decls                                                                        \
+                                                                       \
+  for (;;) {                                                           \
+    while (bits < obits) MPX_GETBITS(ibits, obits, iavail, getbits);   \
+    while (bits >= obits) MPX_PUTBITS(ibits, obits, oavail, putbits);  \
+  }                                                                    \
+                                                                       \
+flush:                                                                 \
+  while (bits > 0) MPX_PUTBITS(ibits, obits, oavail, putbits);         \
+  clear;                                                               \
+}
+
+#define EMPTY
+
+/* --- Macros for @getbits@ and @putbits@ --- */
+
+#define GETMPW(t) do { t = *v++; } while (0)
+#define PUTMPW(x) do { *v++ = MPW(x); } while (0)
+
+#define GETOCTETI(t) do { t = *p++; } while (0)
+#define PUTOCTETD(x) do { *--q = U8(x); } while (0)
+
+#define PUTOCTETI(x) do { *p++ = U8(x); } while (0)
+#define GETOCTETD(t) do { t = *--q; } while (0)
+
+/* --- Machinery for two's complement I/O --- */
+
+#define DECL_2CN                                                       \
+  unsigned c = 1;
+
+#define GETMPW_2CN(t) do {                                             \
+  t = MPW(~*v++ + c);                                                  \
+  c = c && !t;                                                         \
+} while (0)
+
+#define PUTMPW_2CN(t) do {                                             \
+  mpw _t = MPW(~(t) + c);                                              \
+  c = c && !_t;                                                                \
+  *v++ = _t;                                                           \
+} while (0)
+
+#define FLUSHW_2CN do {                                                        \
+  if (c) MPX_ONE(v, vl);                                               \
+  else MPX_ZERO(v, vl);                                                        \
+} while (0)
+
+#define FLUSHO_2CN do {                                                        \
+  memset(p, c ? 0xff : 0, q - p);                                      \
+} while (0)
+
 /* --- @mpx_storel@ --- *
  *
  * Arguments:  @const mpw *v, *vl@ = base and limit of source vector
  *             isn't enough space for them.
  */
 
-void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
-  mpw n, w = 0;
-  octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  while (p < q) {
-    if (bits < 8) {
-      if (v >= vl) {
-       *p++ = U8(w);
-       break;
-      }
-      n = *v++;
-      *p++ = U8(w | n << bits);
-      w = n >> (8 - bits);
-      bits += MPW_BITS - 8;
-    } else {
-      *p++ = U8(w);
-      w >>= 8;
-      bits -= 8;
-    }
-  }
-  memset(p, 0, q - p);
-}
+MPX_LOADSTORE(storel, const, EMPTY, EMPTY,
+             MPW_BITS, (v < vl), GETMPW,
+             8, (p < q), PUTOCTETI,
+             { memset(p, 0, q - p); })
 
 /* --- @mpx_loadl@ --- *
  *
@@ -92,30 +230,11 @@ void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
  *             space for them.
  */
 
-void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
-  unsigned n;
-  mpw w = 0;
-  const octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
+MPX_LOADSTORE(loadl, EMPTY, const, EMPTY,
+             8, (p < q), GETOCTETI,
+             MPW_BITS, (v < vl), PUTMPW,
+             { MPX_ZERO(v, vl); })
 
-  if (v >= vl)
-    return;
-  while (p < q) {
-    n = U8(*p++);
-    w |= n << bits;
-    bits += 8;
-    if (bits >= MPW_BITS) {
-      *v++ = MPW(w);
-      w = n >> (MPW_BITS - bits + 8);
-      bits -= MPW_BITS;
-      if (v >= vl)
-       return;
-    }
-  }
-  *v++ = w;
-  MPX_ZERO(v, vl);
-}
 
 /* --- @mpx_storeb@ --- *
  *
@@ -130,30 +249,10 @@ void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
  *             isn't enough space for them.
  */
 
-void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
-  mpw n, w = 0;
-  octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  while (q > p) {
-    if (bits < 8) {
-      if (v >= vl) {
-       *--q = U8(w);
-       break;
-      }
-      n = *v++;
-      *--q = U8(w | n << bits);
-      w = n >> (8 - bits);
-      bits += MPW_BITS - 8;
-    } else {
-      *--q = U8(w);
-      w >>= 8;
-      bits -= 8;
-    }
-  }
-  memset(p, 0, q - p);
-}
+MPX_LOADSTORE(storeb, const, EMPTY, EMPTY,
+             MPW_BITS, (v < vl), GETMPW,
+             8, (p < q), PUTOCTETD,
+             { memset(p, 0, q - p); })
 
 /* --- @mpx_loadb@ --- *
  *
@@ -168,30 +267,10 @@ void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
  *             space for them.
  */
 
-void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
-  unsigned n;
-  mpw w = 0;
-  const octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  if (v >= vl)
-    return;
-  while (q > p) {
-    n = U8(*--q);
-    w |= n << bits;
-    bits += 8;
-    if (bits >= MPW_BITS) {
-      *v++ = MPW(w);
-      w = n >> (MPW_BITS - bits + 8);
-      bits -= MPW_BITS;
-      if (v >= vl)
-       return;
-    }
-  }
-  *v++ = w;
-  MPX_ZERO(v, vl);
-}
+MPX_LOADSTORE(loadb, EMPTY, const, EMPTY,
+             8, (p < q), GETOCTETD,
+             MPW_BITS, (v < vl), PUTMPW,
+             { MPX_ZERO(v, vl); })
 
 /* --- @mpx_storel2cn@ --- *
  *
@@ -207,40 +286,10 @@ void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
  *             This obviously makes the output bad.
  */
 
-void mpx_storel2cn(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
-  unsigned c = 1;
-  unsigned b = 0;
-  mpw n, w = 0;
-  octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  while (p < q) {
-    if (bits < 8) {
-      if (v >= vl) {
-       b = w;
-       break;
-      }
-      n = *v++;
-      b = w | n << bits;
-      w = n >> (8 - bits);
-      bits += MPW_BITS - 8;
-    } else {
-      b = w;
-      w >>= 8;
-      bits -= 8;
-    }
-    b = U8(~b + c);
-    c = c && !b;
-    *p++ = b;
-  }
-  while (p < q) {
-    b = U8(~b + c);
-    c = c && !b;
-    *p++ = b;
-    b = 0;
-  }
-}
+MPX_LOADSTORE(storel2cn, const, EMPTY, DECL_2CN,
+             MPW_BITS, (v < vl), GETMPW_2CN,
+             8, (p < q), PUTOCTETI,
+             { FLUSHO_2CN; })
 
 /* --- @mpx_loadl2cn@ --- *
  *
@@ -256,32 +305,10 @@ void mpx_storel2cn(const mpw *v, const mpw *vl, void *pp, size_t sz)
  *             means you made the wrong choice coming here.
  */
 
-void mpx_loadl2cn(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
-  unsigned n;
-  unsigned c = 1;
-  mpw w = 0;
-  const octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  if (v >= vl)
-    return;
-  while (p < q) {
-    n = U8(~(*p++) + c);
-    c = c && !n;
-    w |= n << bits;
-    bits += 8;
-    if (bits >= MPW_BITS) {
-      *v++ = MPW(w);
-      w = n >> (MPW_BITS - bits + 8);
-      bits -= MPW_BITS;
-      if (v >= vl)
-       return;
-    }
-  }
-  *v++ = w;
-  MPX_ZERO(v, vl);
-}
+MPX_LOADSTORE(loadl2cn, EMPTY, const, DECL_2CN,
+             8, (p < q), GETOCTETI,
+             MPW_BITS, (v < vl), PUTMPW_2CN,
+             { FLUSHW_2CN; })
 
 /* --- @mpx_storeb2cn@ --- *
  *
@@ -297,40 +324,10 @@ void mpx_loadl2cn(mpw *v, mpw *vl, const void *pp, size_t sz)
  *             which probably isn't what you meant.
  */
 
-void mpx_storeb2cn(const mpw *v, const mpw *vl, void *pp, size_t sz)
-{
-  mpw n, w = 0;
-  unsigned b = 0;
-  unsigned c = 1;
-  octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  while (q > p) {
-    if (bits < 8) {
-      if (v >= vl) {
-       b = w;
-       break;
-      }
-      n = *v++;
-      b = w | n << bits;
-      w = n >> (8 - bits);
-      bits += MPW_BITS - 8;
-    } else {
-      b = w;
-      w >>= 8;
-      bits -= 8;
-    }
-    b = U8(~b + c);
-    c = c && !b;
-    *--q = b;
-  }
-  while (q > p) {
-    b = ~b + c;
-    c = c && !(b & 0xff);
-    *--q = b;
-    b = 0;
-  }
-}
+MPX_LOADSTORE(storeb2cn, const, EMPTY, DECL_2CN,
+             MPW_BITS, (v < vl), GETMPW_2CN,
+             8, (p < q), PUTOCTETD,
+             { FLUSHO_2CN; })
 
 /* --- @mpx_loadb2cn@ --- *
  *
@@ -346,32 +343,10 @@ void mpx_storeb2cn(const mpw *v, const mpw *vl, void *pp, size_t sz)
  *             chose this function wrongly.
  */
 
-void mpx_loadb2cn(mpw *v, mpw *vl, const void *pp, size_t sz)
-{
-  unsigned n;
-  unsigned c = 1;
-  mpw w = 0;
-  const octet *p = pp, *q = p + sz;
-  unsigned bits = 0;
-
-  if (v >= vl)
-    return;
-  while (q > p) {
-    n = U8(~(*--q) + c);
-    c = c && !n;
-    w |= n << bits;
-    bits += 8;
-    if (bits >= MPW_BITS) {
-      *v++ = MPW(w);
-      w = n >> (MPW_BITS - bits + 8);
-      bits -= MPW_BITS;
-      if (v >= vl)
-       return;
-    }
-  }
-  *v++ = w;
-  MPX_ZERO(v, vl);
-}
+MPX_LOADSTORE(loadb2cn, EMPTY, const, DECL_2CN,
+             8, (p < q), GETOCTETD,
+             MPW_BITS, (v < vl), PUTMPW_2CN,
+             { FLUSHW_2CN; })
 
 /*----- Logical shifting --------------------------------------------------*/
 
@@ -873,8 +848,13 @@ void mpx_usubnlsl(mpw *dv, mpw *dvl, mpw a, unsigned o)
  *             vectors in any way.
  */
 
-void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
-             const mpw *bv, const mpw *bvl)
+CPU_DISPATCH(EMPTY, (void), void, mpx_umul,
+            (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
+             const mpw *bv, const mpw *bvl),
+            (dv, dvl, av, avl, bv, bvl), pick_umul, simple_umul);
+
+static void simple_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
+                       const mpw *bv, const mpw *bvl)
 {
   /* --- This is probably worthwhile on a multiply --- */
 
@@ -913,6 +893,44 @@ void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
   }
 }
 
+#define MAYBE_UMUL4(impl)                                              \
+  extern void mpx_umul4_##impl(mpw */*dv*/,                            \
+                              const mpw */*av*/, const mpw */*avl*/,   \
+                              const mpw */*bv*/, const mpw */*bvl*/);  \
+  static void maybe_umul4_##impl(mpw *dv, mpw *dvl,                    \
+                                const mpw *av, const mpw *avl,         \
+                                const mpw *bv, const mpw *bvl)         \
+  {                                                                    \
+    size_t an = avl - av, bn = bvl - bv, dn = dvl - dv;                        \
+    if (!an || an%4 != 0 || !bn || bn%4 != 0 || dn < an + bn)          \
+      simple_umul(dv, dvl, av, avl, bv, bvl);                          \
+    else {                                                             \
+      mpx_umul4_##impl(dv, av, avl, bv, bvl);                          \
+      MPX_ZERO(dv + an + bn, dvl);                                     \
+    }                                                                  \
+  }
+
+#if CPUFAM_X86
+  MAYBE_UMUL4(x86_sse2)
+#endif
+
+#if CPUFAM_AMD64
+  MAYBE_UMUL4(amd64_sse2)
+#endif
+
+static mpx_umul__functype *pick_umul(void)
+{
+#if CPUFAM_X86
+  DISPATCH_PICK_COND(mpx_umul, maybe_umul4_x86_sse2,
+                    cpu_feature_p(CPUFEAT_X86_SSE2));
+#endif
+#if CPUFAM_AMD64
+  DISPATCH_PICK_COND(mpx_umul, maybe_umul4_amd64_sse2,
+                    cpu_feature_p(CPUFEAT_X86_SSE2));
+#endif
+  DISPATCH_PICK_FALLBACK(mpx_umul, simple_umul);
+}
+
 /* --- @mpx_umuln@ --- *
  *
  * Arguments:  @mpw *dv, *dvl@ = destination vector base and limit
@@ -1248,6 +1266,7 @@ mpw mpx_udivn(mpw *qv, mpw *qvl, const mpw *rv, const mpw *rvl, mpw d)
   size_t _sz = (sz);                                                   \
   mpw *_vv = xmalloc(MPWS(_sz));                                       \
   mpw *_vvl = _vv + _sz;                                               \
+  memset(_vv, 0xa5, MPWS(_sz));                                                \
   (v) = _vv;                                                           \
   (vl) = _vvl;                                                         \
 } while (0)