diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 22289ab41798eb821e696e029004914876b6f274..595a6232da60fe83a852f96ef28b1195aff91c57 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -17,11 +17,29 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + namespace paddle { namespace operators { using framework::DataLayout; +struct vector_mul : public Xbyak::CodeGenerator { + vector_mul() { + // RDI is ptr X + // RSI is ptr Y + // RDX is ptr Z + + vmovups(zmm2, ptr[rdi]); + vmovups(zmm3, ptr[rsi]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx], zmm1); + + ret(); + } +}; + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -61,6 +79,14 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; + vector_mul mul; + + using mul_func_t = void (*)(const float*, const float*, float*); + + mul_func_t mul_func = (mul_func_t)mul.getCode(); + + auto ptr_x = x_data; + for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { for (int hi = 0; hi < h; hi++) { @@ -74,9 +100,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { ci * h * w * simd_width + hi * w * simd_width + wi * simd_width; - for (int i = 0; i < simd_width; i++) { - ptr_z[i] = ptr_x[i] * ptr_y[i]; - } + mul_func(ptr_x, ptr_y, ptr_z); } } }