diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 13e4cc04df01fd1c159294a7daae40616da2cc4e..21716e271d36365b9759c1616ec3b4c0e09a3cba 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -30,16 +30,42 @@ struct vector_mul : public Xbyak::CodeGenerator { // RDI is ptr X // RSI is ptr Y // RDX is ptr Z + // RCX is h + // r8 is w - vmovups(zmm2, ptr[rdi]); + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); vmovups(zmm3, ptr[rsi]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx], zmm1); + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); ret(); } }; +void check(const float* x, const float* y, float* z, int w) { + for (int wi = 0; wi < w; wi++) { + for (int i = 0; i < 16; i++) { + z[wi * 16 + i] = x[wi * 16 + i] * y[i]; + } + } +} + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -65,7 +91,6 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { PADDLE_THROW("Not implemented when post is 1"); } else { // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); int n = x_dims[0]; @@ -81,26 +106,21 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { vector_mul mul; - using mul_func_t = void (*)(const float*, const float*, float*); + using mul_func_t = + void (*)(const float*, const float*, float*, int, int); mul_func_t mul_func = (mul_func_t)mul.getCode(); for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { - for (int hi = 0; hi < h; hi++) { - for (int wi = 0; wi < w; wi++) { - auto ptr_x = x_data + ni * C * h * w * simd_width + - ci * h * w * simd_width + hi * w * simd_width + - wi * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - - auto ptr_z = z_data + ni * C * h * w * simd_width + - ci * h * w * simd_width + hi * w * simd_width + - wi * simd_width; - - mul_func(ptr_x, ptr_y, ptr_z); - } - } + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + mul_func(ptr_x, ptr_y, ptr_z, h, w); } } }