diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 4fd1d1794274e47fc3dc2dbd752b5cf747c23741..8859c0f7d8f62fdfb2704cc80df01425073a9ece 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" +#include #include namespace paddle { @@ -27,29 +28,35 @@ KernelPool& KernelPool::Instance() { return g_jit_kernels; } -template <> -const std::shared_ptr> KernelPool::Get>( - int d) { - std::string key = "f" + std::to_string(d); +const std::shared_ptr KernelPool::Get(const std::string& key) const { if (kers_.find(key) == kers_.end()) { - auto p = std::make_shared>(d); - kers_.insert({key, std::dynamic_pointer_cast(p)}); - return p; + return nullptr; } - return std::dynamic_pointer_cast>(kers_.at(key)); + return kers_.at(key); } -template <> -const std::shared_ptr> KernelPool::Get>( - int d) { - std::string key = "d" + std::to_string(d); - if (kers_.find(key) == kers_.end()) { - auto p = std::make_shared>(d); - kers_.insert({key, std::dynamic_pointer_cast(p)}); - return p; +#define DEFINE_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key) \ + template <> \ + const std::shared_ptr> \ + KernelPool::Get>(int d) { \ + std::string key = #ker_key #dtype_key + std::to_string(d); \ + if (kers_.find(key) == kers_.end()) { \ + auto p = std::make_shared>(d); \ + kers_.insert({key, std::dynamic_pointer_cast(p)}); \ + return p; \ + } \ + return std::dynamic_pointer_cast>(kers_.at(key)); \ } - return std::dynamic_pointer_cast>(kers_.at(key)); -} + +#define REGISTER_BLAS_JITKERNEL(ker_key, ker_class) \ + DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \ + DEFINE_WITH_DTYPE(ker_key, ker_class, double, d) + +REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); +REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); + +#undef REGISTER_BLAS_JITKERNEL +#undef DEFINE_WITH_DTYPE template <> const std::shared_ptr> @@ -57,7 +64,8 @@ KernelPool::Get, int, const std::string&, const std::string&, const std::string&>(int d, const std::string& act_gate, const std::string& act_cand, const std::string& act_cell) { - std::string key = "f" + std::to_string(d) + act_gate + act_cand + act_cell; + std::string key = + "lstmf" + std::to_string(d) + act_gate + act_cand + act_cell; if (kers_.find(key) == kers_.end()) { auto p = std::make_shared>(d, act_gate, act_cand, act_cell); diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 3849d29040bf5cb928501a617da59ad299720d0e..610f6714041066dc000a5560f4c925ff7227b1bb 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -54,6 +54,8 @@ class KernelPool { template const std::shared_ptr Get(ARGS... args); + const std::shared_ptr Get(const std::string &key) const; + private: KernelPool() = default; std::unordered_map> kers_; @@ -68,6 +70,13 @@ class VMulKernel : public Kernel { void (*Compute)(const int n, const T *, const T *, T *); }; +template +class VAddKernel : public Kernel { + public: + explicit VAddKernel(int n); + void (*Compute)(const int n, const T *, const T *, T *); +}; + template class LSTMKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 29394e31893621de85c91a0b04661b5aaa51d208..4ce60ffc043c80fcb58722b2e66f762bffa42431 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -74,15 +74,22 @@ namespace jit = platform::jit; FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ FOR_EACH_ALL_BLOCK(macro_, jit::any) -/* VMUL JitKernel */ -#define VMUL_ANY \ - for (int i = 0; i < n; ++i) { \ - z[i] = x[i] * y[i]; \ +#define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \ + template <> \ + ker_class::ker_class(int d) { \ + SEARCH_ISA_BLOCK(ker_func, ker_dtype); \ } +#define BIND_KERNEL(ker_class, ker_func) \ + BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, float); \ + BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, double) + +/* VMUL JitKernel */ template static void VMulCompute(const int n, const T* x, const T* y, T* z) { - VMUL_ANY + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } } #ifdef PADDLE_USE_MKLML @@ -107,6 +114,8 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) /// lt8 #ifdef PADDLE_USE_MKLML VMUL_MKL_FLOAT(jit::avx, kLT8) +VMUL_MKL_FLOAT(jit::avx2, kLT8) +VMUL_MKL_FLOAT(jit::avx512f, kLT8) #endif /// eq8 @@ -143,20 +152,93 @@ VMUL_MKL_FLOAT(jit::avx2, kEQ16) VMUL_MKL_FLOAT(jit::avx512f, kEQ16) #endif -#define USE_VMUL_KERNEL(T, func) \ - template <> \ - VMulKernel::VMulKernel(int d) { \ - SEARCH_ISA_BLOCK(func, T); \ - } - -USE_VMUL_KERNEL(float, VMulCompute); -USE_VMUL_KERNEL(double, VMulCompute); - -#undef VMUL_ANY #undef VMUL_INTRI8_FLOAT #undef VMUL_MKL_FLOAT #undef VMUL_MKL_DOUBLE -#undef USE_VMUL_KERNEL + +/* VADD */ +template +static void VAddCompute(const int n, const T* x, const T* y, T* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } +} + +#ifdef PADDLE_USE_MKLML +#define VADD_MKL_FLOAT(isa, block) \ + template <> \ + void VAddCompute(const int n, const float* x, \ + const float* y, float* z) { \ + platform::dynload::vsAdd(n, x, y, z); \ + } + +#define VADD_MKL_DOUBLE(isa, block) \ + template <> \ + void VAddCompute(const int n, const double* x, \ + const double* y, float* z) { \ + platform::dynload::vdAdd(n, x, y, z); \ + } + +FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT) +FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) +#endif + +/// lt8 +#ifdef PADDLE_USE_MKLML +VADD_MKL_FLOAT(jit::avx, kLT8) +VADD_MKL_FLOAT(jit::avx2, kLT8) +VADD_MKL_FLOAT(jit::avx512f, kLT8) +#endif + +/// eq8 +#define VADD_INTRI8_FLOAT(isa) \ + template <> \ + void VAddCompute(const int n, const float* x, \ + const float* y, float* z) { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_add_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ + } + +// mkl > avx > for, ">" means better +#ifdef PADDLE_USE_MKLML +VADD_MKL_FLOAT(jit::avx, kEQ8) +#elif defined __AVX__ +VADD_INTRI8_FLOAT(jit::avx) +#endif +// avx2 > mkl > for +#ifdef __AVX2__ +VADD_INTRI8_FLOAT(jit::avx2) +#elif defined PADDLE_USE_MKLML +VADD_MKL_FLOAT(jit::avx2, kEQ8) +#endif +// TODO(TJ): test and complete avx512 + +/// eq16 +#ifdef PADDLE_USE_MKLML +// TODO(TJ): test and complete me +VADD_MKL_FLOAT(jit::avx, kEQ16) +VADD_MKL_FLOAT(jit::avx2, kEQ16) +VADD_MKL_FLOAT(jit::avx512f, kEQ16) +#endif + +#undef VADD_INTRI8_FLOAT +#undef VADD_MKL_FLOAT +#undef VADD_MKL_DOUBLE + +BIND_KERNEL(VMulKernel, VMulCompute); +BIND_KERNEL(VAddKernel, VAddCompute); + +#undef BIND_KERNEL +#undef BIND_KERNEL_WITH_DTYPE +#undef FOR_EACH_ISA_ALL_BLOCK +#undef FOR_EACH_ALL_BLOCK +#undef FOR_EACH_ISA_COMMON_BLOCK +#undef FOR_EACH_COMMON_BLOCK +#undef SEARCH_ISA_BLOCK +#undef SEARCH_BLOCK } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 041234442d372df8deacb3663cff1fef478bac2c..6b2502910180a753c4878f3cd431f585fa7ae5ba 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -23,25 +23,30 @@ TEST(JitKernel, pool) { namespace jit = paddle::operators::math::jitkernel; const int frame_size = 4; std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; - const auto& p1 = + const auto& plstm1 = jit::KernelPool::Instance() .template Get, int, const std::string&, const std::string&, const std::string&>( frame_size, act_gate, act_cand, act_cell); - const auto& p2 = + const auto& plstm2 = jit::KernelPool::Instance() .template Get, int, const std::string&, const std::string&, const std::string&>( frame_size, act_gate, act_cand, act_cell); - EXPECT_EQ(p1, p2); + EXPECT_EQ(plstm1, plstm2); - const auto& p3 = + const auto& pvmul_f = jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(p2) != - std::dynamic_pointer_cast(p3)); + EXPECT_TRUE(std::dynamic_pointer_cast(plstm2) != + std::dynamic_pointer_cast(pvmul_f)); - const auto& p4 = + const auto& pvmul_d = jit::KernelPool::Instance().template Get>(4); - EXPECT_TRUE(std::dynamic_pointer_cast(p3) != - std::dynamic_pointer_cast(p4)); + EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != + std::dynamic_pointer_cast(pvmul_d)); + + const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4"); + EXPECT_TRUE(pvmul_f == pvmul_from_key); + const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5"); + EXPECT_TRUE(pvmul_from_key2 == nullptr); }