diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc similarity index 85% rename from paddle/fluid/operators/elementwise_mul_mkldnn_op.cc rename to paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index 216c7ed9c6671767b4292b9b011ce5ebee7ed3d5..10290a4aeff6b6a023fb28961d12728aff891e83 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/mkldnn_helper.h" -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" +#include "paddle/fluid/operators/math/jit_kernel.h" +#include "xbyak.h" +#include "xbyak_util.h" namespace paddle { namespace operators { @@ -27,47 +28,6 @@ namespace operators { using framework::DataLayout; using mkldnn::memory; -struct vector_mul : public Xbyak::CodeGenerator { - vector_mul() { - // RDI is ptr X - // RSI is ptr Y - // RDX is ptr Z - // RCX is h - // r8 is w - - push(rbx); - - xor_(rax, rax); - xor_(r10, r10); - vmovups(zmm3, ptr[rsi]); - - L("h_loop"); - xor_(rbx, rbx); - L("w_loop"); - vmovups(zmm2, ptr[rdi + rax]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx + rax], zmm1); - add(rax, 64); - inc(rbx); - cmp(r8, rbx); - jnz("w_loop"); - inc(r10); - cmp(r10, rcx); - jnz("h_loop"); - - pop(rbx); - ret(); - } -}; - -void check(const float* x, const float* y, float* z, int w) { - for (int wi = 0; wi < w; wi++) { - for (int i = 0; i < 16; i++) { - z[wi * 16 + i] = x[wi * 16 + i] * y[i]; - } - } -} - static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { std::transform(format.begin(), format.end(), format.begin(), ::tolower); @@ -163,12 +123,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; - vector_mul mul; - - using mul_func_t = - void (*)(const float*, const float*, float*, int, int); - - mul_func_t mul_func = (mul_func_t)mul.getCode(); + const auto& multiply = + math::jitkernel::KernelPool::Instance() + .template Get>(n); #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { @@ -180,7 +137,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto ptr_z = z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - mul_func(ptr_x, ptr_y, ptr_z, h, w); + multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); } } } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 71205b211b7f571f8081640ef60222de051ff49d..dbfe6290137e8cfdf4308cbe4e8f90a9d1568d7f 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -156,6 +156,42 @@ class VActJitCode : public JitCode { ymm_t ymm_dst = ymm_t(1); }; +#ifdef PADDLE_WITH_MKLDNN +struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { + explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) + : Xbyak::CodeGenerator(code_size) { + // RDI is ptr x_input + // RSI is ptr y_input + // RDX is ptr output + // RCX is height + // r8 is width + + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); + vmovups(zmm3, ptr[rsi]); + + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); + ret(); + } +}; +#endif + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 4d8d3cd79a16a3ea61c4f63da3493e105847d30b..110de3b1408a644fbdf57c2ffa5475d4d925df25 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -94,6 +94,15 @@ class VAddBiasKernel : public Kernel { void (*Compute)(const T *, const T *, T *, int); }; +#ifdef PADDLE_WITH_MKLDNN +template +class EltwiseMulnChw16cNCKernel : public Kernel { + public: + // nChw16c = nChw16c .* NC + void (*Compute)(const float *, const float *, float *, int, int); +}; +#endif + template class VActKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 36a50f20434f313e93bfa3dd2c9d46963024caf7..a143b51439f55d1f80d7936dfad46e31bd19f0cb 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -226,6 +226,44 @@ bool VAddKernelImpl::useMKL(int d) { } #endif +#ifdef PADDLE_WITH_MKLDNN +/* EltwiseMul for nChw16c & NC inputs JitKernel */ +template +class EltwiseMulnChw16cNCKernelImpl + : public math::jitkernel::EltwiseMulnChw16cNCKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit EltwiseMulnChw16cNCKernelImpl(int d) + : EltwiseMulnChw16cNCKernel() { + using mul_func_t = void (*)(const float*, const float*, float*, int, int); +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + // roughly estimate the size of code + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; + sz = sz > 4096 ? sz : 4096; + jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz)); + this->Compute = (mul_func_t)jitcode_->getCode(); + return; + } +#endif + PADDLE_THROW( + "This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN " + "environemnt"); + } + +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +template <> +bool EltwiseMulnChw16cNCKernelImpl::useJIT(int d) { + return true; +} +#endif +#endif + /* VAddRelu JitKernel */ template class VAddReluKernelImpl : public VAddReluKernel { @@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL(videntity, VIdentityKernel); +#ifdef PADDLE_WITH_MKLDNN +REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel); +#endif } // namespace jitkernel } // namespace math