diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 9b943440a869e213db4ed761cfe7c508bc5e94ae..75fc59125f21901b6781315eb3d7dba36b7f11f2 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -231,10 +231,10 @@ use lstm_x_t as input and compute as standard LSTM. template inline void bias_relu(const int n, const T* x, const T* bias, T* y) { if (bias) { - math::vec_add_bias(n, *bias, x, y); - math::vec_relu(n, y, y); + math::vec_add_bias(n, *bias, x, y); + math::vec_relu(n, y, y); } else { - math::vec_relu(n, x, y); + math::vec_relu(n, x, y); } } @@ -245,8 +245,8 @@ inline void vec_softmax(const int n, const T* x, T* y) { for (int i = 1; i < n; ++i) { scalar = scalar < x[i] ? x[i] : scalar; } - math::vec_add_bias(n, -scalar, x, y); // sub - math::vec_exp(n, y, y); // exp + math::vec_add_bias(n, -scalar, x, y); // sub + math::vec_exp(n, y, y); // exp // sum scalar = T(0); for (int i = 0; i < n; ++i) { @@ -302,13 +302,13 @@ class AttentionLSTMKernel : public framework::OpKernel { auto& act_gate_str = ctx.Attr("gate_activation"); auto& act_cell_str = ctx.Attr("cell_activation"); auto& act_cand_str = ctx.Attr("candidate_activation"); - if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; + if (platform::MayIUse(platform::avx)) { + math::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); } else { - math::VecActivations act_functor; + math::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 6d463538d232e1a38f845e7abc3786568ca3bb21..1eb6523a2dfb358490a07bf1b806d5638442a4d5 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -217,13 +217,13 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { auto& act_gate_str = ctx.Attr("gate_activation"); \ auto& act_cell_str = ctx.Attr("cell_activation"); \ auto& act_cand_str = ctx.Attr("candidate_activation"); \ - if (platform::jit::MayIUse(platform::jit::avx)) { \ - math::VecActivations act_functor; \ + if (platform::MayIUse(platform::avx)) { \ + math::VecActivations act_functor; \ act_gate = act_functor(act_gate_str); \ act_cell = act_functor(act_cell_str); \ act_cand = act_functor(act_cand_str); \ } else { \ - math::VecActivations act_functor; \ + math::VecActivations act_functor; \ act_gate = act_functor(act_gate_str); \ act_cell = act_functor(act_cell_str); \ act_cand = act_functor(act_cand_str); \ diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 288b56fc2485138b20c5b53af3e950f1c1886ba5..17ed9771d074cf7ae8c6735e4cb859139503a0af 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -151,11 +151,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { std::function fc_act; auto& fc_act_str = ctx.Attr("fc_activation"); - if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; + if (platform::MayIUse(platform::avx)) { + math::VecActivations act_functor; fc_act = act_functor(fc_act_str); } else { - math::VecActivations act_functor; + math::VecActivations act_functor; fc_act = act_functor(fc_act_str); } diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 7d81aee596934308763002d440f52400f45b5f20..e1e4d168db3ca594b44396a6e30c5bfc03483eaf 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -77,7 +77,7 @@ inline void vec_scal(const int n, const double a, double* x) { #endif // MKL scal only support inplace, choose this if src and dst are not equal -template +template inline void vec_scal(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = a * x[i]; @@ -85,12 +85,12 @@ inline void vec_scal(const int n, const T a, const T* x, T* y) { } template <> -inline void vec_scal(const int n, const float a, - const float* x, float* y) { +inline void vec_scal(const int n, const float a, + const float* x, float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_scal(n, a, x, y); + vec_scal(n, a, x, y); return; } const int rest = n % block; @@ -114,24 +114,24 @@ inline void vec_scal(const int n, const float a, y[i] = a * x[i]; } #else - vec_scal(n, a, x, y); + vec_scal(n, a, x, y); #endif } template <> -inline void vec_scal(const int n, const float a, - const float* x, float* y) { - vec_scal(n, a, x, y); +inline void vec_scal(const int n, const float a, + const float* x, float* y) { + vec_scal(n, a, x, y); } template <> -inline void vec_scal(const int n, const float a, - const float* x, float* y) { +inline void vec_scal(const int n, const float a, + const float* x, float* y) { // TODO(TJ): enable me - vec_scal(n, a, x, y); + vec_scal(n, a, x, y); } -template +template inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = a - x[i]; @@ -139,12 +139,12 @@ inline void vec_bias_sub(const int n, const T a, const T* x, T* y) { } template <> -inline void vec_bias_sub(const int n, const float a, - const float* x, float* y) { +inline void vec_bias_sub(const int n, const float a, + const float* x, float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_bias_sub(n, a, x, y); + vec_bias_sub(n, a, x, y); return; } const int rest = n % block; @@ -168,27 +168,25 @@ inline void vec_bias_sub(const int n, const float a, y[i] = a - x[i]; } #else - vec_bias_sub(n, a, x, y); + vec_bias_sub(n, a, x, y); #endif } template <> -inline void vec_bias_sub(const int n, const float a, - const float* x, float* y) { - vec_bias_sub(n, a, x, y); +inline void vec_bias_sub(const int n, const float a, + const float* x, float* y) { + vec_bias_sub(n, a, x, y); } template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { +inline void vec_bias_sub(const int n, const float a, + const float* x, float* y) { // TODO(TJ): enable me - vec_bias_sub(n, a, x, y); + vec_bias_sub(n, a, x, y); } // out = x*y + (1-x)*z -template +template inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { for (int i = 0; i < n; ++i) { out[i] = x[i] * y[i] + (static_cast(1) - x[i]) * z[i]; @@ -196,13 +194,13 @@ inline void vec_cross(const int n, const T* x, const T* y, const T* z, T* out) { } template <> -inline void vec_cross(const int n, const float* x, - const float* y, const float* z, - float* out) { +inline void vec_cross(const int n, const float* x, + const float* y, const float* z, + float* out) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); return; } const int rest = n % block; @@ -228,25 +226,26 @@ inline void vec_cross(const int n, const float* x, out[i] = x[i] * y[i] + (1.f - x[i]) * z[i]; } #else - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); #endif } template <> -inline void vec_cross(const int n, const float* x, - const float* y, - const float* z, float* out) { - vec_cross(n, x, y, z, out); +inline void vec_cross(const int n, const float* x, + const float* y, const float* z, + float* out) { + vec_cross(n, x, y, z, out); } template <> -inline void vec_cross( - const int n, const float* x, const float* y, const float* z, float* out) { +inline void vec_cross(const int n, const float* x, + const float* y, const float* z, + float* out) { // TODO(TJ): enable me - vec_cross(n, x, y, z, out); + vec_cross(n, x, y, z, out); } -template +template inline void vec_add_bias(const int n, const T a, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] + a; @@ -254,12 +253,12 @@ inline void vec_add_bias(const int n, const T a, const T* x, T* y) { } template <> -inline void vec_add_bias(const int n, const float a, - const float* x, float* y) { +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_add_bias(n, a, x, y); + vec_add_bias(n, a, x, y); return; } const int rest = n % block; @@ -283,32 +282,30 @@ inline void vec_add_bias(const int n, const float a, y[i] = x[i] + a; } #else - vec_add_bias(n, a, x, y); + vec_add_bias(n, a, x, y); #endif } template <> -inline void vec_add_bias(const int n, const float a, - const float* x, float* y) { - vec_add_bias(n, a, x, y); +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { + vec_add_bias(n, a, x, y); } template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { +inline void vec_add_bias(const int n, const float a, + const float* x, float* y) { // TODO(TJ): enable me - vec_add_bias(n, a, x, y); + vec_add_bias(n, a, x, y); } -template +template inline void vec_identity(const int n, const T* x, T* y) { // do nothing return; } -template +template inline void vec_sigmoid(const int n, const T* x, T* y) { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; @@ -323,12 +320,12 @@ inline void vec_sigmoid(const int n, const T* x, T* y) { } template <> -inline void vec_sigmoid(const int n, const float* x, - float* y) { +inline void vec_sigmoid(const int n, const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block) { - vec_sigmoid(n, x, y); + vec_sigmoid(n, x, y); return; } const int rest = n % block; @@ -377,25 +374,24 @@ inline void vec_sigmoid(const int n, const float* x, y[i] = 1.f / (1.f + y[i]); } #else - vec_sigmoid(n, x, y); + vec_sigmoid(n, x, y); #endif } template <> -inline void vec_sigmoid(const int n, const float* x, - float* y) { - vec_sigmoid(n, x, y); +inline void vec_sigmoid(const int n, const float* x, + float* y) { + vec_sigmoid(n, x, y); } template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { +inline void vec_sigmoid(const int n, const float* x, + float* y) { // TODO(TJ): enable me - vec_sigmoid(n, x, y); + vec_sigmoid(n, x, y); } -template +template inline void vec_tanh(const int n, const T* x, T* y) { vec_scal(n, static_cast(2), x, y); vec_sigmoid(n, y, y); @@ -404,7 +400,7 @@ inline void vec_tanh(const int n, const T* x, T* y) { } // TODO(TJ): make relu clip -template +template inline void vec_relu(const int n, const T* x, T* y) { for (int i = 0; i < n; ++i) { y[i] = x[i] > 0 ? x[i] : 0; @@ -412,12 +408,12 @@ inline void vec_relu(const int n, const T* x, T* y) { } template <> -inline void vec_relu(const int n, const float* x, - float* y) { +inline void vec_relu(const int n, const float* x, + float* y) { #ifdef __AVX__ constexpr int block = YMM_FLOAT_BLOCK; if (n < block * 4) { - vec_relu(n, x, y); + vec_relu(n, x, y); return; } @@ -441,26 +437,26 @@ inline void vec_relu(const int n, const float* x, #undef MOVE_ONE_STEP #else - vec_relu(n, x, y); + vec_relu(n, x, y); #endif } template <> -inline void vec_relu(const int n, const float* x, - float* y) { - vec_relu(n, x, y); +inline void vec_relu(const int n, const float* x, + float* y) { + vec_relu(n, x, y); } template <> -inline void vec_relu(const int n, const float* x, - float* y) { +inline void vec_relu(const int n, const float* x, + float* y) { // TODO(TJ): enable me - vec_relu(n, x, y); + vec_relu(n, x, y); } // TODO(TJ): optimize double of sigmoid, tanh and relu if necessary -template +template class VecActivations { public: std::function operator()( diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index c37fa291a259550a3cb6d4f3dd9d5a415c3a2130..28eb9cadc9d4258bf4f8f71a06e029531e448014 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -104,38 +104,42 @@ void TestAndBench(const int n, std::function tgt, } TEST(CpuVecTest, sigmoid) { - namespace jit = paddle::platform::jit; + namespace platform = paddle::platform; using namespace paddle::operators::math; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_sigmoid, ref_sigmoid); - TestAndBench(sz, vec_sigmoid, ref_sigmoid); - TestAndBench(sz, vec_sigmoid, ref_sigmoid); - TestAndBench(sz, vec_sigmoid, + TestAndBench(sz, vec_sigmoid, + ref_sigmoid); + TestAndBench(sz, vec_sigmoid, + ref_sigmoid); + TestAndBench(sz, vec_sigmoid, ref_sigmoid); } TestAndBench(30, vec_sigmoid, ref_sigmoid); } TEST(CpuVecTest, tanh) { - namespace jit = paddle::platform::jit; + namespace platform = paddle::platform; using namespace paddle::operators::math; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); + TestAndBench(sz, vec_tanh, + ref_tanh); } TestAndBench(30, vec_tanh, ref_tanh); } TEST(CpuVecTest, relu) { - namespace jit = paddle::platform::jit; + namespace platform = paddle::platform; using namespace paddle::operators::math; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, ref_relu); + TestAndBench(sz, vec_relu, + ref_relu); } TestAndBench(30, vec_relu, ref_relu); } @@ -162,38 +166,40 @@ void TestInplace(const int n, std::function tgt, } TEST(CpuVecTest, inplace_sigmoid) { - namespace jit = paddle::platform::jit; + namespace platform = paddle::platform; using namespace paddle::operators::math; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_sigmoid, ref_sigmoid); - TestInplace(sz, vec_sigmoid, ref_sigmoid); - TestInplace(sz, vec_sigmoid, ref_sigmoid); - TestInplace(sz, vec_sigmoid, + TestInplace(sz, vec_sigmoid, + ref_sigmoid); + TestInplace(sz, vec_sigmoid, + ref_sigmoid); + TestInplace(sz, vec_sigmoid, ref_sigmoid); } TestInplace(30, vec_sigmoid, ref_sigmoid); } TEST(CpuVecTest, inplace_tanh) { - namespace jit = paddle::platform::jit; + namespace platform = paddle::platform; using namespace paddle::operators::math; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); } TestInplace(30, vec_tanh, ref_tanh); } TEST(CpuVecTest, inplace_relu) { - namespace jit = paddle::platform::jit; + namespace platform = paddle::platform; using namespace paddle::operators::math; // NOLINT for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) { TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); + TestInplace(sz, vec_relu, ref_relu); } TestInplace(30, vec_relu, ref_relu); } diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 52cbdf685dee651cbc1490dc6faacb8680004c89..78d0c3e8808f0daf6a18d2217664e965773b95ff 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -22,7 +22,7 @@ namespace math { namespace jitkernel { namespace gen { -using namespace platform::jit; // NOLINT +using namespace platform; // NOLINT bool VXXJitCode::init(int d, int scalar_index) { // It's not necessary to use avx512 since it would slow down the frequency diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index a9214621295a7740b804b26c02d216dd5118d8bb..e2b4761435594fdc952ff5dba5b5fa4f4aa98e6c 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -179,7 +179,7 @@ class VActJitCode : public JitCode { template void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { - using namespace platform::jit; // NOLINT + using namespace platform; // NOLINT // check all idx can not equal JMM jmm_src = JMM(src_idx); JMM jmm_fx = JMM(fx_idx); diff --git a/paddle/fluid/operators/math/jit_gen.cc b/paddle/fluid/operators/math/jit_gen.cc index 6af39518ed926554c8c839bba701d3827923dba0..5c6672928e8c03ccb1920bd828f785084e422fc2 100644 --- a/paddle/fluid/operators/math/jit_gen.cc +++ b/paddle/fluid/operators/math/jit_gen.cc @@ -36,7 +36,7 @@ void JitCode::preCode() { for (int i = 0; i < num_g_abi_regs; ++i) { push(Xbyak::Reg64(g_abi_regs[i])); } - if (platform::jit::MayIUse(platform::jit::avx512f)) { + if (platform::MayIUse(platform::avx512f)) { mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); } } diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 68b708b345334bc63b5e2e88c308d20ca6378e6b..118696ba47986e2dbf97535333c9817b7c264a54 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -21,8 +21,6 @@ namespace operators { namespace math { namespace jitkernel { -namespace jit = platform::jit; - KernelPool& KernelPool::Instance() { static thread_local KernelPool g_jit_kernels; return g_jit_kernels; diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index a0f93fd8e7eb7d81211724a6991a681e2a0ed9ce..8cf588efba52314650bfd376b95b10e6d4336b2e 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -30,7 +30,6 @@ namespace paddle { namespace operators { namespace math { namespace jitkernel { -namespace jit = platform::jit; #ifdef PADDLE_WITH_MKLML template @@ -125,7 +124,7 @@ bool VMulKernelImpl::useJIT(int d) { #ifdef PADDLE_WITH_MKLML template <> bool VMulKernelImpl::useMKL(int d) { - return jit::MayIUse(jit::avx512f) && d > 512; + return platform::MayIUse(platform::avx512f) && d > 512; } template <> diff --git a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc index 4d26b81948238f18b097f535534fcfe9049b93c3..eeb305a88bee8f0e21b205684d24b19ca4631f65 100644 --- a/paddle/fluid/operators/math/jit_kernel_crf_decode.cc +++ b/paddle/fluid/operators/math/jit_kernel_crf_decode.cc @@ -25,10 +25,8 @@ namespace operators { namespace math { namespace jitkernel { -namespace jit = platform::jit; - /* CRF Decode JitKernel */ -template +template class CRFDecodeKernelImpl : public CRFDecodeKernel { public: explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel() { @@ -101,7 +99,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { #define INTRIAVX_FLOAT(block) \ template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ + CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ @@ -109,7 +107,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ } \ template <> \ - void CRFDecodeKernelImpl::Compute( \ + void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ INIT_ALPHA(YMM_FLOAT_BLOCK) \ @@ -204,7 +202,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { #define INTRIAVX512_FLOAT(block) \ template <> \ - CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ + CRFDecodeKernelImpl::CRFDecodeKernelImpl( \ int tag_num) \ : CRFDecodeKernel() { \ this->num_ = tag_num; \ @@ -212,7 +210,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel { this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \ } \ template <> \ - void CRFDecodeKernelImpl::Compute( \ + void CRFDecodeKernelImpl::Compute( \ const int seq_len, const float* x, const float* w, float* alpha, \ int* track) const { \ INIT_ALPHA(ZMM_FLOAT_BLOCK) \ @@ -270,14 +268,14 @@ INTRIAVX_FLOAT(kEQ16); INTRIAVX_FLOAT(kGT16); #endif #ifdef __AVX2__ -INTRIAVX2_FLOAT(jit::avx2, kEQ8); -INTRIAVX2_FLOAT(jit::avx2, kGT8LT16); -INTRIAVX2_FLOAT(jit::avx2, kEQ16); -INTRIAVX2_FLOAT(jit::avx2, kGT16); +INTRIAVX2_FLOAT(platform::avx2, kEQ8); +INTRIAVX2_FLOAT(platform::avx2, kGT8LT16); +INTRIAVX2_FLOAT(platform::avx2, kEQ16); +INTRIAVX2_FLOAT(platform::avx2, kGT16); #endif #ifdef __AVX512F__ -INTRIAVX2_FLOAT(jit::avx512f, kEQ8); -INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16); +INTRIAVX2_FLOAT(platform::avx512f, kEQ8); +INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16); INTRIAVX512_FLOAT(kEQ16); INTRIAVX512_FLOAT(kGT16); #endif diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index 686f3dd9836cb9192088771753065c6add639620..7945cfb253a61b7d1191c39537254126e2bb85dd 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -29,7 +29,6 @@ namespace paddle { namespace operators { namespace math { namespace jitkernel { -namespace jit = platform::jit; #ifdef PADDLE_WITH_MKLML // try to use MKL to speedup diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc index 49904e6e8c7cd346bcbfb67c3a7574118b36e058..fead13ebadcd131afafc308740cdd39b1c53bc08 100644 --- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc +++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc @@ -22,10 +22,8 @@ namespace operators { namespace math { namespace jitkernel { -namespace jit = platform::jit; - /* Layer Norm JitKernel */ -template +template class LayerNormKernelImpl : public LayerNormKernel { public: explicit LayerNormKernelImpl(int right) : LayerNormKernel() { @@ -90,7 +88,7 @@ class LayerNormKernelImpl : public LayerNormKernel { this->end_ = this->num_ - this->rest_; \ } \ template <> \ - void LayerNormKernelImpl::Compute( \ + void LayerNormKernelImpl::Compute( \ float* x, float* out, float* mean, float* var, const float* scale, \ const float* bias, int height, const float epsilon) const { \ __m256 sum; \ @@ -219,16 +217,16 @@ class LayerNormKernelImpl : public LayerNormKernel { } #ifdef __AVX__ -INTRIAVX_FLOAT(jit::avx, kEQ8); -INTRIAVX_FLOAT(jit::avx, kGT8LT16); -INTRIAVX_FLOAT(jit::avx, kEQ16); -INTRIAVX_FLOAT(jit::avx, kGT16); +INTRIAVX_FLOAT(platform::avx, kEQ8); +INTRIAVX_FLOAT(platform::avx, kGT8LT16); +INTRIAVX_FLOAT(platform::avx, kEQ16); +INTRIAVX_FLOAT(platform::avx, kGT16); #endif #ifdef __AVX2__ -INTRIAVX_FLOAT(jit::avx2, kEQ8); -INTRIAVX_FLOAT(jit::avx2, kGT8LT16); -INTRIAVX_FLOAT(jit::avx2, kEQ16); -INTRIAVX_FLOAT(jit::avx2, kGT16); +INTRIAVX_FLOAT(platform::avx2, kEQ8); +INTRIAVX_FLOAT(platform::avx2, kGT8LT16); +INTRIAVX_FLOAT(platform::avx2, kEQ16); +INTRIAVX_FLOAT(platform::avx2, kGT16); #endif #undef INTRIAVX_FLOAT diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h index 5a3efd979f803d396a5084c199b1d71b88a77126..4dba3b56810794cb4839d26386ae77a8f4507977 100644 --- a/paddle/fluid/operators/math/jit_kernel_macro.h +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -92,7 +92,6 @@ namespace jitkernel { JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \ JITKERNEL_IMPL) -namespace jit = platform::jit; // TODO(TJ): below defines are deprecated, would be remove recently #define SEARCH_BLOCK(macro_, ker, dtype, isa) \ if (d < YMM_FLOAT_BLOCK) { \ @@ -107,15 +106,15 @@ namespace jit = platform::jit; macro_(ker, dtype, isa, kGT16); \ } -#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ - if (jit::MayIUse(jit::avx512f)) { \ - SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \ - } else if (jit::MayIUse(jit::avx2)) { \ - SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \ - } else if (jit::MayIUse(jit::avx)) { \ - SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \ - } else { \ - SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \ +#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ + if (platform::MayIUse(platform::avx512f)) { \ + SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \ + } else if (platform::MayIUse(platform::avx2)) { \ + SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \ + } else if (platform::MayIUse(platform::avx)) { \ + SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \ + } else { \ + SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \ } #define JITKERNEL_KEY(ker_key, dtype_key) \ @@ -156,10 +155,10 @@ namespace jit = platform::jit; marco_declare, macro_key, macro_impl) #define FOR_EACH_ISA(macro_, block) \ - macro_(jit::avx512f, block); \ - macro_(jit::avx2, block); \ - macro_(jit::avx, block); \ - macro_(jit::isa_any, block) + macro_(platform::avx512f, block); \ + macro_(platform::avx2, block); \ + macro_(platform::avx, block); \ + macro_(platform::isa_any, block) #define FOR_EACH_BLOCK(macro_, isa) \ macro_(isa, kLT8); \ @@ -168,11 +167,11 @@ namespace jit = platform::jit; macro_(isa, kEQ16); \ macro_(isa, kGT16) -#define FOR_EACH_ISA_BLOCK(macro_) \ - FOR_EACH_BLOCK(macro_, jit::avx512f); \ - FOR_EACH_BLOCK(macro_, jit::avx2); \ - FOR_EACH_BLOCK(macro_, jit::avx); \ - FOR_EACH_BLOCK(macro_, jit::isa_any) +#define FOR_EACH_ISA_BLOCK(macro_) \ + FOR_EACH_BLOCK(macro_, platform::avx512f); \ + FOR_EACH_BLOCK(macro_, platform::avx2); \ + FOR_EACH_BLOCK(macro_, platform::avx); \ + FOR_EACH_BLOCK(macro_, platform::isa_any) } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index ed86a47e159cacd4f5572e22c7633f725aaeb516..19f7bd8909499c12fd5bee4db0d0a71a632e7f19 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -705,7 +705,7 @@ TEST(JitKernel, pool) { jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false); // empty call it to avoid unknown flag 'use_pinned_memory' on Mac - paddle::platform::jit::MayIUse(paddle::platform::jit::avx); + paddle::platform::MayIUse(paddle::platform::avx); const auto& plstm1 = jit::KernelPool::Instance() .template Get, const jit::lstm_attr_t&>(attr); diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index d466f28d1ea0a8327f8d7a45c3e55c5aacd61544..f9a32bfa4c15261ba6b79fc4efd3a1961f7c6d4d 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -123,7 +123,6 @@ size_t CUDAPinnedMaxChunkSize() { return CUDAPinnedMaxAllocSize() / 256; } -namespace jit { #ifdef PADDLE_WITH_XBYAK static Xbyak::util::Cpu cpu; bool MayIUse(const cpu_isa_t cpu_isa) { @@ -165,6 +164,5 @@ bool MayIUse(const cpu_isa_t cpu_isa) { } #endif -} // namespace jit } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index fd31ef77b46d5b5b641983a0421da31914c87c18..55dba545ff133b1c219ee58f6d1bb2d2130d1a59 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -39,7 +39,6 @@ size_t CUDAPinnedMinChunkSize(); //! Get the maximum chunk size for buddy allocator. size_t CUDAPinnedMaxChunkSize(); -namespace jit { typedef enum { isa_any, sse42, @@ -55,7 +54,5 @@ typedef enum { // May I use some instruction bool MayIUse(const cpu_isa_t cpu_isa); -} // namespace jit - } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 51b46450e419c6641b415a58b7551f7dd56627b2..0d10d82d74a2011b1b2bc088fe88cbfdb49600b8 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { #endif #if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__) - if (platform::jit::MayIUse(platform::jit::avx)) { + if (platform::MayIUse(platform::avx)) { #ifndef __AVX__ LOG(WARNING) << "AVX is available, Please re-compile on local machine"; #endif @@ -131,10 +131,10 @@ void InitDevices(bool init_p2p, const std::vector devices) { " version or compile from source code." #ifdef __AVX512F__ - if (!platform::jit::MayIUse(platform::jit::avx512f)) { - if (platform::jit::MayIUse(platform::jit::avx2)) { + if (!platform::MayIUse(platform::avx512f)) { + if (platform::MayIUse(platform::avx2)) { AVX_GUIDE(AVX512, AVX2); - } else if (platform::jit::MayIUse(platform::jit::avx)) { + } else if (platform::MayIUse(platform::avx)) { AVX_GUIDE(AVX512, AVX); } else { AVX_GUIDE(AVX512, NonAVX); @@ -143,8 +143,8 @@ void InitDevices(bool init_p2p, const std::vector devices) { #endif #ifdef __AVX2__ - if (!platform::jit::MayIUse(platform::jit::avx2)) { - if (platform::jit::MayIUse(platform::jit::avx)) { + if (!platform::MayIUse(platform::avx2)) { + if (platform::MayIUse(platform::avx)) { AVX_GUIDE(AVX2, AVX); } else { AVX_GUIDE(AVX2, NonAVX); @@ -153,7 +153,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { #endif #ifdef __AVX__ - if (!platform::jit::MayIUse(platform::jit::avx)) { + if (!platform::MayIUse(platform::avx)) { AVX_GUIDE(AVX, NonAVX); } #endif