From 923ad5dc7aa01eb37a0fc9fd42e30033315776c3 Mon Sep 17 00:00:00 2001 From: PuQing Date: Mon, 28 Nov 2022 13:53:13 +0800 Subject: [PATCH] add cpu_info.h (#48403) --- paddle/fluid/platform/cpu_info.h | 15 +- paddle/phi/backends/cpu/cpu_info.h | 56 ++++ paddle/phi/kernels/funcs/cpu_vec.h | 249 +++++++++--------- .../funcs/detail/activation_functions.h | 2 +- paddle/phi/kernels/funcs/detail/avx_mathfun.h | 2 +- .../kernels/sparse/cpu/softmax_grad_kernel.cc | 10 +- .../phi/kernels/sparse/cpu/softmax_kernel.cc | 10 +- paddle/phi/tests/kernels/test_cpu_vec.cc | 96 +++---- 8 files changed, 239 insertions(+), 201 deletions(-) create mode 100644 paddle/phi/backends/cpu/cpu_info.h diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 29dc0a15aae..6c5bf68227a 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -50,6 +50,8 @@ inline void cpuid(int reg[4], int x) { #endif #endif +#include "paddle/phi/backends/cpu/cpu_info.h" + namespace paddle { namespace platform { @@ -82,18 +84,7 @@ size_t NPUPinnedMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t NPUPinnedMaxChunkSize(); -typedef enum { - isa_any, - sse42, - avx, - avx2, - avx512f, - avx512_core, - avx512_core_vnni, - avx512_mic, - avx512_mic_4ops, - avx512_bf16, -} cpu_isa_t; // Instruction set architecture +using namespace phi::backends::cpu; // NOLINT // May I use some instruction bool MayIUse(const cpu_isa_t cpu_isa); diff --git a/paddle/phi/backends/cpu/cpu_info.h b/paddle/phi/backends/cpu/cpu_info.h new file mode 100644 index 00000000000..cf7c6d95057 --- /dev/null +++ b/paddle/phi/backends/cpu/cpu_info.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#ifdef _WIN32 +#if defined(__AVX2__) +#include // avx2 +#elif defined(__AVX__) +#include // avx +#endif // AVX +#else // WIN32 +#ifdef __AVX__ +#include +#endif +#endif // WIN32 + +#if defined(_WIN32) +#define ALIGN32_BEG __declspec(align(32)) +#define ALIGN32_END +#else +#define ALIGN32_BEG +#define ALIGN32_END __attribute__((aligned(32))) +#endif // _WIN32 + +namespace phi { +namespace backends { +namespace cpu { +typedef enum { + isa_any, + sse42, + avx, + avx2, + avx512f, + avx512_core, + avx512_core_vnni, + avx512_mic, + avx512_mic_4ops, + avx512_bf16, +} cpu_isa_t; // Instruction set architecture +} // namespace cpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/kernels/funcs/cpu_vec.h b/paddle/phi/kernels/funcs/cpu_vec.h index 2719f86f522..e7dc6535c15 100644 --- a/paddle/phi/kernels/funcs/cpu_vec.h +++ b/paddle/phi/kernels/funcs/cpu_vec.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/enforce.h" #ifdef PADDLE_WITH_MKLML @@ -81,8 +81,7 @@ inline void vec_scal(const int n, const double a, double* x) { #endif // MKL scal only support inplace, choose this if src and dst are not equal -template +template inline void vec_scal(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = a * x[i]; @@ -90,14 +89,14 @@ inline void vec_scal(const int n, const T a, const T* x, T* y) { } template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_scal(n, a, x, y); + vec_scal(n, a, x, y); return; } const int rest = n % block; @@ -121,29 +120,28 @@ inline void vec_scal(const int n, y[i] = a * x[i]; } #else - vec_scal(n, a, x, y); + vec_scal(n, a, x, y); #endif } template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { - vec_scal(n, a, x, y); +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { + vec_scal(n, a, x, y); } template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { +inline void vec_scal(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me - vec_scal(n, a, x, y); + vec_scal(n, a, x, y); } -template +template inline void vec_sum(const size_t n, const T* x, T* s) { s[0] = x[0]; for (size_t i = 1; i < n; ++i) { @@ -152,13 +150,13 @@ inline void vec_sum(const size_t n, const T* x, T* s) { } template <> -inline void vec_sum(const size_t n, - const float* x, - float* s) { +inline void vec_sum(const size_t n, + const float* x, + float* s) { #ifdef __AVX__ constexpr unsigned int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_sum(n, x, s); + vec_sum(n, x, s); return; } @@ -182,12 +180,11 @@ inline void vec_sum(const size_t n, s[0] += x[i]; } #else - vec_sum(n, x, s); + vec_sum(n, x, s); #endif } -template +template inline void vec_mul(const size_t n, const T* x, const T* y, T* z) { for (size_t i = 0; i < n; ++i) { z[i] = x[i] * y[i]; @@ -195,14 +192,14 @@ inline void vec_mul(const size_t n, const T* x, const T* y, T* z) { } template <> -inline void vec_mul(const size_t n, - const float* x, - const float* y, - float* z) { +inline void vec_mul(const size_t n, + const float* x, + const float* y, + float* z) { #ifdef __AVX__ constexpr unsigned int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_mul(n, x, y, z); + vec_mul(n, x, y, z); return; } @@ -217,12 +214,11 @@ inline void vec_mul(const size_t n, z[i] = x[i] * y[i]; } #else - vec_mul(n, x, y, z); + vec_mul(n, x, y, z); #endif } -template +template inline void vec_mul_reduce(const size_t n, const T* x, const T* y, T* z) { z[0] = x[0] * y[0]; for (size_t i = 1; i < n; ++i) { @@ -231,14 +227,14 @@ inline void vec_mul_reduce(const size_t n, const T* x, const T* y, T* z) { } template <> -inline void vec_mul_reduce(const size_t n, - const float* x, - const float* y, - float* z) { +inline void vec_mul_reduce(const size_t n, + const float* x, + const float* y, + float* z) { #ifdef __AVX__ constexpr unsigned int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_mul_reduce(n, x, y, z); + vec_mul_reduce(n, x, y, z); return; } @@ -262,12 +258,11 @@ inline void vec_mul_reduce(const size_t n, z[0] += x[i] * y[i]; } #else - vec_mul_reduce(n, x, y, z); + vec_mul_reduce(n, x, y, z); #endif } -template +template inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = a - x[i]; @@ -275,14 +270,14 @@ inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { } template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_bias_sub(n, a, x, y); + vec_bias_sub(n, a, x, y); return; } const int rest = n % block; @@ -306,30 +301,29 @@ inline void vec_bias_sub(const int n, y[i] = a - x[i]; } #else - vec_bias_sub(n, a, x, y); + vec_bias_sub(n, a, x, y); #endif } template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { - vec_bias_sub(n, a, x, y); +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { + vec_bias_sub(n, a, x, y); } template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me - vec_bias_sub(n, a, x, y); + vec_bias_sub(n, a, x, y); } // out = x*y + (1-x)*z -template +template inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { for (int i = 0; i < n; ++i) { out[i] = x[i] * y[i] + (static_cast(1) - x[i]) * z[i]; @@ -337,12 +331,12 @@ inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { } template <> -inline void vec_cross( +inline void vec_cross( const int n, const float* x, const float* y, const float* z, float* out) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); return; } const int rest = n % block; @@ -368,25 +362,24 @@ inline void vec_cross( out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; } #else - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); #endif } template <> -inline void vec_cross( +inline void vec_cross( const int n, const float* x, const float* y, const float* z, float* out) { - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); } template <> -inline void vec_cross( +inline void vec_cross( const int n, const float* x, const float* y, const float* z, float* out) { // TODO(TJ): enable me - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); } -template +template inline void vec_clip(const size_t n, const T a, const T* x, T* y) { for (size_t i = 0; i < n; ++i) { y[i] = x[i] < a ? a : x[i]; @@ -394,14 +387,14 @@ inline void vec_clip(const size_t n, const T a, const T* x, T* y) { } template <> -inline void vec_clip(const size_t n, - const float a, - const float* x, - float* y) { +inline void vec_clip(const size_t n, + const float a, + const float* x, + float* y) { #ifdef __AVX__ constexpr unsigned int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_clip(n, a, x, y); + vec_clip(n, a, x, y); return; } @@ -417,12 +410,11 @@ inline void vec_clip(const size_t n, y[i] = x[i] < a ? a : x[i]; } #else - vec_clip(n, a, x, y); + vec_clip(n, a, x, y); #endif } -template +template inline void vec_add_bias(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] + a; @@ -430,14 +422,14 @@ inline void vec_add_bias(const int n, const T a, const T* x, T* y) { } template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_add_bias(n, a, x, y); + vec_add_bias(n, a, x, y); return; } const int rest = n % block; @@ -461,36 +453,34 @@ inline void vec_add_bias(const int n, y[i] = x[i] + a; } #else - vec_add_bias(n, a, x, y); + vec_add_bias(n, a, x, y); #endif } template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { - vec_add_bias(n, a, x, y); +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { + vec_add_bias(n, a, x, y); } template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me - vec_add_bias(n, a, x, y); + vec_add_bias(n, a, x, y); } -template +template inline void vec_identity(const int n, const T* x, T* y) { // do nothing return; } -template +template inline void vec_sigmoid(const int n, const T* x, T* y) { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; @@ -505,13 +495,13 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { } template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { +inline void vec_sigmoid(const int n, + const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_sigmoid(n, x, y); + vec_sigmoid(n, x, y); return; } const int rest = n % block; @@ -560,27 +550,26 @@ inline void vec_sigmoid(const int n, y[i] = 1.f / (1.f + y[i]); } #else - vec_sigmoid(n, x, y); + vec_sigmoid(n, x, y); #endif } template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { - vec_sigmoid(n, x, y); +inline void vec_sigmoid(const int n, + const float* x, + float* y) { + vec_sigmoid(n, x, y); } template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { +inline void vec_sigmoid(const int n, + const float* x, + float* y) { // TODO(TJ): enable me - vec_sigmoid(n, x, y); + vec_sigmoid(n, x, y); } -template +template inline void vec_tanh(const int n, const T* x, T* y) { vec_scal(n, static_cast(2), x, y); vec_sigmoid(n, y, y); @@ -589,8 +578,7 @@ inline void vec_tanh(const int n, const T* x, T* y) { } // TODO(TJ): make relu clip -template +template inline void vec_relu(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] > 0 ? x[i] : 0; @@ -598,13 +586,13 @@ inline void vec_relu(const int n, const T* x, T* y) { } template <> -inline void vec_relu(const int n, - const float* x, - float* y) { +inline void vec_relu(const int n, + const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block * 4) { - vec_relu(n, x, y); + vec_relu(n, x, y); return; } @@ -628,29 +616,28 @@ inline void vec_relu(const int n, #undef MOVE_ONE_STEP #else - vec_relu(n, x, y); + vec_relu(n, x, y); #endif } template <> -inline void vec_relu(const int n, - const float* x, - float* y) { - vec_relu(n, x, y); +inline void vec_relu(const int n, + const float* x, + float* y) { + vec_relu(n, x, y); } template <> -inline void vec_relu(const int n, - const float* x, - float* y) { +inline void vec_relu(const int n, + const float* x, + float* y) { // TODO(TJ): enable me - vec_relu(n, x, y); + vec_relu(n, x, y); } // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary -template +template class VecActivations { public: std::function operator()( diff --git a/paddle/phi/kernels/funcs/detail/activation_functions.h b/paddle/phi/kernels/funcs/detail/activation_functions.h index d41dca33f75..26be2a83280 100644 --- a/paddle/phi/kernels/funcs/detail/activation_functions.h +++ b/paddle/phi/kernels/funcs/detail/activation_functions.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/hostdevice.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/detail/avx_mathfun.h b/paddle/phi/kernels/funcs/detail/avx_mathfun.h index 90017f3c760..d036176e30e 100644 --- a/paddle/phi/kernels/funcs/detail/avx_mathfun.h +++ b/paddle/phi/kernels/funcs/detail/avx_mathfun.h @@ -42,7 +42,7 @@ (this is the zlib license) */ #pragma once -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/phi/backends/cpu/cpu_info.h" /* __m128 is ugly to write */ typedef __m256 v8sf; // vector of 8 float (avx) diff --git a/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc index e56fe869705..b932f3acdb9 100644 --- a/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc @@ -14,15 +14,13 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/softmax_grad_kernel.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" -namespace plt = paddle::platform; - namespace phi { namespace sparse { @@ -72,11 +70,11 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx, out_crows_data[crow_idx]); T sum = 0; - phi::funcs::vec_mul_reduce( + phi::funcs::vec_mul_reduce( row_nnz, dout_data, out_data, &sum); - phi::funcs::vec_add_bias( + phi::funcs::vec_add_bias( row_nnz, static_cast(-1) * sum, dout_data, dx_data); - phi::funcs::vec_mul( + phi::funcs::vec_mul( row_nnz, dx_data, out_data, dx_data); out_data = out_data + row_nnz; diff --git a/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc b/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc index 96b6470e0f3..46baa163a87 100644 --- a/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/softmax_kernel.cc @@ -14,15 +14,13 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/softmax_kernel.h" -#include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/cpu/cpu_info.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" -namespace plt = paddle::platform; - namespace phi { namespace sparse { @@ -70,14 +68,14 @@ void SoftmaxCsrKernel(const Context& dev_ctx, x_crows_data[crow_idx]); row_max_val = *std::max_element(x_data, x_data + row_nnz); - phi::funcs::vec_add_bias( + phi::funcs::vec_add_bias( row_nnz, static_cast(-1) * row_max_val, x_data, out_data); phi::funcs::vec_exp(row_nnz, out_data, out_data); T sum = 0; - phi::funcs::vec_sum(row_nnz, out_data, &sum); - phi::funcs::vec_scal( + phi::funcs::vec_sum(row_nnz, out_data, &sum); + phi::funcs::vec_scal( row_nnz, static_cast(1) / sum, out_data, out_data); x_data = x_data + row_nnz; diff --git a/paddle/phi/tests/kernels/test_cpu_vec.cc b/paddle/phi/tests/kernels/test_cpu_vec.cc index 271143f9f6f..3cb925e8ef4 100644 --- a/paddle/phi/tests/kernels/test_cpu_vec.cc +++ b/paddle/phi/tests/kernels/test_cpu_vec.cc @@ -106,42 +106,43 @@ void TestAndBench(const int n, } TEST(CpuVecTest, sigmoid) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench( - sz, vec_sigmoid, ref_sigmoid); + sz, vec_sigmoid, ref_sigmoid); TestAndBench( - sz, vec_sigmoid, ref_sigmoid); + sz, vec_sigmoid, ref_sigmoid); TestAndBench( - sz, vec_sigmoid, ref_sigmoid); + sz, vec_sigmoid, ref_sigmoid); } TestAndBench(30, vec_sigmoid, ref_sigmoid); } TEST(CpuVecTest, tanh) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench( - sz, vec_tanh, ref_tanh); + sz, vec_tanh, ref_tanh); + TestAndBench( + sz, vec_tanh, ref_tanh); + TestAndBench( + sz, vec_tanh, ref_tanh); } TestAndBench(30, vec_tanh, ref_tanh); } TEST(CpuVecTest, relu) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, ref_relu); TestAndBench( - sz, vec_relu, ref_relu); + sz, vec_relu, ref_relu); + TestAndBench( + sz, vec_relu, ref_relu); + TestAndBench( + sz, vec_relu, ref_relu); } TestAndBench(30, vec_relu, ref_relu); } @@ -161,14 +162,16 @@ void compare_sum(size_t n, } TEST(CpuVecTest, vec_sum) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { - compare_sum(sz, vec_sum, vec_sum); compare_sum( - sz, vec_sum, vec_sum); + sz, vec_sum, vec_sum); + compare_sum(sz, + vec_sum, + vec_sum); } - compare_sum(30U, vec_sum, vec_sum); + compare_sum( + 30U, vec_sum, vec_sum); } template @@ -192,18 +195,17 @@ void compare_clip( } TEST(CpuVecTest, vec_clip) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { compare_clip( - sz, -4.f, vec_clip, vec_clip); + sz, -4.f, vec_clip, vec_clip); compare_clip(sz, -1.1f, - vec_clip, - vec_clip); + vec_clip, + vec_clip); } compare_clip( - 30U, 1.0, vec_clip, vec_clip); + 30U, 1.0, vec_clip, vec_clip); } template @@ -230,14 +232,16 @@ void compare_mul( } TEST(CpuVecTest, vec_mul) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { - compare_mul(sz, vec_mul, vec_mul); compare_mul( - sz, vec_mul, vec_mul); + sz, vec_mul, vec_mul); + compare_mul(sz, + vec_mul, + vec_mul); } - compare_mul(30U, vec_mul, vec_mul); + compare_mul( + 30U, vec_mul, vec_mul); } template @@ -260,17 +264,18 @@ void compare_mul_reduce( } TEST(CpuVecTest, vec_mul_reduce) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (size_t sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { - compare_mul_reduce( - sz, vec_mul_reduce, vec_mul_reduce); compare_mul_reduce(sz, - vec_mul_reduce, - vec_mul_reduce); + vec_mul_reduce, + vec_mul_reduce); + compare_mul_reduce(sz, + vec_mul_reduce, + vec_mul_reduce); } - compare_mul_reduce( - 30U, vec_mul_reduce, vec_mul_reduce); + compare_mul_reduce(30U, + vec_mul_reduce, + vec_mul_reduce); } template @@ -296,40 +301,43 @@ void TestInplace(const int n, } TEST(CpuVecTest, inplace_sigmoid) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_sigmoid, ref_sigmoid); TestInplace( - sz, vec_sigmoid, ref_sigmoid); + sz, vec_sigmoid, ref_sigmoid); TestInplace( - sz, vec_sigmoid, ref_sigmoid); + sz, vec_sigmoid, ref_sigmoid); TestInplace( - sz, vec_sigmoid, ref_sigmoid); + sz, vec_sigmoid, ref_sigmoid); } TestInplace(30, vec_sigmoid, ref_sigmoid); } TEST(CpuVecTest, inplace_tanh) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, ref_tanh); + TestInplace( + sz, vec_tanh, ref_tanh); + TestInplace( + sz, vec_tanh, ref_tanh); + TestInplace( + sz, vec_tanh, ref_tanh); } TestInplace(30, vec_tanh, ref_tanh); } TEST(CpuVecTest, inplace_relu) { - namespace platform = paddle::platform; using namespace phi::funcs; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, ref_relu); + TestInplace( + sz, vec_relu, ref_relu); + TestInplace( + sz, vec_relu, ref_relu); + TestInplace( + sz, vec_relu, ref_relu); } TestInplace(30, vec_relu, ref_relu); } -- GitLab