diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 4678b008d79ae830934bc677dc4845dc9eb68695..9763d14d54aed15cce50110195f6affffc666e09 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,5 +76,5 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) -cc_library(jit_kernel SRCS jit_kernel.cc DEPS cpu_info cblas) +cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_lstm.cc DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 71b1ffc6670c15859b81e590fdc46f2e41c567f6..4fd1d1794274e47fc3dc2dbd752b5cf747c23741 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -13,17 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" -#include #include -#include "paddle/fluid/operators/math/cpu_vec.h" - -#ifdef PADDLE_WITH_MKLML -#include "paddle/fluid/platform/dynload/mklml.h" -#endif - -#ifdef __AVX__ -#include -#endif namespace paddle { namespace operators { @@ -36,115 +26,6 @@ KernelPool& KernelPool::Instance() { static KernelPool g_jit_kernels; return g_jit_kernels; } -#define SEARCH_BLOCK(src, t, isa) \ - if (d < AVX_FLOAT_BLOCK) { \ - Compute = src; \ - } else if (d == AVX_FLOAT_BLOCK) { \ - Compute = src; \ - } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ - Compute = src; \ - } else if (d == AVX512_FLOAT_BLOCK) { \ - Compute = src; \ - } else { \ - Compute = src; \ - } - -#define SEARCH_ISA_BLOCK(src, t) \ - if (jit::MayIUse(jit::avx512f)) { \ - SEARCH_BLOCK(src, t, jit::avx512f); \ - } else if (jit::MayIUse(jit::avx2)) { \ - SEARCH_BLOCK(src, t, jit::avx2); \ - } else if (jit::MayIUse(jit::avx)) { \ - SEARCH_BLOCK(src, t, jit::avx); \ - } else { \ - SEARCH_BLOCK(src, t, jit::isa_any); \ - } - -// do not include lt8, eq8, eq16 -#define FOR_EACH_COMMON_BLOCK(macro_, isa) \ - macro_(isa, kGT8LT16) macro_(isa, kGT16) - -#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \ - FOR_EACH_BLOCK(macro_, jit::avx512f) \ - FOR_EACH_BLOCK(macro_, jit::avx2) \ - FOR_EACH_BLOCK(macro_, jit::avx) \ - FOR_EACH_BLOCK(macro_, jit::any) - -#define VMUL_ANY \ - for (int i = 0; i < n; ++i) { \ - z[i] = x[i] * y[i]; \ - } - -template -static void VMulCompute(const int n, const T* x, const T* y, T* z) { - VMUL_ANY -} - -#ifdef PADDLE_USE_MKLML -#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \ - template <> \ - void VMulCompute(const int n, const float* x, \ - const float* y, float* z) { \ - platform::dynload::vsMul(n, x, y, z); \ - } - -#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \ - template <> \ - void VMulCompute(const int n, const double* x, \ - const double* y, float* z) { \ - platform::dynload::vdMul(n, x, y, z); \ - } - -FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT) -FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE) -DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kLT8) -DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ16) -#endif - -// mkl > avx > for, ">" means better -#ifdef PADDLE_USE_MKLML -DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ8) -#elif defined __AVX__ -template <> -void VMulCompute(const int n, const float* x, - const float* y, float* z) { - __m256 tmpx, tmpy; - tmpx = _mm256_loadu_ps(x); - tmpy = _mm256_loadu_ps(y); - tmpx = _mm256_mul_ps(tmpx, tmpy); - _mm256_storeu_ps(z, tmpx); -} -#endif - -// avx2 > mkl > for -#ifdef __AVX2__ -template <> -void VMulCompute(const int n, const float* x, - const float* y, float* z) { - __m256 tmpx, tmpy; - tmpx = _mm256_loadu_ps(x); - tmpy = _mm256_loadu_ps(y); - tmpx = _mm256_mul_ps(tmpx, tmpy); - _mm256_storeu_ps(z, tmpx); -} -#elif defined PADDLE_USE_MKLML -DEFINE_VMUL_COMPUTE_FLOAT(jit::avx2, kEQ8) -#endif -// TODO(TJ): test and complete avx512 - -#undef DEFINE_VMUL_COMPUTE_FLOAT -#undef DEFINE_VMUL_COMPUTE_DOUBLE -#undef VMUL_ANY - -template <> -VMulKernel::VMulKernel(int d) { - SEARCH_ISA_BLOCK(VMulCompute, float); -} - -template <> -VMulKernel::VMulKernel(int d) { - SEARCH_ISA_BLOCK(VMulCompute, double); -} template <> const std::shared_ptr> KernelPool::Get>( @@ -170,52 +51,6 @@ const std::shared_ptr> KernelPool::Get>( return std::dynamic_pointer_cast>(kers_.at(key)); } -template <> -LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, - const std::string& act_cand_str, - const std::string& act_cell_str) - : Kernel(), d_(d) { - d2_ = d * 2; - d3_ = d * 3; - if (platform::jit::MayIUse(platform::jit::avx512f)) { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - } else if (platform::jit::MayIUse(platform::jit::avx2)) { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - } else if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - // ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) { - // // gates: W_ch, W_ih, W_fh, W_oh - // act_gate(d3_, gates + d_, gates + d_); - - // /* C_t = C_t-1 * fgated + cand_gated * igated */ - // act_cand(d_, gates, gates); - // blas.VMUL(d_, gates, gates + d_, gates + d_); - // blas.VMUL(d_, ct_1, gates + d2_, gates + d2_); - // blas.VADD(d_, gates + d_, gates + d2_, ct); - - // /* H_t = act_cell(C_t) * ogated */ - // act_cell(d_, ct, gates + d2_); - // blas.VMUL(d_, gates + d2_, gates + d3_, ht) - // GET_Ct(ct_1, gates, ct); - // GET_Ht(ct, gates, ht); - // }; - } else { - math::VecActivations act_functor; - act_gate_ = act_functor(act_gate_str); - act_cell_ = act_functor(act_cell_str); - act_cand_ = act_functor(act_cand_str); - } -} - template <> const std::shared_ptr> KernelPool::Get, int, const std::string&, const std::string&, diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 6005ea76f415a16b125ba76d5cbfebc787e67fe3..3849d29040bf5cb928501a617da59ad299720d0e 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -87,5 +87,3 @@ class LSTMKernel : public Kernel { } // namespace math } // namespace operators } // namespace paddle - -#include "paddle/fluid/operators/math/jit_kernel_impl.h" diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc new file mode 100644 index 0000000000000000000000000000000000000000..29394e31893621de85c91a0b04661b5aaa51d208 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -0,0 +1,164 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +#define SEARCH_BLOCK(src, t, isa) \ + if (d < AVX_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d == AVX_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d == AVX512_FLOAT_BLOCK) { \ + Compute = src; \ + } else { \ + Compute = src; \ + } + +#define SEARCH_ISA_BLOCK(src, t) \ + if (jit::MayIUse(jit::avx512f)) { \ + SEARCH_BLOCK(src, t, jit::avx512f); \ + } else if (jit::MayIUse(jit::avx2)) { \ + SEARCH_BLOCK(src, t, jit::avx2); \ + } else if (jit::MayIUse(jit::avx)) { \ + SEARCH_BLOCK(src, t, jit::avx); \ + } else { \ + SEARCH_BLOCK(src, t, jit::isa_any); \ + } + +// do not include lt8, eq8, eq16 +#define FOR_EACH_COMMON_BLOCK(macro_, isa) \ + macro_(isa, kGT8LT16) macro_(isa, kGT16) + +#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \ + FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \ + FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \ + FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \ + FOR_EACH_COMMON_BLOCK(macro_, jit::any) + +#define FOR_EACH_ALL_BLOCK(macro_, isa) \ + macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \ + macro_(isa, kGT16) + +#define FOR_EACH_ISA_ALL_BLOCK(macro_) \ + FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \ + FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \ + FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ + FOR_EACH_ALL_BLOCK(macro_, jit::any) + +/* VMUL JitKernel */ +#define VMUL_ANY \ + for (int i = 0; i < n; ++i) { \ + z[i] = x[i] * y[i]; \ + } + +template +static void VMulCompute(const int n, const T* x, const T* y, T* z) { + VMUL_ANY +} + +#ifdef PADDLE_USE_MKLML +#define VMUL_MKL_FLOAT(isa, block) \ + template <> \ + void VMulCompute(const int n, const float* x, \ + const float* y, float* z) { \ + platform::dynload::vsMul(n, x, y, z); \ + } + +#define VMUL_MKL_DOUBLE(isa, block) \ + template <> \ + void VMulCompute(const int n, const double* x, \ + const double* y, float* z) { \ + platform::dynload::vdMul(n, x, y, z); \ + } + +FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT) +FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) +#endif + +/// lt8 +#ifdef PADDLE_USE_MKLML +VMUL_MKL_FLOAT(jit::avx, kLT8) +#endif + +/// eq8 +#define VMUL_INTRI8_FLOAT(isa) \ + template <> \ + void VMulCompute(const int n, const float* x, \ + const float* y, float* z) { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_mul_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ + } + +// mkl > avx > for, ">" means better +#ifdef PADDLE_USE_MKLML +VMUL_MKL_FLOAT(jit::avx, kEQ8) +#elif defined __AVX__ +VMUL_INTRI8_FLOAT(jit::avx) +#endif +// avx2 > mkl > for +#ifdef __AVX2__ +VMUL_INTRI8_FLOAT(jit::avx2) +#elif defined PADDLE_USE_MKLML +VMUL_MKL_FLOAT(jit::avx2, kEQ8) +#endif +// TODO(TJ): test and complete avx512 + +/// eq16 +#ifdef PADDLE_USE_MKLML +// TODO(TJ): test and complete me +VMUL_MKL_FLOAT(jit::avx, kEQ16) +VMUL_MKL_FLOAT(jit::avx2, kEQ16) +VMUL_MKL_FLOAT(jit::avx512f, kEQ16) +#endif + +#define USE_VMUL_KERNEL(T, func) \ + template <> \ + VMulKernel::VMulKernel(int d) { \ + SEARCH_ISA_BLOCK(func, T); \ + } + +USE_VMUL_KERNEL(float, VMulCompute); +USE_VMUL_KERNEL(double, VMulCompute); + +#undef VMUL_ANY +#undef VMUL_INTRI8_FLOAT +#undef VMUL_MKL_FLOAT +#undef VMUL_MKL_DOUBLE +#undef USE_VMUL_KERNEL + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_impl.h b/paddle/fluid/operators/math/jit_kernel_impl.h deleted file mode 100644 index 46fef31ff03852a963082dcfbd826a76b3d39171..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/jit_kernel_impl.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include "paddle/fluid/platform/cpu_info.h" - -namespace paddle { -namespace operators { -namespace math { -namespace jitkernel {} // namespace jitkernel -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..895784a4fa6c183e7e4462ba1d2fdb76898052ff --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include +#include "paddle/fluid/operators/math/cpu_vec.h" + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +template <> +LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, + const std::string& act_cand_str, + const std::string& act_cell_str) + : Kernel(), d_(d) { + d2_ = d * 2; + d3_ = d * 3; + if (platform::jit::MayIUse(platform::jit::avx512f)) { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } else if (platform::jit::MayIUse(platform::jit::avx2)) { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } else if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + // ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) { + // // gates: W_ch, W_ih, W_fh, W_oh + // act_gate(d3_, gates + d_, gates + d_); + + // /* C_t = C_t-1 * fgated + cand_gated * igated */ + // act_cand(d_, gates, gates); + // blas.VMUL(d_, gates, gates + d_, gates + d_); + // blas.VMUL(d_, ct_1, gates + d2_, gates + d2_); + // blas.VADD(d_, gates + d_, gates + d2_, ct); + + // /* H_t = act_cell(C_t) * ogated */ + // act_cell(d_, ct, gates + d2_); + // blas.VMUL(d_, gates + d2_, gates + d3_, ht) + // GET_Ct(ct_1, gates, ct); + // GET_Ht(ct, gates, ht); + // }; + } else { + math::VecActivations act_functor; + act_gate_ = act_functor(act_gate_str); + act_cell_ = act_functor(act_cell_str); + act_cand_ = act_functor(act_cand_str); + } +} + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle