[v3,4/4] hash: add SVE support for bulk key lookup

Message ID 20231107121845.2758454-5-yoan.picchi@arm.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series hash: add SVE support for bulk key lookup |

Checks

Context Check Description
ci/checkpatch success coding style OK
ci/Intel-compilation warning apply issues

Commit Message

Yoan Picchi Nov. 7, 2023, 12:18 p.m. UTC
  - Implemented SVE code for comparing signatures in bulk lookup.
- Added Defines in code for SVE code support.
- Optimise NEON code
- New SVE code is ~5% slower than optimized NEON for N2 processor.

Signed-off-by: Yoan Picchi <yoan.picchi@arm.com>
Signed-off-by: Harjot Singh <harjot.singh@arm.com>
Reviewed-by: Nathan Brown <nathan.brown@arm.com>
Reviewed-by: Ruifeng Wang <ruifeng.wang@arm.com>
---
 lib/hash/rte_cuckoo_hash.c | 196 ++++++++++++++++++++++++++++---------
 lib/hash/rte_cuckoo_hash.h |   1 +
 2 files changed, 151 insertions(+), 46 deletions(-)
  

Patch

diff --git a/lib/hash/rte_cuckoo_hash.c b/lib/hash/rte_cuckoo_hash.c
index a4b907c45c..61637d02eb 100644
--- a/lib/hash/rte_cuckoo_hash.c
+++ b/lib/hash/rte_cuckoo_hash.c
@@ -435,8 +435,11 @@  rte_hash_create(const struct rte_hash_parameters *params)
 		h->sig_cmp_fn = RTE_HASH_COMPARE_SSE;
 	else
 #elif defined(RTE_ARCH_ARM64)
-	if (rte_cpu_get_flag_enabled(RTE_CPUFLAG_NEON))
+	if (rte_cpu_get_flag_enabled(RTE_CPUFLAG_NEON)) {
 		h->sig_cmp_fn = RTE_HASH_COMPARE_NEON;
+		if (rte_cpu_get_flag_enabled(RTE_CPUFLAG_SVE))
+			h->sig_cmp_fn = RTE_HASH_COMPARE_SVE;
+	}
 	else
 #endif
 		h->sig_cmp_fn = RTE_HASH_COMPARE_SCALAR;
@@ -1853,37 +1856,103 @@  rte_hash_free_key_with_position(const struct rte_hash *h,
 #if defined(__ARM_NEON)
 
 static inline void
-compare_signatures_dense(uint32_t *prim_hash_matches, uint32_t *sec_hash_matches,
-			const struct rte_hash_bucket *prim_bkt,
-			const struct rte_hash_bucket *sec_bkt,
+compare_signatures_dense(uint16_t *hitmask_buffer,
+			const uint16_t *prim_bucket_sigs,
+			const uint16_t *sec_bucket_sigs,
 			uint16_t sig,
 			enum rte_hash_sig_compare_function sig_cmp_fn)
 {
 	unsigned int i;
 
+	static_assert(sizeof(*hitmask_buffer) >= 2*(RTE_HASH_BUCKET_ENTRIES/8),
+	"The hitmask must be exactly wide enough to accept the whole hitmask if it is dense");
+
 	/* For match mask every bits indicates the match */
 	switch (sig_cmp_fn) {
+#if RTE_HASH_BUCKET_ENTRIES <= 8
 	case RTE_HASH_COMPARE_NEON: {
-		uint16x8_t vmat, x;
+		uint16x8_t vmat, hit1, hit2;
 		const uint16x8_t mask = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80};
 		const uint16x8_t vsig = vld1q_dup_u16((uint16_t const *)&sig);
 
 		/* Compare all signatures in the primary bucket */
-		vmat = vceqq_u16(vsig, vld1q_u16((uint16_t const *)prim_bkt->sig_current));
-		x = vandq_u16(vmat, mask);
-		*prim_hash_matches = (uint32_t)(vaddvq_u16(x));
+		vmat = vceqq_u16(vsig, vld1q_u16(prim_bucket_sigs));
+		hit1 = vandq_u16(vmat, mask);
+
 		/* Compare all signatures in the secondary bucket */
-		vmat = vceqq_u16(vsig, vld1q_u16((uint16_t const *)sec_bkt->sig_current));
-		x = vandq_u16(vmat, mask);
-		*sec_hash_matches = (uint32_t)(vaddvq_u16(x));
+		vmat = vceqq_u16(vsig, vld1q_u16(sec_bucket_sigs));
+		hit2 = vandq_u16(vmat, mask);
+
+		hit2 = vshlq_n_u16(hit2, RTE_HASH_BUCKET_ENTRIES);
+		hit2 = vorrq_u16(hit1, hit2);
+		*hitmask_buffer = vaddvq_u16(hit2);
+		}
+		break;
+#endif
+#if defined(RTE_HAS_SVE_ACLE)
+	case RTE_HASH_COMPARE_SVE: {
+		svuint16_t vsign, shift, sv_matches;
+		svbool_t pred, match, bucket_wide_pred;
+		int i = 0;
+		uint64_t vl = svcnth();
+
+		vsign = svdup_u16(sig);
+		shift = svindex_u16(0, 1);
+
+		if (vl >= 2 * RTE_HASH_BUCKET_ENTRIES && RTE_HASH_BUCKET_ENTRIES <= 8) {
+			svuint16_t primary_array_vect, secondary_array_vect;
+			bucket_wide_pred = svwhilelt_b16(0, RTE_HASH_BUCKET_ENTRIES);
+			primary_array_vect = svld1_u16(bucket_wide_pred, prim_bucket_sigs);
+			secondary_array_vect = svld1_u16(bucket_wide_pred, sec_bucket_sigs);
+
+			/* We merged the two vectors so we can do both comparison at once */
+			primary_array_vect = svsplice_u16(bucket_wide_pred,
+				primary_array_vect,
+				secondary_array_vect);
+			pred = svwhilelt_b16(0, 2*RTE_HASH_BUCKET_ENTRIES);
+
+			/* Compare all signatures in the buckets */
+			match = svcmpeq_u16(pred, vsign, primary_array_vect);
+			if (svptest_any(svptrue_b16(), match)) {
+				sv_matches = svdup_u16(1);
+				sv_matches = svlsl_u16_z(match, sv_matches, shift);
+				*hitmask_buffer = svorv_u16(svptrue_b16(), sv_matches);
+			}
+		} else {
+			do {
+				pred = svwhilelt_b16(i, RTE_HASH_BUCKET_ENTRIES);
+				uin16_t lower_half = 0;
+				uin16_t upper_half = 0;
+				/* Compare all signatures in the primary bucket */
+				match = svcmpeq_u16(pred, vsign, svld1_u16(pred,
+							&prim_bucket_sigs[i]));
+				if (svptest_any(svptrue_b16(), match)) {
+					sv_matches = svdup_u16(1);
+					sv_matches = svlsl_u16_z(match, sv_matches, shift);
+					lower_half = svorv_u16(svptrue_b16(), sv_matches);
+				}
+				/* Compare all signatures in the secondary bucket */
+				match = svcmpeq_u16(pred, vsign, svld1_u16(pred,
+							&sec_bucket_sigs[i]));
+				if (svptest_any(svptrue_b16(), match)) {
+					sv_matches = svdup_u16(1);
+					sv_matches = svlsl_u16_z(match, sv_matches, shift);
+					upper_half = svorv_u16(svptrue_b16(), sv_matches)
+						<< RTE_HASH_BUCKET_ENTRIES;
+				}
+				hitmask_buffer[i/8] = upper_half | lower_half;
+				i += vl;
+			} while (i < RTE_HASH_BUCKET_ENTRIES);
+		}
 		}
 		break;
+#endif
 	default:
 		for (i = 0; i < RTE_HASH_BUCKET_ENTRIES; i++) {
-			*prim_hash_matches |=
-				((sig == prim_bkt->sig_current[i]) << i);
-			*sec_hash_matches |=
-				((sig == sec_bkt->sig_current[i]) << i);
+			*hitmask_buffer |=
+				((sig == prim_bucket_sigs[i]) << i);
+			*hitmask_buffer |=
+				((sig == sec_bucket_sigs[i]) << i) << RTE_HASH_BUCKET_ENTRIES;
 		}
 	}
 }
@@ -1901,7 +1970,7 @@  compare_signatures_sparse(uint32_t *prim_hash_matches, uint32_t *sec_hash_matche
 
 	/* For match mask the first bit of every two bits indicates the match */
 	switch (sig_cmp_fn) {
-#if defined(__SSE2__)
+#if defined(__SSE2__) && RTE_HASH_BUCKET_ENTRIES <= 8
 	case RTE_HASH_COMPARE_SSE:
 		/* Compare all signatures in the bucket */
 		*prim_hash_matches = _mm_movemask_epi8(_mm_cmpeq_epi16(
@@ -1941,14 +2010,18 @@  __bulk_lookup_l(const struct rte_hash *h, const void **keys,
 	uint64_t hits = 0;
 	int32_t i;
 	int32_t ret;
-	uint32_t prim_hitmask[RTE_HASH_LOOKUP_BULK_MAX] = {0};
-	uint32_t sec_hitmask[RTE_HASH_LOOKUP_BULK_MAX] = {0};
 	struct rte_hash_bucket *cur_bkt, *next_bkt;
 
 #if defined(__ARM_NEON)
 	const int hitmask_padding = 0;
+	uint16_t hitmask_buffer[RTE_HASH_LOOKUP_BULK_MAX] = {0};
+
+	static_assert(sizeof(*hitmask_buffer)*8/2 == RTE_HASH_BUCKET_ENTRIES,
+	"The hitmask must be exactly wide enough to accept the whole hitmask when it is dense");
 #else
 	const int hitmask_padding = 1;
+	uint32_t prim_hitmask_buffer[RTE_HASH_LOOKUP_BULK_MAX] = {0};
+	uint32_t sec_hitmask_buffer[RTE_HASH_LOOKUP_BULK_MAX] = {0};
 #endif
 
 	__hash_rw_reader_lock(h);
@@ -1956,18 +2029,24 @@  __bulk_lookup_l(const struct rte_hash *h, const void **keys,
 	/* Compare signatures and prefetch key slot of first hit */
 	for (i = 0; i < num_keys; i++) {
 #if defined(__ARM_NEON)
-		compare_signatures_dense(&prim_hitmask[i], &sec_hitmask[i],
-			primary_bkt[i], secondary_bkt[i],
+		uint16_t *hitmask = &hitmask_buffer[i];
+		compare_signatures_dense(hitmask,
+			primary_bkt[i]->sig_current,
+			secondary_bkt[i]->sig_current,
 			sig[i], h->sig_cmp_fn);
+		const unsigned int prim_hitmask = *(uint8_t *)(hitmask);
+		const unsigned int sec_hitmask = *((uint8_t *)(hitmask)+1);
 #else
-		compare_signatures_sparse(&prim_hitmask[i], &sec_hitmask[i],
+		compare_signatures_sparse(&prim_hitmask_buffer[i], &sec_hitmask_buffer[i],
 			primary_bkt[i], secondary_bkt[i],
 			sig[i], h->sig_cmp_fn);
+		const unsigned int prim_hitmask = prim_hitmask_buffer[i];
+		const unsigned int sec_hitmask = sec_hitmask_buffer[i];
 #endif
 
-		if (prim_hitmask[i]) {
+		if (prim_hitmask) {
 			uint32_t first_hit =
-					__builtin_ctzl(prim_hitmask[i])
+					__builtin_ctzl(prim_hitmask)
 					>> hitmask_padding;
 			uint32_t key_idx =
 				primary_bkt[i]->key_idx[first_hit];
@@ -1979,9 +2058,9 @@  __bulk_lookup_l(const struct rte_hash *h, const void **keys,
 			continue;
 		}
 
-		if (sec_hitmask[i]) {
+		if (sec_hitmask) {
 			uint32_t first_hit =
-					__builtin_ctzl(sec_hitmask[i])
+					__builtin_ctzl(sec_hitmask)
 					>> hitmask_padding;
 			uint32_t key_idx =
 				secondary_bkt[i]->key_idx[first_hit];
@@ -1996,9 +2075,17 @@  __bulk_lookup_l(const struct rte_hash *h, const void **keys,
 	/* Compare keys, first hits in primary first */
 	for (i = 0; i < num_keys; i++) {
 		positions[i] = -ENOENT;
-		while (prim_hitmask[i]) {
+#if defined(__ARM_NEON)
+		uint16_t *hitmask = &hitmask_buffer[i];
+		unsigned int prim_hitmask = *(uint8_t *)(hitmask);
+		unsigned int sec_hitmask = *((uint8_t *)(hitmask)+1);
+#else
+		unsigned int prim_hitmask = prim_hitmask_buffer[i];
+		unsigned int sec_hitmask = sec_hitmask_buffer[i];
+#endif
+		while (prim_hitmask) {
 			uint32_t hit_index =
-					__builtin_ctzl(prim_hitmask[i])
+					__builtin_ctzl(prim_hitmask)
 					>> hitmask_padding;
 			uint32_t key_idx =
 				primary_bkt[i]->key_idx[hit_index];
@@ -2021,12 +2108,12 @@  __bulk_lookup_l(const struct rte_hash *h, const void **keys,
 				positions[i] = key_idx - 1;
 				goto next_key;
 			}
-			prim_hitmask[i] &= ~(1 << (hit_index << hitmask_padding));
+			prim_hitmask &= ~(1 << (hit_index << hitmask_padding));
 		}
 
-		while (sec_hitmask[i]) {
+		while (sec_hitmask) {
 			uint32_t hit_index =
-					__builtin_ctzl(sec_hitmask[i])
+					__builtin_ctzl(sec_hitmask)
 					>> hitmask_padding;
 			uint32_t key_idx =
 				secondary_bkt[i]->key_idx[hit_index];
@@ -2050,7 +2137,7 @@  __bulk_lookup_l(const struct rte_hash *h, const void **keys,
 				positions[i] = key_idx - 1;
 				goto next_key;
 			}
-			sec_hitmask[i] &= ~(1 << (hit_index << hitmask_padding));
+			sec_hitmask &= ~(1 << (hit_index << hitmask_padding));
 		}
 next_key:
 		continue;
@@ -2100,15 +2187,18 @@  __bulk_lookup_lf(const struct rte_hash *h, const void **keys,
 	uint64_t hits = 0;
 	int32_t i;
 	int32_t ret;
-	uint32_t prim_hitmask[RTE_HASH_LOOKUP_BULK_MAX] = {0};
-	uint32_t sec_hitmask[RTE_HASH_LOOKUP_BULK_MAX] = {0};
 	struct rte_hash_bucket *cur_bkt, *next_bkt;
 	uint32_t cnt_b, cnt_a;
 
 #if defined(__ARM_NEON)
 	const int hitmask_padding = 0;
+	uint16_t hitmask_buffer[RTE_HASH_LOOKUP_BULK_MAX] = {0};
+	static_assert(sizeof(*hitmask_buffer)*8/2 == RTE_HASH_BUCKET_ENTRIES,
+	"The hitmask must be exactly wide enough to accept the whole hitmask chen it is dense");
 #else
 	const int hitmask_padding = 1;
+	uint32_t prim_hitmask_buffer[RTE_HASH_LOOKUP_BULK_MAX] = {0};
+	uint32_t sec_hitmask_buffer[RTE_HASH_LOOKUP_BULK_MAX] = {0};
 #endif
 
 	for (i = 0; i < num_keys; i++)
@@ -2125,18 +2215,24 @@  __bulk_lookup_lf(const struct rte_hash *h, const void **keys,
 		/* Compare signatures and prefetch key slot of first hit */
 		for (i = 0; i < num_keys; i++) {
 #if defined(__ARM_NEON)
-			compare_signatures_dense(&prim_hitmask[i], &sec_hitmask[i],
-				primary_bkt[i], secondary_bkt[i],
+			uint16_t *hitmask = &hitmask_buffer[i];
+			compare_signatures_dense(hitmask,
+				primary_bkt[i]->sig_current,
+				secondary_bkt[i]->sig_current,
 				sig[i], h->sig_cmp_fn);
+			const unsigned int prim_hitmask = *(uint8_t *)(hitmask);
+			const unsigned int sec_hitmask = *((uint8_t *)(hitmask)+1);
 #else
-			compare_signatures_sparse(&prim_hitmask[i], &sec_hitmask[i],
+			compare_signatures_sparse(&prim_hitmask_buffer[i], &sec_hitmask_buffer[i],
 				primary_bkt[i], secondary_bkt[i],
 				sig[i], h->sig_cmp_fn);
+			const unsigned int prim_hitmask = prim_hitmask_buffer[i];
+			const unsigned int sec_hitmask = sec_hitmask_buffer[i];
 #endif
 
-			if (prim_hitmask[i]) {
+			if (prim_hitmask) {
 				uint32_t first_hit =
-						__builtin_ctzl(prim_hitmask[i])
+						__builtin_ctzl(prim_hitmask)
 						>> hitmask_padding;
 				uint32_t key_idx =
 					primary_bkt[i]->key_idx[first_hit];
@@ -2148,9 +2244,9 @@  __bulk_lookup_lf(const struct rte_hash *h, const void **keys,
 				continue;
 			}
 
-			if (sec_hitmask[i]) {
+			if (sec_hitmask) {
 				uint32_t first_hit =
-						__builtin_ctzl(sec_hitmask[i])
+						__builtin_ctzl(sec_hitmask)
 						>> hitmask_padding;
 				uint32_t key_idx =
 					secondary_bkt[i]->key_idx[first_hit];
@@ -2164,9 +2260,17 @@  __bulk_lookup_lf(const struct rte_hash *h, const void **keys,
 
 		/* Compare keys, first hits in primary first */
 		for (i = 0; i < num_keys; i++) {
-			while (prim_hitmask[i]) {
+#if defined(__ARM_NEON)
+			uint16_t *hitmask = &hitmask_buffer[i];
+			unsigned int prim_hitmask = *(uint8_t *)(hitmask);
+			unsigned int sec_hitmask = *((uint8_t *)(hitmask)+1);
+#else
+			unsigned int prim_hitmask = prim_hitmask_buffer[i];
+			unsigned int sec_hitmask = sec_hitmask_buffer[i];
+#endif
+			while (prim_hitmask) {
 				uint32_t hit_index =
-						__builtin_ctzl(prim_hitmask[i])
+						__builtin_ctzl(prim_hitmask)
 						>> hitmask_padding;
 				uint32_t key_idx =
 				__atomic_load_n(
@@ -2193,12 +2297,12 @@  __bulk_lookup_lf(const struct rte_hash *h, const void **keys,
 					positions[i] = key_idx - 1;
 					goto next_key;
 				}
-				prim_hitmask[i] &= ~(1 << (hit_index << hitmask_padding));
+				prim_hitmask &= ~(1 << (hit_index << hitmask_padding));
 			}
 
-			while (sec_hitmask[i]) {
+			while (sec_hitmask) {
 				uint32_t hit_index =
-						__builtin_ctzl(sec_hitmask[i])
+						__builtin_ctzl(sec_hitmask)
 						>> hitmask_padding;
 				uint32_t key_idx =
 				__atomic_load_n(
@@ -2226,7 +2330,7 @@  __bulk_lookup_lf(const struct rte_hash *h, const void **keys,
 					positions[i] = key_idx - 1;
 					goto next_key;
 				}
-				sec_hitmask[i] &= ~(1 << (hit_index << hitmask_padding));
+				sec_hitmask &= ~(1 << (hit_index << hitmask_padding));
 			}
 next_key:
 			continue;
diff --git a/lib/hash/rte_cuckoo_hash.h b/lib/hash/rte_cuckoo_hash.h
index eb2644f74b..356ec2a69e 100644
--- a/lib/hash/rte_cuckoo_hash.h
+++ b/lib/hash/rte_cuckoo_hash.h
@@ -148,6 +148,7 @@  enum rte_hash_sig_compare_function {
 	RTE_HASH_COMPARE_SCALAR = 0,
 	RTE_HASH_COMPARE_SSE,
 	RTE_HASH_COMPARE_NEON,
+	RTE_HASH_COMPARE_SVE,
 	RTE_HASH_COMPARE_NUM
 };