symm/salsa20.c, symm/salsa20-core.h: Permute input matrix for SIMD.
authorMark Wooding <mdw@distorted.org.uk>
Sat, 2 May 2015 16:05:20 +0000 (17:05 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Mon, 20 Jul 2015 12:54:22 +0000 (13:54 +0100)
Maintain the input matrix in the Salsa20 context structure in a permuted
form which makes SIMD implementations of the core function rather more
efficient.

symm/chacha-core.h
symm/salsa20-core.h
symm/salsa20.c

index ad6b05f..1c0efcd 100644 (file)
 
 /* The ChaCha feedforward step, used at the end of the core function.  Here,
  * @y@ contains the original input matrix; @z@ contains the final one, and is
- * updated.  This is the same as Salsa20.
+ * updated.  This is the same as Salsa20, only without the final permutation.
  */
-#define CHACHA_FFWD(z, y) SALSA20_FFWD(z, y)
+#define CHACHA_FFWD(z, y) do {                                         \
+  int _i;                                                              \
+  for (_i = 0; _i < 16; _i++) (z)[_i] += (y)[_i];                      \
+} while (0)
 
 /* Various numbers of rounds, unrolled.  Read from @y@, and write to @z@. */
 #define CHACHA_4R(z, y)                                                        \
index 98efa72..b27f222 100644 (file)
 
 /*----- The Salsa20 core function -----------------------------------------*/
 
+/* It makes life somewhat easier if we don't actually store and maintain the
+ * input matrix in the textbook order.  Instead, we rotate the columns other
+ * than the leftmost one upwards, so that the constants which were originally
+ * along the diagonal end up on the top row.  We'll need to undo this
+ * permutation on output, but that's not too terrible an imposition.
+ *
+ * The permutation we're applying to the matrix elements is this:
+ *
+ * [  0  1  2  3 ]      [  0  5 10 15 ]
+ * [  4  5  6  7 ]  -->  [  4  9 14  3 ]
+ * [  8  9 10 11 ]      [  8 13  2  7 ]
+ * [ 12 13 14 15 ]      [ 12  1  6 11 ]
+ *
+ * and as a result, we need to apply this inverse permutation to figure out
+ * which indices to use in the doublerow function and elsewhere.
+ *
+ * [  0 13 10  7 ]
+ * [  4  1 14 11 ]
+ * [  8  5  2 15 ]
+ * [ 12  9  6  3 ]
+ */
+
 /* The Salsa20 quarter-round.  Read from the matrix @y@ at indices @a@, @b@,
  * @c@, and @d@; and write to the corresponding elements of @z@.
  */
  */
 #define SALSA20_DR(z, y) do {                                          \
   SALSA20_QR(z, y,  0,  4,  8, 12);                                    \
-  SALSA20_QR(z, y,  5,  9, 13,  1);                                    \
-  SALSA20_QR(z, y, 10, 14,  2,  6);                                    \
-  SALSA20_QR(z, y, 15,  3,  7, 11);                                    \
-  SALSA20_QR(z, z,  0,  1,  2,  3);                                    \
-  SALSA20_QR(z, z,  5,  6,  7,  4);                                    \
-  SALSA20_QR(z, z, 10, 11,  8,  9);                                    \
-  SALSA20_QR(z, z, 15, 12, 13, 14);                                    \
+  SALSA20_QR(z, y,  1,  5,  9, 13);                                    \
+  SALSA20_QR(z, y,  2,  6, 10, 14);                                    \
+  SALSA20_QR(z, y,  3,  7, 11, 15);                                    \
+  SALSA20_QR(z, z,  0, 13, 10,  7);                                    \
+  SALSA20_QR(z, z,  1, 14, 11,  4);                                    \
+  SALSA20_QR(z, z,  2, 15,  8,  5);                                    \
+  SALSA20_QR(z, z,  3, 12,  9,  6);                                    \
 } while (0)
 
 /* The Salsa20 feedforward step, used at the end of the core function.  Here,
  * @y@ contains the original input matrix; @z@ contains the final one, and is
- * updated.
+ * updated.  The output is rendered in canonical order, ready for output.
  */
 #define SALSA20_FFWD(z, y) do {                                                \
-  int _i;                                                              \
-  for (_i = 0; _i < 16; _i++) (z)[_i] += (y)[_i];                      \
+  const uint32 *_y = (y);                                              \
+  uint32 *_z = (z);                                                    \
+  int _t;                                                              \
+  _z[ 0] = _z[ 0] + _y[ 0]; _z[ 4] = _z[ 4] + _y[ 4];                  \
+  _z[ 8] = _z[ 8] + _y[ 8]; _z[12] = _z[12] + _y[12];                  \
+      _t = _z[ 1] + _y[ 1]; _z[ 1] = _z[13] + _y[13];                  \
+  _z[13] = _z[ 9] + _y[ 9]; _z[ 9] = _z[ 5] + _y[ 5]; _z[ 5] = _t;     \
+      _t = _z[ 2] + _y[ 2]; _z[ 2] = _z[10] + _y[10]; _z[10] = _t;     \
+      _t = _z[ 6] + _y[ 6]; _z[ 6] = _z[14] + _y[14]; _z[14] = _t;     \
+      _t = _z[ 3] + _y[ 3]; _z[ 3] = _z[ 7] + _y[ 7];                  \
+  _z[ 7] = _z[11] + _y[11]; _z[11] = _z[15] + _y[15]; _z[15] = _t;     \
 } while (0)
 
 /* Various numbers of rounds, unrolled.  Read from @y@, and write to @z@. */
 
 /* Step the counter in the Salsa20 state matrix @a@. */
 #define SALSA20_STEP(a)                                                        \
-  do { (a)[8] = U32((a)[8] + 1); (a)[9] += !(a)[8]; } while (0)
+  do { (a)[8] = U32((a)[8] + 1); (a)[5] += !(a)[8]; } while (0)
 
 /*----- Buffering and output ----------------------------------------------*
  *
index c12945e..0104e5e 100644 (file)
@@ -61,33 +61,42 @@ static void populate(salsa20_matrix a, const void *key, size_t ksz)
 
   KSZ_ASSERT(salsa20, ksz);
 
-  a[ 1] = LOAD32_L(k +  0);
-  a[ 2] = LOAD32_L(k +  4);
+  /* Here's the pattern of key, constant, nonce, and counter pieces in the
+   * matrix, before and after our permutation.
+   *
+   * [ C0 K0 K1 K2 ]      [ C0 C1 C2 C3 ]
+   * [ K3 C1 N0 N1 ]  -->  [ K3 T1 K7 K2 ]
+   * [ T0 T1 C2 K4 ]      [ T0 K6 K1 N1 ]
+   * [ K5 K6 K7 C3 ]      [ K5 K0 N0 K4 ]
+   */
+
+  a[13] = LOAD32_L(k +  0);
+  a[10] = LOAD32_L(k +  4);
   if (ksz == 10) {
-    a[ 3] = LOAD16_L(k +  8);
+    a[ 7] = LOAD16_L(k +  8);
     a[ 4] = 0;
   } else {
-    a[ 3] = LOAD32_L(k +  8);
+    a[ 7] = LOAD32_L(k +  8);
     a[ 4] = LOAD32_L(k + 12);
   }
   if (ksz <= 16) {
-    a[11] = a[ 1];
-    a[12] = a[ 2];
-    a[13] = a[ 3];
-    a[14] = a[ 4];
+    a[15] = a[13];
+    a[12] = a[10];
+    a[ 9] = a[ 7];
+    a[ 6] = a[ 4];
     a[ 0] = SALSA20_A128;
-    a[ 5] = SALSA20_B128;
-    a[10] = ksz == 10 ? SALSA20_C80 : SALSA20_C128;
-    a[15] = SALSA20_D128;
+    a[ 1] = SALSA20_B128;
+    a[ 2] = ksz == 10 ? SALSA20_C80 : SALSA20_C128;
+    a[ 3] = SALSA20_D128;
   } else {
-    a[11] = LOAD32_L(k + 16);
+    a[15] = LOAD32_L(k + 16);
     a[12] = LOAD32_L(k + 20);
-    a[13] = LOAD32_L(k + 24);
-    a[14] = LOAD32_L(k + 28);
+    a[ 9] = LOAD32_L(k + 24);
+    a[ 6] = LOAD32_L(k + 28);
     a[ 0] = SALSA20_A256;
-    a[ 5] = SALSA20_B256;
-    a[10] = SALSA20_C256;
-    a[15] = SALSA20_D256;
+    a[ 1] = SALSA20_B256;
+    a[ 2] = SALSA20_C256;
+    a[ 3] = SALSA20_D256;
   }
 }
 
@@ -130,8 +139,8 @@ void salsa20_setnonce(salsa20_ctx *ctx, const void *nonce)
 {
   const octet *n = nonce;
 
-  ctx->a[6] = LOAD32_L(n + 0);
-  ctx->a[7] = LOAD32_L(n + 4);
+  ctx->a[14] = LOAD32_L(n + 0);
+  ctx->a[11] = LOAD32_L(n + 4);
   salsa20_seek(ctx, 0);
 }
 
@@ -153,7 +162,7 @@ void salsa20_seek(salsa20_ctx *ctx, unsigned long i)
 
 void salsa20_seeku64(salsa20_ctx *ctx, kludge64 i)
 {
-  ctx->a[8] = LO64(i); ctx->a[9] = HI64(i);
+  ctx->a[8] = LO64(i); ctx->a[5] = HI64(i);
   ctx->bufi = SALSA20_OUTSZ;
 }
 
@@ -169,7 +178,7 @@ unsigned long salsa20_tell(salsa20_ctx *ctx)
   { kludge64 i = salsa20_tellu64(ctx); return (GET64(unsigned long, i)); }
 
 kludge64 salsa20_tellu64(salsa20_ctx *ctx)
-  { kludge64 i; SET64(i, ctx->a[9], ctx->a[8]); return (i); }
+  { kludge64 i; SET64(i, ctx->a[5], ctx->a[8]); return (i); }
 
 /* --- @salsa20{,12,8}_encrypt@ --- *
  *
@@ -272,10 +281,10 @@ SALSA20_VARS(DEFENCRYPT)
      * speed critical, so we do it the harder way.                     \
      */                                                                        \
                                                                        \
-    for (i = 0; i < 4; i++) k[i + 6] = src[i];                         \
+    for (i = 0; i < 4; i++) k[14 - 3*i] = src[i];                      \
     core(r, k, a);                                                     \
-    for (i = 0; i < 4; i++) dest[i] = a[5*i] - k[5*i];                 \
-    for (i = 4; i < 8; i++) dest[i] = a[i + 2] - k[i + 2];             \
+    for (i = 0; i < 4; i++) dest[i] = a[5*i] - k[i];                   \
+    for (i = 4; i < 8; i++) dest[i] = a[i + 2] - k[26 - 3*i];          \
   }                                                                    \
                                                                        \
   void HSALSA20_PRF(r, salsa20_ctx *ctx, const void *src, void *dest)  \
@@ -340,9 +349,9 @@ SALSA20_VARS(DEFHSALSA20)
                                                                        \
     populate(ctx->k, key, ksz);                                                \
     ctx->s.a[ 0] = SALSA20_A256;                                       \
-    ctx->s.a[ 5] = SALSA20_B256;                                       \
-    ctx->s.a[10] = SALSA20_C256;                                       \
-    ctx->s.a[15] = SALSA20_D256;                                       \
+    ctx->s.a[ 1] = SALSA20_B256;                                       \
+    ctx->s.a[ 2] = SALSA20_C256;                                       \
+    ctx->s.a[ 3] = SALSA20_D256;                                       \
     XSALSA20_SETNONCE(r, ctx, nonce ? nonce : zerononce);              \
   }
 SALSA20_VARS(DEFXINIT)
@@ -371,8 +380,8 @@ SALSA20_VARS(DEFXINIT)
                                                                        \
     for (i = 0; i < 4; i++) in[i] = LOAD32_L(n + 4*i);                 \
     HSALSA20_RAW(r, ctx->k, in, out);                                  \
-    for (i = 0; i < 4; i++) ctx->s.a[i + 1] = out[i];                  \
-    for (i = 4; i < 8; i++) ctx->s.a[i + 7] = out[i];                  \
+    for (i = 0; i < 4; i++) ctx->s.a[13 - 3*i] = out[i];               \
+    for (i = 4; i < 8; i++) ctx->s.a[27 - 3*i] = out[i];               \
     salsa20_setnonce(&ctx->s, n + 16);                                 \
   }
 SALSA20_VARS(DEFXNONCE)
@@ -730,23 +739,31 @@ SALSA20_VARS(DEFXGRAND)
 #include <mLib/quis.h>
 #include <mLib/testrig.h>
 
+static const int perm[] = {
+   0, 13, 10,  7,
+   4,  1, 14, 11,
+   8,  5,  2, 15,
+  12,  9,  6,  3
+};
+
 #define DEFVCORE(r)                                                    \
   static int v_core_##r(dstr *v)                                       \
   {                                                                    \
     salsa20_matrix a, b;                                               \
     dstr d = DSTR_INIT;                                                        \
-    int i, n;                                                          \
+    int i, j, n;                                                       \
     int ok = 1;                                                                \
                                                                        \
     DENSURE(&d, SALSA20_OUTSZ); d.len = SALSA20_OUTSZ;                 \
     n = *(int *)v[0].buf;                                              \
     for (i = 0; i < SALSA20_OUTSZ/4; i++)                              \
-      a[i] = LOAD32_L(v[1].buf + 4*i);                                 \
+      b[i] = LOAD32_L(v[1].buf + 4*i);                                 \
     for (i = 0; i < n; i++) {                                          \
+      for (j = 0; j < 16; j++) a[perm[j]] = b[j];                      \
       core(r, a, b);                                                   \
       memcpy(a, b, sizeof(a));                                         \
     }                                                                  \
-    for (i = 0; i < SALSA20_OUTSZ/4; i++) STORE32_L(d.buf + 4*i, a[i]);        \
+    for (i = 0; i < SALSA20_OUTSZ/4; i++) STORE32_L(d.buf + 4*i, b[i]);        \
                                                                        \
     if (d.len != v[2].len || memcmp(d.buf, v[2].buf, v[2].len) != 0) { \
       ok = 0;                                                          \