未验证 提交 d7850dcd 编写于 作者: T Tao Luo 提交者: GitHub

add noavx_axpy and noavx_axpy_noadd (#24207)

* remove double registery for pyramid_hash op

* add noavx_axpy and noavx_axpy_noadd

test=develop
上级 ec00d112
...@@ -53,7 +53,7 @@ if (NOT WITH_MKL OR NOT WITH_AVX) ...@@ -53,7 +53,7 @@ if (NOT WITH_MKL OR NOT WITH_AVX)
SET(OP_MKL_DEPS ${OP_MKL_DEPS} match_matrix_tensor_op) SET(OP_MKL_DEPS ${OP_MKL_DEPS} match_matrix_tensor_op)
SET(OP_MKL_DEPS ${OP_MKL_DEPS} var_conv_2d_op) SET(OP_MKL_DEPS ${OP_MKL_DEPS} var_conv_2d_op)
endif() endif()
if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32) if(WITH_COVERAGE OR WIN32)
SET(OP_MKL_DEPS ${OP_MKL_DEPS} pyramid_hash_op) SET(OP_MKL_DEPS ${OP_MKL_DEPS} pyramid_hash_op)
endif() endif()
......
...@@ -288,8 +288,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> { ...@@ -288,8 +288,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in; auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in;
auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in; auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in;
if (diff != 0.0) { if (diff != 0.0) {
avx_axpy(r_data, l_trans_diff, dim_in, diff); axpy(r_data, l_trans_diff, dim_in, diff);
avx_axpy(l_trans_data, r_diff, dim_in, diff); axpy(l_trans_data, r_diff, dim_in, diff);
} }
} }
} }
......
...@@ -385,8 +385,8 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> { ...@@ -385,8 +385,8 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
} }
auto weight_type = _blobs_0->type(); auto weight_type = _blobs_0->type();
if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) { if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) {
avx_axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1], axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1],
_drop_out_percent); _drop_out_percent);
} }
} }
}; };
...@@ -451,7 +451,7 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> { ...@@ -451,7 +451,7 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
int _space_len) const { int _space_len) const {
for (int j = 0; j != _num_emb; j += _rand_len) { for (int j = 0; j != _num_emb; j += _rand_len) {
unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len; unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len;
avx_axpy(top_pos + j, weights + pos, _rand_len, mlr); axpy(top_pos + j, weights + pos, _rand_len, mlr);
} }
} }
...@@ -525,9 +525,7 @@ REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad); ...@@ -525,9 +525,7 @@ REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
pyramid_hash, ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, float>, pyramid_hash, ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, float>,
ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, double>,
ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, int8_t>); ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, int8_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
pyramid_hash_grad, pyramid_hash_grad,
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, float>, ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, float>);
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, double>);
...@@ -83,16 +83,23 @@ static const unsigned int AVX_CUT_LEN_MASK = 7U; ...@@ -83,16 +83,23 @@ static const unsigned int AVX_CUT_LEN_MASK = 7U;
#define _mm256_store_px _mm256_storeu_ps #define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss #define _mm256_broadcast_sx _mm256_broadcast_ss
#define _mm256_mul_pd _mm256_mul_pd #define __m128x __m128
#define _mm256_add_pd _mm256_add_pd
#define _mm256_load_pd _mm256_loadu_pd
#define _mm256_store_pd _mm256_storeu_pd
#define _mm256_broadcast_sd _mm256_broadcast_sd
inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) { static const unsigned int SSE_STEP_SIZE = 2;
static const unsigned int SSE_CUT_LEN_MASK = 1U;
#define _mm_add_px _mm_add_ps
#define _mm_mul_px _mm_mul_ps
#define _mm_load_px _mm_loadu_ps
#define _mm_store_px _mm_storeu_ps
#define _mm_load1_px _mm_load1_ps
template <typename T>
inline void axpy(const T* x, T* y, size_t len, const T alpha) {
unsigned int jjj, lll; unsigned int jjj, lll;
jjj = lll = 0; jjj = lll = 0;
#ifdef PADDLE_WITH_AVX
lll = len & ~AVX_CUT_LEN_MASK; lll = len & ~AVX_CUT_LEN_MASK;
__m256x mm_alpha = _mm256_broadcast_sx(&alpha); __m256x mm_alpha = _mm256_broadcast_sx(&alpha);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
...@@ -101,66 +108,51 @@ inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) { ...@@ -101,66 +108,51 @@ inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) {
_mm256_add_px(_mm256_load_px(y + jjj), _mm256_add_px(_mm256_load_px(y + jjj),
_mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj)))); _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
} }
#else
for (; jjj < len; jjj++) { lll = len & ~SSE_CUT_LEN_MASK;
y[jjj] += alpha * x[jjj]; __m128x mm_alpha = _mm_load1_px(&alpha);
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(y + jjj,
_mm_add_px(_mm_load_px(y + jjj),
_mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
} }
}
inline void avx_axpy(const double* x, double* y, size_t len, #endif
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
lll = len & ~AVX_CUT_LEN_MASK;
double alpha_d = static_cast<double>(alpha);
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_pd(
y + jjj,
_mm256_add_pd(_mm256_load_pd(y + jjj),
_mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj))));
}
for (; jjj < len; jjj++) { for (; jjj < len; jjj++) {
y[jjj] += alpha * x[jjj]; y[jjj] += alpha * x[jjj];
} }
} }
inline void avx_axpy_noadd(const double* x, double* y, size_t len,
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
double alpha_d = static_cast<double>(alpha);
lll = len & ~AVX_CUT_LEN_MASK;
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_pd(y + jjj, _mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj)));
}
for (; jjj < len; jjj++) { template <typename T>
y[jjj] = alpha * x[jjj]; inline void axpy_noadd(const T* x, T* y, size_t len, const T alpha) {
}
}
inline void avx_axpy_noadd(const float* x, float* y, size_t len,
const float alpha) {
unsigned int jjj, lll; unsigned int jjj, lll;
jjj = lll = 0; jjj = lll = 0;
#ifdef PADDLE_WITH_AVX
lll = len & ~AVX_CUT_LEN_MASK; lll = len & ~AVX_CUT_LEN_MASK;
__m256x mm_alpha = _mm256_broadcast_sx(&alpha); __m256x mm_alpha = _mm256_broadcast_sx(&alpha);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) { for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_px(y + jjj, _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))); _mm256_store_px(y + jjj, _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj)));
} }
#else
lll = len & ~SSE_CUT_LEN_MASK;
__m128x mm_alpha = _mm_load1_px(&alpha);
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(y + jjj, _mm_mul_px(mm_alpha, _mm_load_px(x + jjj)));
}
#endif
for (; jjj < len; jjj++) { for (; jjj < len; jjj++) {
y[jjj] = alpha * x[jjj]; y[jjj] = alpha * x[jjj];
} }
} }
inline void avx_axpy_noadd(const int8_t* x, int8_t* y, size_t len,
const float alpha) { inline void axpy_noadd(const int8_t* x, int8_t* y, size_t len,
const float alpha) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"int8_t input of avx_axpy_noadd is not supported")); "int8_t input of axpy_noadd is not supported"));
} }
} // namespace operators } // namespace operators
......
...@@ -91,7 +91,7 @@ if(NOT WITH_MKL OR NOT WITH_AVX) ...@@ -91,7 +91,7 @@ if(NOT WITH_MKL OR NOT WITH_AVX)
list(REMOVE_ITEM TEST_OPS test_match_matrix_tensor_op) list(REMOVE_ITEM TEST_OPS test_match_matrix_tensor_op)
list(REMOVE_ITEM TEST_OPS test_var_conv_2d) list(REMOVE_ITEM TEST_OPS test_var_conv_2d)
endif() endif()
if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32) if(WITH_COVERAGE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_pyramid_hash_op) list(REMOVE_ITEM TEST_OPS test_pyramid_hash_op)
list(REMOVE_ITEM TEST_OPS test_fleet_pyramid_hash) list(REMOVE_ITEM TEST_OPS test_fleet_pyramid_hash)
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册