vppinfra: native AES-CTR implementation

Type: feature
Change-Id: I7ef3277edaeb266fbd3c8c9355d4443002ed2311
Signed-off-by: Damjan Marion <damarion@cisco.com>
This commit is contained in:
Damjan Marion
2024-01-08 19:05:40 +00:00
committed by Mohammed HAWARI
parent bf40da413f
commit 9caef2a351
7 changed files with 882 additions and 210 deletions

View File

@ -26,21 +26,18 @@
#endif
#if defined(__VAES__) && defined(__AVX512F__)
#define N 16
#define u8xN u8x64
#define u32xN u32x16
#define u32xN_min_scalar u32x16_min_scalar
#define u32xN_is_all_zero u32x16_is_all_zero
#define u32xN_splat u32x16_splat
#elif defined(__VAES__)
#define N 8
#define u8xN u8x32
#define u32xN u32x8
#define u32xN_min_scalar u32x8_min_scalar
#define u32xN_is_all_zero u32x8_is_all_zero
#define u32xN_splat u32x8_splat
#else
#define N 4
#define u8xN u8x16
#define u32xN u32x4
#define u32xN_min_scalar u32x4_min_scalar
@ -58,17 +55,17 @@ aes_ops_enc_aes_cbc (vlib_main_t * vm, vnet_crypto_op_t * ops[],
u32 i, j, count, n_left = n_ops;
u32xN placeholder_mask = { };
u32xN len = { };
vnet_crypto_key_index_t key_index[N];
u8 *src[N] = { };
u8 *dst[N] = { };
vnet_crypto_key_index_t key_index[N_AES_BYTES];
u8 *src[N_AES_BYTES] = {};
u8 *dst[N_AES_BYTES] = {};
u8xN r[4] = {};
u8xN k[15][4] = {};
for (i = 0; i < N; i++)
for (i = 0; i < N_AES_BYTES; i++)
key_index[i] = ~0;
more:
for (i = 0; i < N; i++)
for (i = 0; i < N_AES_BYTES; i++)
if (len[i] == 0)
{
if (n_left == 0)
@ -160,16 +157,16 @@ more:
for (j = 1; j < rounds; j++)
{
r[0] = aes_enc_round (r[0], k[j][0]);
r[1] = aes_enc_round (r[1], k[j][1]);
r[2] = aes_enc_round (r[2], k[j][2]);
r[3] = aes_enc_round (r[3], k[j][3]);
r[0] = aes_enc_round_x1 (r[0], k[j][0]);
r[1] = aes_enc_round_x1 (r[1], k[j][1]);
r[2] = aes_enc_round_x1 (r[2], k[j][2]);
r[3] = aes_enc_round_x1 (r[3], k[j][3]);
}
r[0] = aes_enc_last_round (r[0], k[j][0]);
r[1] = aes_enc_last_round (r[1], k[j][1]);
r[2] = aes_enc_last_round (r[2], k[j][2]);
r[3] = aes_enc_last_round (r[3], k[j][3]);
r[0] = aes_enc_last_round_x1 (r[0], k[j][0]);
r[1] = aes_enc_last_round_x1 (r[1], k[j][1]);
r[2] = aes_enc_last_round_x1 (r[2], k[j][2]);
r[3] = aes_enc_last_round_x1 (r[3], k[j][3]);
aes_block_store (dst[0] + i, r[0]);
aes_block_store (dst[1] + i, r[1]);
@ -201,7 +198,7 @@ more:
len -= u32xN_splat (count);
for (i = 0; i < N; i++)
for (i = 0; i < N_AES_BYTES; i++)
{
src[i] += count;
dst[i] += count;

View File

@ -132,6 +132,7 @@ set(VPPINFRA_HEADERS
crypto/ghash.h
crypto/aes.h
crypto/aes_cbc.h
crypto/aes_ctr.h
crypto/aes_gcm.h
crypto/poly1305.h
dlist.h
@ -285,6 +286,7 @@ endif(VPP_BUILD_VPPINFRA_TESTS)
set(test_files
test/aes_cbc.c
test/aes_ctr.c
test/aes_gcm.c
test/poly1305.c
test/array_mask.c

View File

@ -15,8 +15,8 @@
*------------------------------------------------------------------
*/
#ifndef __aesni_h__
#define __aesni_h__
#ifndef __aes_h__
#define __aes_h__
typedef enum
{
@ -35,7 +35,7 @@ aes_block_load (u8 * p)
}
static_always_inline u8x16
aes_enc_round (u8x16 a, u8x16 k)
aes_enc_round_x1 (u8x16 a, u8x16 k)
{
#if defined (__AES__)
return (u8x16) _mm_aesenc_si128 ((__m128i) a, (__m128i) k);
@ -97,7 +97,7 @@ aes_dec_last_round_x2 (u8x32 a, u8x32 k)
#endif
static_always_inline u8x16
aes_enc_last_round (u8x16 a, u8x16 k)
aes_enc_last_round_x1 (u8x16 a, u8x16 k)
{
#if defined (__AES__)
return (u8x16) _mm_aesenclast_si128 ((__m128i) a, (__m128i) k);
@ -109,13 +109,13 @@ aes_enc_last_round (u8x16 a, u8x16 k)
#ifdef __x86_64__
static_always_inline u8x16
aes_dec_round (u8x16 a, u8x16 k)
aes_dec_round_x1 (u8x16 a, u8x16 k)
{
return (u8x16) _mm_aesdec_si128 ((__m128i) a, (__m128i) k);
}
static_always_inline u8x16
aes_dec_last_round (u8x16 a, u8x16 k)
aes_dec_last_round_x1 (u8x16 a, u8x16 k)
{
return (u8x16) _mm_aesdeclast_si128 ((__m128i) a, (__m128i) k);
}
@ -133,8 +133,8 @@ aes_encrypt_block (u8x16 block, const u8x16 * round_keys, aes_key_size_t ks)
int rounds = AES_KEY_ROUNDS (ks);
block ^= round_keys[0];
for (int i = 1; i < rounds; i += 1)
block = aes_enc_round (block, round_keys[i]);
return aes_enc_last_round (block, round_keys[rounds]);
block = aes_enc_round_x1 (block, round_keys[i]);
return aes_enc_last_round_x1 (block, round_keys[rounds]);
}
static_always_inline u8x16
@ -427,13 +427,67 @@ aes_key_enc_to_dec (u8x16 * ke, u8x16 * kd, aes_key_size_t ks)
kd[rounds / 2] = aes_inv_mix_column (ke[rounds / 2]);
}
#if defined(__VAES__) && defined(__AVX512F__)
#define N_AES_LANES 4
#define aes_load_partial(p, n) u8x64_load_partial ((u8 *) (p), n)
#define aes_store_partial(v, p, n) u8x64_store_partial (v, (u8 *) (p), n)
#define aes_reflect(r) u8x64_reflect_u8x16 (r)
typedef u8x64 aes_data_t;
typedef u8x64u aes_mem_t;
typedef u32x16 aes_counter_t;
#elif defined(__VAES__)
#define N_AES_LANES 2
#define aes_load_partial(p, n) u8x32_load_partial ((u8 *) (p), n)
#define aes_store_partial(v, p, n) u8x32_store_partial (v, (u8 *) (p), n)
#define aes_reflect(r) u8x32_reflect_u8x16 (r)
typedef u8x32 aes_data_t;
typedef u8x32u aes_mem_t;
typedef u32x8 aes_counter_t;
#else
#define N_AES_LANES 1
#define aes_load_partial(p, n) u8x16_load_partial ((u8 *) (p), n)
#define aes_store_partial(v, p, n) u8x16_store_partial (v, (u8 *) (p), n)
#define aes_reflect(r) u8x16_reflect (r)
typedef u8x16 aes_data_t;
typedef u8x16u aes_mem_t;
typedef u32x4 aes_counter_t;
#endif
#endif /* __aesni_h__ */
#define N_AES_BYTES (N_AES_LANES * 16)
/*
* fd.io coding-style-patch-verification: ON
*
* Local Variables:
* eval: (c-set-style "gnu")
* End:
*/
typedef union
{
u8x16 x1;
u8x32 x2;
u8x64 x4;
u8x16 lanes[4];
} aes_expaned_key_t;
static_always_inline void
aes_enc_round (aes_data_t *r, const aes_expaned_key_t *ek, uword n_blocks)
{
for (int i = 0; i < n_blocks; i++)
#if N_AES_LANES == 4
r[i] = aes_enc_round_x4 (r[i], ek->x4);
#elif N_AES_LANES == 2
r[i] = aes_enc_round_x2 (r[i], ek->x2);
#else
r[i] = aes_enc_round_x1 (r[i], ek->x1);
#endif
}
static_always_inline void
aes_enc_last_round (aes_data_t *r, aes_data_t *d, const aes_expaned_key_t *ek,
uword n_blocks)
{
for (int i = 0; i < n_blocks; i++)
#if N_AES_LANES == 4
d[i] ^= r[i] = aes_enc_last_round_x4 (r[i], ek->x4);
#elif N_AES_LANES == 2
d[i] ^= r[i] = aes_enc_last_round_x2 (r[i], ek->x2);
#else
d[i] ^= r[i] = aes_enc_last_round_x1 (r[i], ek->x1);
#endif
}
#endif /* __aes_h__ */

View File

@ -30,8 +30,8 @@ clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
#if __x86_64__
r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
for (j = 1; j < rounds; j++)
r = aes_enc_round (r, k[j]);
r = aes_enc_last_round (r, k[rounds]);
r = aes_enc_round_x1 (r, k[j]);
r = aes_enc_last_round_x1 (r, k[rounds]);
#else
r ^= *(u8x16u *) (src + i);
for (j = 1; j < rounds - 1; j++)
@ -85,16 +85,16 @@ aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
for (int i = 1; i < rounds; i++)
{
r[0] = aes_dec_round (r[0], k[i]);
r[1] = aes_dec_round (r[1], k[i]);
r[2] = aes_dec_round (r[2], k[i]);
r[3] = aes_dec_round (r[3], k[i]);
r[0] = aes_dec_round_x1 (r[0], k[i]);
r[1] = aes_dec_round_x1 (r[1], k[i]);
r[2] = aes_dec_round_x1 (r[2], k[i]);
r[3] = aes_dec_round_x1 (r[3], k[i]);
}
r[0] = aes_dec_last_round (r[0], k[rounds]);
r[1] = aes_dec_last_round (r[1], k[rounds]);
r[2] = aes_dec_last_round (r[2], k[rounds]);
r[3] = aes_dec_last_round (r[3], k[rounds]);
r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
r[1] = aes_dec_last_round_x1 (r[1], k[rounds]);
r[2] = aes_dec_last_round_x1 (r[2], k[rounds]);
r[3] = aes_dec_last_round_x1 (r[3], k[rounds]);
#else
for (int i = 0; i < rounds - 1; i++)
{
@ -125,8 +125,8 @@ aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
#if __x86_64__
r[0] ^= k[0];
for (int i = 1; i < rounds; i++)
r[0] = aes_dec_round (r[0], k[i]);
r[0] = aes_dec_last_round (r[0], k[rounds]);
r[0] = aes_dec_round_x1 (r[0], k[i]);
r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
#else
c[0] = r[0] = src[0];
for (int i = 0; i < rounds - 1; i++)
@ -469,8 +469,8 @@ aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
{
u8x16 rl = *(u8x16u *) src ^ k[0];
for (i = 1; i < rounds; i++)
rl = aes_dec_round (rl, k[i]);
rl = aes_dec_last_round (rl, k[i]);
rl = aes_dec_round_x1 (rl, k[i]);
rl = aes_dec_last_round_x1 (rl, k[i]);
*(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
}
}

View File

@ -0,0 +1,190 @@
/* SPDX-License-Identifier: Apache-2.0
* Copyright(c) 2024 Cisco Systems, Inc.
*/
#ifndef __crypto_aes_ctr_h__
#define __crypto_aes_ctr_h__
#include <vppinfra/clib.h>
#include <vppinfra/vector.h>
#include <vppinfra/cache.h>
#include <vppinfra/string.h>
#include <vppinfra/crypto/aes.h>
typedef struct
{
const aes_expaned_key_t exp_key[AES_KEY_ROUNDS (AES_KEY_256) + 1];
} aes_ctr_key_data_t;
typedef struct
{
const aes_expaned_key_t exp_key[AES_KEY_ROUNDS (AES_KEY_256) + 1];
aes_counter_t ctr; /* counter (reflected) */
u8 keystream_bytes[N_AES_BYTES]; /* keystream leftovers */
u32 n_keystream_bytes; /* number of keystream leftovers */
} aes_ctr_ctx_t;
static_always_inline aes_counter_t
aes_ctr_one_block (aes_ctr_ctx_t *ctx, aes_counter_t ctr, const u8 *src,
u8 *dst, u32 n_parallel, u32 n_bytes, int rounds, int last)
{
u32 __clib_aligned (N_AES_BYTES)
inc[] = { N_AES_LANES, 0, 0, 0, N_AES_LANES, 0, 0, 0,
N_AES_LANES, 0, 0, 0, N_AES_LANES, 0, 0, 0 };
const aes_expaned_key_t *k = ctx->exp_key;
const aes_mem_t *sv = (aes_mem_t *) src;
aes_mem_t *dv = (aes_mem_t *) dst;
aes_data_t d[4], t[4];
u32 r;
n_bytes -= (n_parallel - 1) * N_AES_BYTES;
/* AES First Round */
for (int i = 0; i < n_parallel; i++)
{
#if N_AES_LANES == 4
t[i] = k[0].x4 ^ (u8x64) aes_reflect ((u8x64) ctr);
#elif N_AES_LANES == 2
t[i] = k[0].x2 ^ (u8x32) aes_reflect ((u8x32) ctr);
#else
t[i] = k[0].x1 ^ (u8x16) aes_reflect ((u8x16) ctr);
#endif
ctr += *(aes_counter_t *) inc;
}
/* Load Data */
for (int i = 0; i < n_parallel - last; i++)
d[i] = sv[i];
if (last)
d[n_parallel - 1] =
aes_load_partial ((u8 *) (sv + n_parallel - 1), n_bytes);
/* AES Intermediate Rounds */
for (r = 1; r < rounds; r++)
aes_enc_round (t, k + r, n_parallel);
/* AES Last Round */
aes_enc_last_round (t, d, k + r, n_parallel);
/* Store Data */
for (int i = 0; i < n_parallel - last; i++)
dv[i] = d[i];
if (last)
{
aes_store_partial (d[n_parallel - 1], dv + n_parallel - 1, n_bytes);
*(aes_data_t *) ctx->keystream_bytes = t[n_parallel - 1];
ctx->n_keystream_bytes = N_AES_BYTES - n_bytes;
}
return ctr;
}
static_always_inline void
clib_aes_ctr_init (aes_ctr_ctx_t *ctx, const aes_ctr_key_data_t *kd,
const u8 *iv, aes_key_size_t ks)
{
u32x4 ctr = (u32x4) u8x16_reflect (*(u8x16u *) iv);
#if N_AES_LANES == 4
ctx->ctr = (aes_counter_t) u32x16_splat_u32x4 (ctr) +
(u32x16){ 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0 };
#elif N_AES_LANES == 2
ctx->ctr = (aes_counter_t) u32x8_splat_u32x4 (ctr) +
(u32x8){ 0, 0, 0, 0, 1, 0, 0, 0 };
#else
ctx->ctr = ctr;
#endif
for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
((aes_expaned_key_t *) ctx->exp_key)[i] = kd->exp_key[i];
ctx->n_keystream_bytes = 0;
}
static_always_inline void
clib_aes_ctr_transform (aes_ctr_ctx_t *ctx, const u8 *src, u8 *dst,
u32 n_bytes, aes_key_size_t ks)
{
int r = AES_KEY_ROUNDS (ks);
aes_counter_t ctr = ctx->ctr;
if (ctx->n_keystream_bytes)
{
u8 *ks = ctx->keystream_bytes + N_AES_BYTES - ctx->n_keystream_bytes;
if (ctx->n_keystream_bytes >= n_bytes)
{
for (int i = 0; i < n_bytes; i++)
dst[i] = src[i] ^ ks[i];
ctx->n_keystream_bytes -= n_bytes;
return;
}
for (int i = 0; i < ctx->n_keystream_bytes; i++)
dst++[0] = src++[0] ^ ks[i];
n_bytes -= ctx->n_keystream_bytes;
ctx->n_keystream_bytes = 0;
}
/* main loop */
for (int n = 4 * N_AES_BYTES; n_bytes >= n; n_bytes -= n, dst += n, src += n)
ctr = aes_ctr_one_block (ctx, ctr, src, dst, 4, n, r, 0);
if (n_bytes)
{
if (n_bytes > 3 * N_AES_BYTES)
ctr = aes_ctr_one_block (ctx, ctr, src, dst, 4, n_bytes, r, 1);
else if (n_bytes > 2 * N_AES_BYTES)
ctr = aes_ctr_one_block (ctx, ctr, src, dst, 3, n_bytes, r, 1);
else if (n_bytes > N_AES_BYTES)
ctr = aes_ctr_one_block (ctx, ctr, src, dst, 2, n_bytes, r, 1);
else
ctr = aes_ctr_one_block (ctx, ctr, src, dst, 1, n_bytes, r, 1);
}
else
ctx->n_keystream_bytes = 0;
ctx->ctr = ctr;
}
static_always_inline void
clib_aes_ctr_key_expand (aes_ctr_key_data_t *kd, const u8 *key,
aes_key_size_t ks)
{
u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1];
aes_expaned_key_t *k = (aes_expaned_key_t *) kd->exp_key;
/* expand AES key */
aes_key_expand (ek, key, ks);
for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
k[i].lanes[0] = k[i].lanes[1] = k[i].lanes[2] = k[i].lanes[3] = ek[i];
}
static_always_inline void
clib_aes128_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes,
const u8 *iv, u8 *dst)
{
aes_ctr_ctx_t ctx;
clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_128);
clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_128);
}
static_always_inline void
clib_aes192_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes,
const u8 *iv, u8 *dst)
{
aes_ctr_ctx_t ctx;
clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_192);
clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_192);
}
static_always_inline void
clib_aes256_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes,
const u8 *iv, u8 *dst)
{
aes_ctr_ctx_t ctx;
clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_256);
clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_256);
}
#endif /* __crypto_aes_ctr_h__ */

File diff suppressed because it is too large Load Diff

481
src/vppinfra/test/aes_ctr.c Normal file

File diff suppressed because it is too large Load Diff