未验证 提交 efbdad05 编写于 作者: T Tao Luo 提交者: GitHub

make search_compute support avx default (#20779)

* make search_compute support avx only

* clean search_compute.h

* rename sse_axpy to avx_axpy

test=develop

* update CMakeLists.txt

test=develop
上级 3556514e
...@@ -49,7 +49,7 @@ if (WITH_DISTRIBUTE) ...@@ -49,7 +49,7 @@ if (WITH_DISTRIBUTE)
endif() endif()
SET(OP_ONLY_MKL "") SET(OP_ONLY_MKL "")
if (NOT WITH_MKL) if (NOT WITH_MKL OR NOT WITH_AVX)
SET(OP_ONLY_MKL ${OP_ONLY_MKL} match_matrix_tensor_op) SET(OP_ONLY_MKL ${OP_ONLY_MKL} match_matrix_tensor_op)
SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op) SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op)
endif() endif()
......
...@@ -286,8 +286,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> { ...@@ -286,8 +286,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in; auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in;
auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in; auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in;
if (diff != 0.0) { if (diff != 0.0) {
sse_axpy(r_data, l_trans_diff, dim_in, diff); avx_axpy(r_data, l_trans_diff, dim_in, diff);
sse_axpy(l_trans_data, r_diff, dim_in, diff); avx_axpy(l_trans_data, r_diff, dim_in, diff);
} }
} }
} }
......
...@@ -73,22 +73,10 @@ void call_gemm_batched(const framework::ExecutionContext& ctx, ...@@ -73,22 +73,10 @@ void call_gemm_batched(const framework::ExecutionContext& ctx,
} }
} }
#ifndef TYPE_USE_FLOAT
#define TYPE_USE_FLOAT
#endif
#ifndef USE_SSE
#define USE_SSE
#endif
#if defined(TYPE_USE_FLOAT)
#define __m256x __m256 #define __m256x __m256
#define __m128x __m128
static const unsigned int AVX_STEP_SIZE = 8; static const unsigned int AVX_STEP_SIZE = 8;
static const unsigned int SSE_STEP_SIZE = 4;
static const unsigned int AVX_CUT_LEN_MASK = 7U; static const unsigned int AVX_CUT_LEN_MASK = 7U;
static const unsigned int SSE_CUT_LEN_MASK = 3U;
#define _mm256_mul_px _mm256_mul_ps #define _mm256_mul_px _mm256_mul_ps
#define _mm256_add_px _mm256_add_ps #define _mm256_add_px _mm256_add_ps
...@@ -96,20 +84,11 @@ static const unsigned int SSE_CUT_LEN_MASK = 3U; ...@@ -96,20 +84,11 @@ static const unsigned int SSE_CUT_LEN_MASK = 3U;
#define _mm256_store_px _mm256_storeu_ps #define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss #define _mm256_broadcast_sx _mm256_broadcast_ss
#define _mm_add_px _mm_add_ps
#define _mm_mul_px _mm_mul_ps
#define _mm_load_px _mm_loadu_ps
#define _mm_store_px _mm_storeu_ps
#define _mm_load1_px _mm_load1_ps
#endif
template <typename T> template <typename T>
inline void sse_axpy(const T* x, T* y, size_t len, const T alpha) { inline void avx_axpy(const T* x, T* y, size_t len, const T alpha) {
unsigned int jjj, lll; unsigned int jjj, lll;
jjj = lll = 0; jjj = lll = 0;
#if defined(USE_AVX)
lll = len & ~AVX_CUT_LEN_MASK; lll = len & ~AVX_CUT_LEN_MASK;
__m256x mm_alpha = _mm256_broadcast_sx(&alpha); __m256x mm_alpha = _mm256_broadcast_sx(&alpha);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
...@@ -119,16 +98,6 @@ inline void sse_axpy(const T* x, T* y, size_t len, const T alpha) { ...@@ -119,16 +98,6 @@ inline void sse_axpy(const T* x, T* y, size_t len, const T alpha) {
_mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj)))); _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
} }
#elif defined(USE_SSE)
lll = len & ~SSE_CUT_LEN_MASK;
__m128x mm_alpha = _mm_load1_px(&alpha);
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(y + jjj,
_mm_add_px(_mm_load_px(y + jjj),
_mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
}
#endif
for (; jjj < len; jjj++) { for (; jjj < len; jjj++) {
y[jjj] += alpha * x[jjj]; y[jjj] += alpha * x[jjj];
} }
......
...@@ -70,10 +70,10 @@ if(NOT WITH_MKLML) ...@@ -70,10 +70,10 @@ if(NOT WITH_MKLML)
list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op) list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
endif() endif()
if(NOT WITH_MKL) if(NOT WITH_MKL OR NOT WITH_AVX)
list(REMOVE_ITEM TEST_OPS test_match_matrix_tensor_op) list(REMOVE_ITEM TEST_OPS test_match_matrix_tensor_op)
list(REMOVE_ITEM TEST_OPS test_var_conv_2d) list(REMOVE_ITEM TEST_OPS test_var_conv_2d)
endif(NOT WITH_MKL) endif()
if(WITH_GPU OR NOT WITH_MKLML) if(WITH_GPU OR NOT WITH_MKLML)
# matmul with multiple heads need MKL support # matmul with multiple heads need MKL support
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册