From 74292f414c033bf7cb53b4f87a82f7ff6c18a4b2 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 17 Dec 2018 14:51:52 +0000 Subject: [PATCH] enable eltwise nchw16c mul nc --- .../elementwise/elementwise_mul_mkldnn_op.cc | 10 ++- paddle/fluid/operators/jit/gen/CMakeLists.txt | 1 + paddle/fluid/operators/jit/gen/blas.cc | 43 +++++++++++++ paddle/fluid/operators/jit/gen/blas.h | 12 ++++ paddle/fluid/operators/jit/helper.cc | 3 +- paddle/fluid/operators/jit/helper.h | 1 + paddle/fluid/operators/jit/kernel_base.h | 11 +++- .../fluid/operators/jit/refer/CMakeLists.txt | 1 + paddle/fluid/operators/jit/refer/refer.cc | 2 + paddle/fluid/operators/jit/refer/refer.h | 15 +++++ paddle/fluid/operators/jit/test.cc | 62 +++++++++++++++++++ 11 files changed, 153 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index c600d1e3d76..71f4b71330a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" -#include "paddle/fluid/operators/math/jit_kernel.h" +#include "paddle/fluid/operators/jit/kernels.h" #include "xbyak/xbyak.h" #include "xbyak/xbyak_util.h" @@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; - const auto& multiply = - math::jitkernel::KernelPool::Instance() - .template Get>(n); - + auto multiply = jit::Get(0); #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { @@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto ptr_z = z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); + multiply(ptr_x, ptr_y, ptr_z, h, w); } } } diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 8ad9587b5ef..a7f9e18318d 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -25,3 +25,4 @@ USE_JITKERNEL_GEN(lstmc1h1) USE_JITKERNEL_GEN(gruh1) USE_JITKERNEL_GEN(gruhtpart1) USE_JITKERNEL_GEN(gruhtpart2) +USE_JITKERNEL_GEN(nchw16cmulnc) diff --git a/paddle/fluid/operators/jit/gen/blas.cc b/paddle/fluid/operators/jit/gen/blas.cc index b24f44c9f3b..65b9a52ff2d 100644 --- a/paddle/fluid/operators/jit/gen/blas.cc +++ b/paddle/fluid/operators/jit/gen/blas.cc @@ -104,6 +104,48 @@ void VXXJitCode::genCode() { ret(); } +void NCHW16CMulNCJitCode::genCode() { + // RDI is ptr x_input + // RSI is ptr y_input + // RDX is ptr output + // RCX is height + // r8 is width + + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); + vmovups(zmm3, ptr[rsi]); + + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); + ret(); +} + +class NCHW16CMulNCCreator : public JitCodeCreator { + public: + bool UseMe(const int& attr) const override { + return platform::MayIUse(platform::avx512f); + } + size_t CodeSize(const int& d) const override { return 256 * 1024; } + std::unique_ptr CreateJitCode(const int& attr) const override { + return make_unique(attr, CodeSize(attr)); + } +}; + #define DECLARE_BLAS_CREATOR(name) \ class name##Creator : public JitCodeCreator { \ public: \ @@ -141,3 +183,4 @@ REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator); REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator); REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator); REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator); +REGISTER_JITKERNEL_GEN(nchw16cmulnc, gen::NCHW16CMulNCCreator); diff --git a/paddle/fluid/operators/jit/gen/blas.h b/paddle/fluid/operators/jit/gen/blas.h index 5a2192052f8..29be4e73589 100644 --- a/paddle/fluid/operators/jit/gen/blas.h +++ b/paddle/fluid/operators/jit/gen/blas.h @@ -99,6 +99,18 @@ DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false); #undef DECLARE_BLAS_JITCODE +// nChw16c = nChw16c .* NC +class NCHW16CMulNCJitCode : public JitCode { + public: + DECLARE_JIT_CODE(NCHW16CMulNCJitCode); + explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr) { + this->genCode(); + } + void genCode() override; +}; + } // namespace gen } // namespace jit } // namespace operators diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index a0ff82043fc..a1bb51fa666 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -44,8 +44,9 @@ const char* to_string(KernelType kt) { ONE_CASE(gruhtpart2); ONE_CASE(crfdecoding); ONE_CASE(layernorm); + ONE_CASE(nchw16cmulnc); default: - PADDLE_THROW("Not support type: %d", kt); + PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; } return nullptr; diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 44952fb9079..275170ca2b5 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -93,6 +93,7 @@ inline typename KernelTuples::func_type GetRefer() { template +// TODO(TJ): const & attr typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { auto jitfunc = GetJitCode(attr); if (jitfunc) { diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 59531c2f17c..9ba0a958313 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -39,7 +39,8 @@ typedef enum { gruhtpart1, gruhtpart2, crfdecoding, - layernorm + layernorm, + nchw16cmulnc, } KernelType; template @@ -126,6 +127,14 @@ struct LayerNormTuples { const float, int); }; +// nChw16c = nChw16c .* NC +template +struct NCHW16CMulNCTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int, int); +}; + // Just for adding to kernel pool without template class Kernel { public: diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index f3a0e9b11f6..86432bfffe7 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -25,3 +25,4 @@ USE_JITKERNEL_REFER(gruhtpart1) USE_JITKERNEL_REFER(gruhtpart2) USE_JITKERNEL_REFER(crfdecoding) USE_JITKERNEL_REFER(layernorm) +USE_JITKERNEL_REFER(nchw16cmulnc) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 00daa0d4786..1aee6ff9500 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -45,4 +45,6 @@ REGISTER_REFER_KERNEL(gruhtpart2, GRUHtPart2); REGISTER_REFER_KERNEL(crfdecoding, CRFDecoding); REGISTER_REFER_KERNEL(layernorm, LayerNorm); +REGISTER_REFER_KERNEL(nchw16cmulnc, NCHW16CMulNC); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 5780ea05bdf..6f72c2b724b 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -319,6 +319,19 @@ void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias, } } +template +void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { + int offset = 0; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + for (int i = 0; i < 16; ++i) { + z[i + offset] = y[i] * x[i + offset]; + } + offset += ZMM_FLOAT_BLOCK; + } + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -355,6 +368,8 @@ DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); +DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 85eadea7516..32937d9c005 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -19,6 +19,7 @@ #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/place.h" template @@ -414,6 +415,59 @@ void TestGRUKernel() { } } +template +void TestNCHW16CMulNCKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + const int n = 3, c = 16 * 4, h = 10, w = 10; + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + int sz = n * c * h * w; + std::vector x(sz), y(n * c), zref(sz); + std::vector ztgt(sz), zjit(sz); + RandomVec(sz, x.data(), -2.f, 2.f); + RandomVec(n * c, y.data(), -2.f, 2.f); + + const T* x_data = x.data(); + const T* y_data = y.data(); + T* zref_data = zref.data(); + T* ztgt_data = ztgt.data(); + T* zjit_data = zjit.data(); + constexpr int simd_width = ZMM_FLOAT_BLOCK; + int C = c / simd_width; + auto tgt = jit::Get, PlaceType>(0); + auto jitcode = jit::GetJitCode, PlaceType>(0); + EXPECT_TRUE(tgt != nullptr); + + if (std::is_same::value && + paddle::platform::MayIUse(paddle::platform::avx512f)) { + EXPECT_TRUE(jitcode != nullptr); + } + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_zref = + zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_ztgt = + ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + ref(ptr_x, ptr_y, ptr_zref, h, w); + tgt(ptr_x, ptr_y, ptr_ztgt, h, w); + + if (jitcode) { + auto ptr_zjit = + zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + jitcode(ptr_x, ptr_y, ptr_zjit, h, w); + } + } + } + ExpectEQ(ztgt_data, zref_data, sz); + if (jitcode) { + ExpectEQ(zjit_data, zref_data, sz); + } +} + // XYZNTuple TEST(JITKernel, vmul) { namespace jit = paddle::operators::jit; @@ -515,6 +569,14 @@ TEST(JITKernel, gruhtpart2) { TestGRUKernel(); } +TEST(JITKernel, nchw16cmulnc) { + namespace jit = paddle::operators::jit; + TestNCHW16CMulNCKernel(); + TestNCHW16CMulNCKernel(); +} + // TODO(yihua/TJ): add crf decoding and layer norm unit tests TEST(JITKernel, pool) { -- GitLab