Merge branch 'mdw/cpu-dispatch'
[catacomb] / symm / salsa20.c
index 4b35cbd..40f28fc 100644 (file)
@@ -7,11 +7,14 @@
 
 /*----- Header files ------------------------------------------------------*/
 
+#include "config.h"
+
 #include <stdarg.h>
 
 #include <mLib/bits.h>
 
 #include "arena.h"
+#include "dispatch.h"
 #include "gcipher.h"
 #include "grand.h"
 #include "keysz.h"
@@ -39,9 +42,29 @@ const octet salsa20_keysz[] = { KSZ_SET, 32, 16, 10, 0 };
  *             the feedforward step.
  */
 
-static void core(unsigned r, const salsa20_matrix src, salsa20_matrix dest)
+CPU_DISPATCH(static, (void),
+            void, core, (unsigned r, const salsa20_matrix src,
+                         salsa20_matrix dest),
+            (r, src, dest),
+            pick_core, simple_core);
+
+static void simple_core(unsigned r, const salsa20_matrix src,
+                       salsa20_matrix dest)
   { SALSA20_nR(dest, src, r); SALSA20_FFWD(dest, src); }
 
+#if CPUFAM_X86 || CPUFAM_AMD64
+extern core__functype salsa20_core_x86ish_sse2;
+#endif
+
+static core__functype *pick_core(void)
+{
+#if CPUFAM_X86 || CPUFAM_AMD64
+  DISPATCH_PICK_COND(salsa20_core, salsa20_core_x86ish_sse2,
+                    cpu_feature_p(CPUFEAT_X86_SSE2));
+#endif
+  DISPATCH_PICK_FALLBACK(salsa20_core, simple_core);
+}
+
 /* --- @populate@ --- *
  *
  * Arguments:  @salsa20_matrix a@ = a matrix to fill in
@@ -61,33 +84,42 @@ static void populate(salsa20_matrix a, const void *key, size_t ksz)
 
   KSZ_ASSERT(salsa20, ksz);
 
-  a[ 1] = LOAD32_L(k +  0);
-  a[ 2] = LOAD32_L(k +  4);
+  /* Here's the pattern of key, constant, nonce, and counter pieces in the
+   * matrix, before and after our permutation.
+   *
+   * [ C0 K0 K1 K2 ]      [ C0 C1 C2 C3 ]
+   * [ K3 C1 N0 N1 ]  -->  [ K3 T1 K7 K2 ]
+   * [ T0 T1 C2 K4 ]      [ T0 K6 K1 N1 ]
+   * [ K5 K6 K7 C3 ]      [ K5 K0 N0 K4 ]
+   */
+
+  a[13] = LOAD32_L(k +  0);
+  a[10] = LOAD32_L(k +  4);
   if (ksz == 10) {
-    a[ 3] = LOAD16_L(k +  8);
+    a[ 7] = LOAD16_L(k +  8);
     a[ 4] = 0;
   } else {
-    a[ 3] = LOAD32_L(k +  8);
+    a[ 7] = LOAD32_L(k +  8);
     a[ 4] = LOAD32_L(k + 12);
   }
   if (ksz <= 16) {
-    a[11] = a[ 1];
-    a[12] = a[ 2];
-    a[13] = a[ 3];
-    a[14] = a[ 4];
+    a[15] = a[13];
+    a[12] = a[10];
+    a[ 9] = a[ 7];
+    a[ 6] = a[ 4];
     a[ 0] = SALSA20_A128;
-    a[ 5] = SALSA20_B128;
-    a[10] = ksz == 10 ? SALSA20_C80 : SALSA20_C128;
-    a[15] = SALSA20_D128;
+    a[ 1] = SALSA20_B128;
+    a[ 2] = ksz == 10 ? SALSA20_C80 : SALSA20_C128;
+    a[ 3] = SALSA20_D128;
   } else {
-    a[11] = LOAD32_L(k + 16);
+    a[15] = LOAD32_L(k + 16);
     a[12] = LOAD32_L(k + 20);
-    a[13] = LOAD32_L(k + 24);
-    a[14] = LOAD32_L(k + 28);
+    a[ 9] = LOAD32_L(k + 24);
+    a[ 6] = LOAD32_L(k + 28);
     a[ 0] = SALSA20_A256;
-    a[ 5] = SALSA20_B256;
-    a[10] = SALSA20_C256;
-    a[15] = SALSA20_D256;
+    a[ 1] = SALSA20_B256;
+    a[ 2] = SALSA20_C256;
+    a[ 3] = SALSA20_D256;
   }
 }
 
@@ -130,8 +162,8 @@ void salsa20_setnonce(salsa20_ctx *ctx, const void *nonce)
 {
   const octet *n = nonce;
 
-  ctx->a[6] = LOAD32_L(n + 0);
-  ctx->a[7] = LOAD32_L(n + 4);
+  ctx->a[14] = LOAD32_L(n + 0);
+  ctx->a[11] = LOAD32_L(n + 4);
   salsa20_seek(ctx, 0);
 }
 
@@ -153,7 +185,7 @@ void salsa20_seek(salsa20_ctx *ctx, unsigned long i)
 
 void salsa20_seeku64(salsa20_ctx *ctx, kludge64 i)
 {
-  ctx->a[8] = LO64(i); ctx->a[9] = HI64(i);
+  ctx->a[8] = LO64(i); ctx->a[5] = HI64(i);
   ctx->bufi = SALSA20_OUTSZ;
 }
 
@@ -169,7 +201,7 @@ unsigned long salsa20_tell(salsa20_ctx *ctx)
   { kludge64 i = salsa20_tellu64(ctx); return (GET64(unsigned long, i)); }
 
 kludge64 salsa20_tellu64(salsa20_ctx *ctx)
-  { kludge64 i; SET64(i, ctx->a[9], ctx->a[8]); return (i); }
+  { kludge64 i; SET64(i, ctx->a[5], ctx->a[8]); return (i); }
 
 /* --- @salsa20{,12,8}_encrypt@ --- *
  *
@@ -272,10 +304,10 @@ SALSA20_VARS(DEFENCRYPT)
      * speed critical, so we do it the harder way.                     \
      */                                                                        \
                                                                        \
-    for (i = 0; i < 4; i++) k[i + 6] = src[i];                         \
+    for (i = 0; i < 4; i++) k[14 - 3*i] = src[i];                      \
     core(r, k, a);                                                     \
-    for (i = 0; i < 4; i++) dest[i] = a[5*i] - k[5*i];                 \
-    for (i = 4; i < 8; i++) dest[i] = a[i + 2] - k[i + 2];             \
+    for (i = 0; i < 4; i++) dest[i] = a[5*i] - k[i];                   \
+    for (i = 4; i < 8; i++) dest[i] = a[i + 2] - k[26 - 3*i];          \
   }                                                                    \
                                                                        \
   void HSALSA20_PRF(r, salsa20_ctx *ctx, const void *src, void *dest)  \
@@ -340,9 +372,9 @@ SALSA20_VARS(DEFHSALSA20)
                                                                        \
     populate(ctx->k, key, ksz);                                                \
     ctx->s.a[ 0] = SALSA20_A256;                                       \
-    ctx->s.a[ 5] = SALSA20_B256;                                       \
-    ctx->s.a[10] = SALSA20_C256;                                       \
-    ctx->s.a[15] = SALSA20_D256;                                       \
+    ctx->s.a[ 1] = SALSA20_B256;                                       \
+    ctx->s.a[ 2] = SALSA20_C256;                                       \
+    ctx->s.a[ 3] = SALSA20_D256;                                       \
     XSALSA20_SETNONCE(r, ctx, nonce ? nonce : zerononce);              \
   }
 SALSA20_VARS(DEFXINIT)
@@ -371,8 +403,8 @@ SALSA20_VARS(DEFXINIT)
                                                                        \
     for (i = 0; i < 4; i++) in[i] = LOAD32_L(n + 4*i);                 \
     HSALSA20_RAW(r, ctx->k, in, out);                                  \
-    for (i = 0; i < 4; i++) ctx->s.a[i + 1] = out[i];                  \
-    for (i = 4; i < 8; i++) ctx->s.a[i + 7] = out[i];                  \
+    for (i = 0; i < 4; i++) ctx->s.a[13 - 3*i] = out[i];               \
+    for (i = 4; i < 8; i++) ctx->s.a[27 - 3*i] = out[i];               \
     salsa20_setnonce(&ctx->s, n + 16);                                 \
   }
 SALSA20_VARS(DEFXNONCE)
@@ -730,23 +762,31 @@ SALSA20_VARS(DEFXGRAND)
 #include <mLib/quis.h>
 #include <mLib/testrig.h>
 
+static const int perm[] = {
+   0, 13, 10,  7,
+   4,  1, 14, 11,
+   8,  5,  2, 15,
+  12,  9,  6,  3
+};
+
 #define DEFVCORE(r)                                                    \
   static int v_core_##r(dstr *v)                                       \
   {                                                                    \
     salsa20_matrix a, b;                                               \
     dstr d = DSTR_INIT;                                                        \
-    int i, n;                                                          \
+    int i, j, n;                                                       \
     int ok = 1;                                                                \
                                                                        \
     DENSURE(&d, SALSA20_OUTSZ); d.len = SALSA20_OUTSZ;                 \
     n = *(int *)v[0].buf;                                              \
     for (i = 0; i < SALSA20_OUTSZ/4; i++)                              \
-      a[i] = LOAD32_L(v[1].buf + 4*i);                                 \
+      b[i] = LOAD32_L(v[1].buf + 4*i);                                 \
     for (i = 0; i < n; i++) {                                          \
+      for (j = 0; j < 16; j++) a[perm[j]] = b[j];                      \
       core(r, a, b);                                                   \
       memcpy(a, b, sizeof(a));                                         \
     }                                                                  \
-    for (i = 0; i < SALSA20_OUTSZ/4; i++) STORE32_L(d.buf + 4*i, a[i]);        \
+    for (i = 0; i < SALSA20_OUTSZ/4; i++) STORE32_L(d.buf + 4*i, b[i]);        \
                                                                        \
     if (d.len != v[2].len || memcmp(d.buf, v[2].buf, v[2].len) != 0) { \
       ok = 0;                                                          \