From 70540b2684c5bef920f3bd0c445b391ce9f9fb49 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 4 Mar 2022 17:02:51 +0800 Subject: [PATCH] [phi] move cpu_vec (#39714) move cpu_vec.h to phi/kernels/funcs. --- paddle/fluid/operators/attention_lstm_op.cc | 18 +- .../fused/fused_embedding_fc_lstm_op.cc | 6 +- .../fused/fusion_seqexpand_concat_fc_op.cc | 6 +- paddle/fluid/operators/math/CMakeLists.txt | 1 - paddle/phi/kernels/funcs/cpu_vec.h | 675 ++++++++++++++++++ paddle/phi/tests/kernels/CMakeLists.txt | 2 + .../tests/kernels/test_cpu_vec.cc} | 112 +-- 7 files changed, 756 insertions(+), 64 deletions(-) create mode 100644 paddle/phi/kernels/funcs/cpu_vec.h rename paddle/{fluid/operators/math/cpu_vec_test.cc => phi/tests/kernels/test_cpu_vec.cc} (75%) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index a23e484d0a8..78ea8b6b6fb 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" namespace paddle { namespace operators { @@ -269,10 +269,10 @@ use lstm_x_t as input and compute as standard LSTM. template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { if (bias) { - math::vec_add_bias(n, *bias, x, y); - math::vec_relu(n, y, y); + phi::funcs::vec_add_bias(n, *bias, x, y); + phi::funcs::vec_relu(n, y, y); } else { - math::vec_relu(n, x, y); + phi::funcs::vec_relu(n, x, y); } } @@ -283,14 +283,14 @@ inline void vec_softmax(const int n, const T* x, T* y) { for (int i = 1; i < n; ++i) { scalar = scalar < x[i] ? x[i] : scalar; } - math::vec_add_bias(n, -scalar, x, y); // sub - math::vec_exp(n, y, y); // exp + phi::funcs::vec_add_bias(n, -scalar, x, y); // sub + phi::funcs::vec_exp(n, y, y); // exp // sum scalar = T(0); for (int i = 0; i < n; ++i) { scalar += y[i]; } - math::vec_scal(n, static_cast(1) / scalar, y); // scale + phi::funcs::vec_scal(n, static_cast(1) / scalar, y); // scale } template @@ -344,12 +344,12 @@ class AttentionLSTMKernel : public framework::OpKernel { auto& act_cell_str = ctx.Attr("cell_activation"); auto& act_cand_str = ctx.Attr("candidate_activation"); if (platform::MayIUse(platform::avx)) { - math::VecActivations act_functor; + phi::funcs::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); } else { - math::VecActivations act_functor; + phi::funcs::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 0c83c36b475..7308f307792 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.h" #include -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/sequence2batch.h" namespace paddle { @@ -243,12 +243,12 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { auto& act_cell_str = ctx.Attr("cell_activation"); \ auto& act_cand_str = ctx.Attr("candidate_activation"); \ if (platform::MayIUse(platform::avx)) { \ - math::VecActivations act_functor; \ + phi::funcs::VecActivations act_functor; \ act_gate = act_functor(act_gate_str); \ act_cell = act_functor(act_cell_str); \ act_cand = act_functor(act_cand_str); \ } else { \ - math::VecActivations act_functor; \ + phi::funcs::VecActivations act_functor; \ act_gate = act_functor(act_gate_str); \ act_cell = act_functor(act_cell_str); \ act_cand = act_functor(act_cand_str); \ diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 88fb7349d53..1000d0488dc 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" #include -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" namespace paddle { namespace operators { @@ -196,10 +196,10 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { std::function fc_act; auto& fc_act_str = ctx.Attr("fc_activation"); if (platform::MayIUse(platform::avx)) { - math::VecActivations act_functor; + phi::funcs::VecActivations act_functor; fc_act = act_functor(fc_act_str); } else { - math::VecActivations act_functor; + phi::funcs::VecActivations act_functor; fc_act = act_functor(fc_act_str); } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index ba047355ad7..14b12ca3acb 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -70,7 +70,6 @@ if(WITH_GPU AND (NOT WITH_ROCM)) endif() endif() -cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) if(WITH_TESTING AND TEST im2col_test) set_tests_properties(im2col_test PROPERTIES TIMEOUT 120) endif() diff --git a/paddle/phi/kernels/funcs/cpu_vec.h b/paddle/phi/kernels/funcs/cpu_vec.h new file mode 100644 index 00000000000..7bb2a5fcfb3 --- /dev/null +++ b/paddle/phi/kernels/funcs/cpu_vec.h @@ -0,0 +1,675 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +namespace phi { +namespace funcs { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 + +#define YMM_FLOAT_BLOCK 8 +#define AVX_DOUBLE_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define AVX2_DOUBLE_BLOCK 4 +#define ZMM_FLOAT_BLOCK 16 +#define AVX512_DOUBLE_BLOCK 8 + +template +inline void vec_exp(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +} + +template +inline void vec_scal(const int n, const T a, T* x) { + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +} + +#ifdef PADDLE_WITH_MKLML +template <> +inline void vec_exp(const int n, const float* x, float* y) { + constexpr int small_enough = 128; + if (n < small_enough) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } + } else { + paddle::platform::dynload::vsExp(n, x, y); + } +} + +template <> +inline void vec_exp(const int n, const double* x, double* y) { + paddle::platform::dynload::vdExp(n, x, y); +} + +template <> +inline void vec_scal(const int n, const float a, float* x) { + paddle::platform::dynload::cblas_sscal(n, a, x, 1); +} + +template <> +inline void vec_scal(const int n, const double a, double* x) { + paddle::platform::dynload::cblas_dscal(n, a, x, 1); +} +#endif + +// MKL scal only support inplace, choose this if src and dst are not equal +template +inline void vec_scal(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = a * x[i]; + } +} + +template <> +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_scal(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 scalar = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = a * x[i]; + } +#else + vec_scal(n, a, x, y); +#endif +} + +template <> +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { + vec_scal(n, a, x, y); +} + +template <> +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_scal(n, a, x, y); +} + +template +inline void vec_sum(const size_t n, const T* x, T* s) { + s[0] = x[0]; + for (size_t i = 1; i < n; ++i) { + s[0] += x[i]; + } +} + +template <> +inline void vec_sum(const size_t n, + const float* x, + float* s) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_sum(n, x, s); + return; + } + + unsigned int i, end; + i = end = 0; + s[0] = 0.f; + + end = n & ~(block - 1); + __m256 tmp = _mm256_setzero_ps(); + for (i = 0; i < end; i += block) { + tmp = _mm256_add_ps(tmp, _mm256_loadu_ps(x + i)); + } + + __m256 hsum = _mm256_hadd_ps(tmp, tmp); + hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1)); + _mm_store_ss( + s, + _mm_hadd_ps(_mm256_castps256_ps128(hsum), _mm256_castps256_ps128(hsum))); + + for (; i < n; i++) { + s[0] += x[i]; + } +#else + vec_sum(n, x, s); +#endif +} + +template +inline void vec_mul(const size_t n, const T* x, const T* y, T* z) { + for (size_t i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +} + +template <> +inline void vec_mul(const size_t n, + const float* x, + const float* y, + float* z) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_mul(n, x, y, z); + return; + } + + unsigned int i = 0, end = 0; + end = n & ~(block - 1); + for (i = 0; i < end; i += block) { + _mm256_storeu_ps( + z + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), _mm256_loadu_ps(y + i))); + } + + for (; i < n; i++) { + z[i] = x[i] * y[i]; + } +#else + vec_mul(n, x, y, z); +#endif +} + +template +inline void vec_mul_reduce(const size_t n, const T* x, const T* y, T* z) { + z[0] = x[0] * y[0]; + for (size_t i = 1; i < n; ++i) { + z[0] += x[i] * y[i]; + } +} + +template <> +inline void vec_mul_reduce(const size_t n, + const float* x, + const float* y, + float* z) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_mul_reduce(n, x, y, z); + return; + } + + unsigned int i = 0, end = 0; + z[0] = 0.f; + + end = n & ~(block - 1); + __m256 tmp = _mm256_setzero_ps(); + for (i = 0; i < end; i += block) { + tmp = _mm256_add_ps( + tmp, _mm256_mul_ps(_mm256_loadu_ps(x + i), _mm256_loadu_ps(y + i))); + } + + __m256 hsum = _mm256_hadd_ps(tmp, tmp); + hsum = _mm256_add_ps(hsum, _mm256_permute2f128_ps(hsum, hsum, 0x1)); + _mm_store_ss( + z, + _mm_hadd_ps(_mm256_castps256_ps128(hsum), _mm256_castps256_ps128(hsum))); + + for (; i < n; i++) { + z[0] += x[i] * y[i]; + } +#else + vec_mul_reduce(n, x, y, z); +#endif +} + +template +inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = a - x[i]; + } +} + +template <> +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_bias_sub(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_sub_ps(bias, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = a - x[i]; + } +#else + vec_bias_sub(n, a, x, y); +#endif +} + +template <> +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { + vec_bias_sub(n, a, x, y); +} + +template <> +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_bias_sub(n, a, x, y); +} + +// out = x*y + (1-x)*z +template +inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { + for (int i = 0; i < n; ++i) { + out[i] = x[i] * y[i] + (static_cast(1) - x[i]) * z[i]; + } +} + +template <> +inline void vec_cross( + const int n, const float* x, const float* y, const float* z, float* out) { +#ifdef __AVX__ + constexpr int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_cross(n, x, y, z, out); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(1.f); + __m256 tmpx, tmpy, tmpz; + for (i = 0; i < end; i += block) { + tmpx = _mm256_loadu_ps(x + i); + tmpy = _mm256_loadu_ps(y + i); + tmpz = _mm256_loadu_ps(z + i); + tmpy = _mm256_mul_ps(tmpx, tmpy); + tmpx = _mm256_sub_ps(bias, tmpx); + tmpz = _mm256_mul_ps(tmpx, tmpz); + tmpz = _mm256_add_ps(tmpy, tmpz); + _mm256_storeu_ps(out + i, tmpz); + } + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; + } +#else + vec_cross(n, x, y, z, out); +#endif +} + +template <> +inline void vec_cross( + const int n, const float* x, const float* y, const float* z, float* out) { + vec_cross(n, x, y, z, out); +} + +template <> +inline void vec_cross( + const int n, const float* x, const float* y, const float* z, float* out) { + // TODO(TJ): enable me + vec_cross(n, x, y, z, out); +} + +template +inline void vec_clip(const size_t n, const T a, const T* x, T* y) { + for (size_t i = 0; i < n; ++i) { + y[i] = x[i] < a ? a : x[i]; + } +} + +template <> +inline void vec_clip(const size_t n, + const float a, + const float* x, + float* y) { +#ifdef __AVX__ + constexpr unsigned int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_clip(n, a, x, y); + return; + } + + unsigned int i = 0, end = 0; + end = n & ~(block - 1); + __m256 threshold = _mm256_set1_ps(a); + + for (i = 0; i < end; i += block) { + _mm256_storeu_ps(y + i, _mm256_max_ps(_mm256_loadu_ps(x + i), threshold)); + } + + for (; i < n; i++) { + y[i] = x[i] < a ? a : x[i]; + } +#else + vec_clip(n, a, x, y); +#endif +} + +template +inline void vec_add_bias(const int n, const T a, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + a; + } +} + +template <> +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_add_bias(n, a, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 bias = _mm256_set1_ps(a); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_add_ps(tmp, bias); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step if src and dst are inplace + for (i = n - rest; i < n; ++i) { + y[i] = x[i] + a; + } +#else + vec_add_bias(n, a, x, y); +#endif +} + +template <> +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { + vec_add_bias(n, a, x, y); +} + +template <> +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_add_bias(n, a, x, y); +} + +template +inline void vec_identity(const int n, const T* x, T* y) { + // do nothing + return; +} + +template +inline void vec_sigmoid(const int n, const T* x, T* y) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + vec_exp(n, y, y); + for (int i = 0; i < n; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } +} + +template <> +inline void vec_sigmoid(const int n, + const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = YMM_FLOAT_BLOCK; + if (n < block) { + vec_sigmoid(n, x, y); + return; + } + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 zeros = _mm256_setzero_ps(); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, min); \ + tmp = _mm256_min_ps(tmp, max); \ + tmp = _mm256_sub_ps(zeros, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest != 0) { + // can not continue move step since the src and dst address could be equal + const float xmin = SIGMOID_THRESHOLD_MIN; + const float xmax = SIGMOID_THRESHOLD_MAX; + for (i = n - rest; i < n; ++i) { + y[i] = 0.f - ((x[i] < xmin) ? xmin : ((x[i] > xmax) ? xmax : x[i])); + } + } + + vec_exp(n, y, y); + + __m256 ones = _mm256_set1_ps(1.0f); +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(y + i); \ + tmp = _mm256_add_ps(ones, tmp); \ + tmp = _mm256_div_ps(ones, tmp); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } +#undef MOVE_ONE_STEP + if (rest == 0) { + return; + } + // can not continue move step + for (i = n - rest; i < n; ++i) { + y[i] = 1.f / (1.f + y[i]); + } +#else + vec_sigmoid(n, x, y); +#endif +} + +template <> +inline void vec_sigmoid(const int n, + const float* x, + float* y) { + vec_sigmoid(n, x, y); +} + +template <> +inline void vec_sigmoid(const int n, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_sigmoid(n, x, y); +} + +template +inline void vec_tanh(const int n, const T* x, T* y) { + vec_scal(n, static_cast(2), x, y); + vec_sigmoid(n, y, y); + vec_scal(n, static_cast(2), y); + vec_add_bias(n, static_cast(-1), y, y); +} + +// TODO(TJ): make relu clip +template +inline void vec_relu(const int n, const T* x, T* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + +template <> +inline void vec_relu(const int n, + const float* x, + float* y) { +#ifdef __AVX__ + constexpr int block = YMM_FLOAT_BLOCK; + if (n < block * 4) { + vec_relu(n, x, y); + return; + } + + const int rest = n % block; + const int end = n - rest; + int i = 0; + __m256 zeros = _mm256_setzero_ps(); + __m256 tmp; +#define MOVE_ONE_STEP \ + tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + i, tmp) + for (i = 0; i < end; i += block) { + MOVE_ONE_STEP; + } + if (rest == 0) { + return; + } + i = n - block; + MOVE_ONE_STEP; +#undef MOVE_ONE_STEP + +#else + vec_relu(n, x, y); +#endif +} + +template <> +inline void vec_relu(const int n, + const float* x, + float* y) { + vec_relu(n, x, y); +} + +template <> +inline void vec_relu(const int n, + const float* x, + float* y) { + // TODO(TJ): enable me + vec_relu(n, x, y); +} + +// TODO(TJ): optimize double of sigmoid, tanh and relu if necessary + +template +class VecActivations { + public: + std::function operator()( + const std::string& type) { + if (type == "sigmoid") { + return vec_sigmoid; + } else if (type == "relu") { + return vec_relu; + } else if (type == "tanh") { + return vec_tanh; + } else if (type == "identity" || type == "") { + return vec_identity; + } + PADDLE_THROW(phi::errors::InvalidArgument( + "Expected type should be one of sigmod, relu, tanh, identity. But got " + "not support type: %s.", + type)); + } +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index c92e10f8dd7..317dcce92c8 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -22,3 +22,5 @@ endif() if(WITH_ROCM) hip_test(test_math_function_gpu SRCS test_math_function.cu DEPS math_function) endif() + +cc_test(test_cpu_vec SRCS test_cpu_vec.cc DEPS blas cpu_info) diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/phi/tests/kernels/test_cpu_vec.cc similarity index 75% rename from paddle/fluid/operators/math/cpu_vec_test.cc rename to paddle/phi/tests/kernels/test_cpu_vec.cc index 859afec3781..271143f9f6f 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/phi/tests/kernels/test_cpu_vec.cc @@ -18,7 +18,10 @@ limitations under the License. */ #include "glog/logging.h" #include "gtest/gtest.h" -#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/phi/kernels/funcs/cpu_vec.h" + +namespace phi { +namespace tests { inline double GetCurrentUS() { struct timeval time; @@ -62,7 +65,9 @@ void ref_relu(const int n, const T* x, T* y) { } template -void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), +void RandomVec(const int n, + T* a, + const T lower = static_cast(-20.f), const T upper = static_cast(20.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); @@ -73,7 +78,8 @@ void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), } template -void TestAndBench(const int n, std::function tgt, +void TestAndBench(const int n, + std::function tgt, std::function ref) { std::vector x(n); std::vector ytgt(n), yref(n); @@ -101,47 +107,48 @@ void TestAndBench(const int n, std::function tgt, TEST(CpuVecTest, sigmoid) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_sigmoid, ref_sigmoid); - TestAndBench(sz, vec_sigmoid, - ref_sigmoid); - TestAndBench(sz, vec_sigmoid, - ref_sigmoid); - TestAndBench(sz, vec_sigmoid, - ref_sigmoid); + TestAndBench( + sz, vec_sigmoid, ref_sigmoid); + TestAndBench( + sz, vec_sigmoid, ref_sigmoid); + TestAndBench( + sz, vec_sigmoid, ref_sigmoid); } TestAndBench(30, vec_sigmoid, ref_sigmoid); } TEST(CpuVecTest, tanh) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, - ref_tanh); + TestAndBench( + sz, vec_tanh, ref_tanh); } TestAndBench(30, vec_tanh, ref_tanh); } TEST(CpuVecTest, relu) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, - ref_relu); + TestAndBench( + sz, vec_relu, ref_relu); } TestAndBench(30, vec_relu, ref_relu); } template -void compare_sum(size_t n, std::function tgt, +void compare_sum(size_t n, + std::function tgt, std::function ref) { std::vector x(n); T ytgt_data, yref_data; @@ -155,18 +162,19 @@ void compare_sum(size_t n, std::function tgt, TEST(CpuVecTest, vec_sum) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { compare_sum(sz, vec_sum, vec_sum); - compare_sum(sz, vec_sum, - vec_sum); + compare_sum( + sz, vec_sum, vec_sum); } compare_sum(30U, vec_sum, vec_sum); } template void compare_clip( - size_t n, T threshold, + size_t n, + T threshold, std::function tgt, std::function ref) { std::vector x(n); @@ -185,20 +193,23 @@ void compare_clip( TEST(CpuVecTest, vec_clip) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { - compare_clip(sz, -4.f, vec_clip, - vec_clip); - compare_clip(sz, -1.1f, vec_clip, + compare_clip( + sz, -4.f, vec_clip, vec_clip); + compare_clip(sz, + -1.1f, + vec_clip, vec_clip); } - compare_clip(30U, 1.0, vec_clip, - vec_clip); + compare_clip( + 30U, 1.0, vec_clip, vec_clip); } template void compare_mul( - size_t n, std::function tgt, + size_t n, + std::function tgt, std::function ref) { std::vector x(n), y(n); std::vector ztgt(n), zref(n); @@ -220,18 +231,19 @@ void compare_mul( TEST(CpuVecTest, vec_mul) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { compare_mul(sz, vec_mul, vec_mul); - compare_mul(sz, vec_mul, - vec_mul); + compare_mul( + sz, vec_mul, vec_mul); } compare_mul(30U, vec_mul, vec_mul); } template void compare_mul_reduce( - size_t n, std::function tgt, + size_t n, + std::function tgt, std::function ref) { std::vector x(n), y(n); T ztgt_data, zref_data; @@ -249,19 +261,21 @@ void compare_mul_reduce( TEST(CpuVecTest, vec_mul_reduce) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { - compare_mul_reduce(sz, vec_mul_reduce, - vec_mul_reduce); - compare_mul_reduce(sz, vec_mul_reduce, + compare_mul_reduce( + sz, vec_mul_reduce, vec_mul_reduce); + compare_mul_reduce(sz, + vec_mul_reduce, vec_mul_reduce); } - compare_mul_reduce(30U, vec_mul_reduce, - vec_mul_reduce); + compare_mul_reduce( + 30U, vec_mul_reduce, vec_mul_reduce); } template -void TestInplace(const int n, std::function tgt, +void TestInplace(const int n, + std::function tgt, std::function ref) { std::vector x(n); std::vector ytgt(n), yref(n); @@ -283,22 +297,22 @@ void TestInplace(const int n, std::function tgt, TEST(CpuVecTest, inplace_sigmoid) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_sigmoid, ref_sigmoid); - TestInplace(sz, vec_sigmoid, - ref_sigmoid); - TestInplace(sz, vec_sigmoid, - ref_sigmoid); - TestInplace(sz, vec_sigmoid, - ref_sigmoid); + TestInplace( + sz, vec_sigmoid, ref_sigmoid); + TestInplace( + sz, vec_sigmoid, ref_sigmoid); + TestInplace( + sz, vec_sigmoid, ref_sigmoid); } TestInplace(30, vec_sigmoid, ref_sigmoid); } TEST(CpuVecTest, inplace_tanh) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_tanh, ref_tanh); TestInplace(sz, vec_tanh, ref_tanh); @@ -310,7 +324,7 @@ TEST(CpuVecTest, inplace_tanh) { TEST(CpuVecTest, inplace_relu) { namespace platform = paddle::platform; - using namespace paddle::operators::math; // NOLINT + using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_relu, ref_relu); TestInplace(sz, vec_relu, ref_relu); @@ -319,3 +333,5 @@ TEST(CpuVecTest, inplace_relu) { } TestInplace(30, vec_relu, ref_relu); } +} // namespace tests +} // namespace phi -- GitLab