diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 0457aeed616fa33a8ac05d696ff7327f63138ce9..0266edac75d1ef2baa33c2f222cbde539550dabc 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -25,9 +25,9 @@ using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNGetDataType; using paddle::platform::to_void_cast; using Tensor = paddle::framework::Tensor; -using paddle::framework::vectorize; -using paddle::framework::make_ddim; using paddle::framework::GradVarName; +using paddle::framework::make_ddim; +using paddle::framework::vectorize; template class MatMulV2MKLDNNHandler @@ -123,45 +123,58 @@ class MatMulV2MKLDNNHandler } }; -template -class MatMulV2MKLDNNKernel - : public paddle::operators::MatMulGradMKLDNNKernel { - public: - void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } +bool IsOutputFused(const ExecutionContext& ctx) { + auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); + auto& fused_transpose_Out = ctx.Attr>("fused_transpose_Out"); + return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); +} + +float ComputeOutputScale(const ExecutionContext& ctx) { + float scale_x = ctx.Attr("Scale_x"); + float scale_y = ctx.Attr("Scale_y"); + bool force_fp32_out = ctx.Attr("force_fp32_output"); + float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); + return scale_out / (scale_x * scale_y); +} - protected: - void ExecuteMatMul(const ExecutionContext& ctx, +template +void ExecuteMatMulV2(const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, const dnnl::engine onednn_engine, paddle::platform::Place cpu_place, const Tensor* x, std::vector& x_dims, bool trans_x, const Tensor* y, std::vector& y_dims, bool trans_y, Tensor* out, std::vector& out_dims, - int execution_number = 0) const { - MatMulV2MKLDNNHandler handler(onednn_engine, ctx.GetPlace(), x_dims, - trans_x, y_dims, trans_y, - IsOutputFused(ctx)); + int execution_number = 0) { + MatMulV2MKLDNNHandler handler(onednn_engine, ctx.GetPlace(), x_dims, + trans_x, y_dims, trans_y, + IsOutputFused(ctx)); - const auto src_memory_p = handler.AcquireSrcMemory(x); - const auto weights_memory_p = handler.AcquireWeightsMemory(y); - const auto dst_memory_p = handler.AcquireDstMemory(out); + const auto src_memory_p = handler.AcquireSrcMemory(x); + const auto weights_memory_p = handler.AcquireWeightsMemory(y); + const auto dst_memory_p = handler.AcquireDstMemory(out); - auto matmul_p = handler.AcquireForwardPrimitive(); + auto matmul_p = handler.AcquireForwardPrimitive(); - std::unordered_map matmul_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; - auto& astream = MKLDNNDeviceContext::tls().get_stream(); - matmul_p->execute(astream, matmul_args); - astream.wait(); + auto& astream = MKLDNNDeviceContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); - auto format = paddle::platform::MKLDNNFormatForSize( - out->dims().size(), dnnl::memory::format_tag::nchw); - out->set_layout(paddle::framework::DataLayout::kMKLDNN); - out->set_format(format); - } + auto format = paddle::platform::MKLDNNFormatForSize( + out->dims().size(), dnnl::memory::format_tag::nchw); + out->set_layout(paddle::framework::DataLayout::kMKLDNN); + out->set_format(format); +} + +template +class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { + public: + void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } private: void CalculateMatrixDims(const ExecutionContext& ctx, @@ -207,13 +220,6 @@ class MatMulV2MKLDNNKernel } } - bool IsOutputFused(const ExecutionContext& ctx) const { - auto& fused_reshape_Out = ctx.Attr>("fused_reshape_Out"); - auto& fused_transpose_Out = - ctx.Attr>("fused_transpose_Out"); - return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); - } - void RunKernel(const ExecutionContext& ctx) const { const auto& dev_ctx = ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); @@ -237,13 +243,14 @@ class MatMulV2MKLDNNKernel CalculateMatrixDims(ctx, x_dims, y_dims, x_bd_dims, y_bd_dims, out_dims, out); - ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_bd_dims, - trans_x, y, y_bd_dims, trans_y, out, out_dims); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, + x_bd_dims, trans_x, y, y_bd_dims, trans_y, out, + out_dims); } }; template -class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { +class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { public: void Compute(const ExecutionContext& ctx) const override { RunKernel(ctx); } @@ -316,7 +323,7 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { // if no broadcasting is needed, we can simply use matmul's grad and avoid // using reduce_sum if (!is_broadcast) { - paddle::operators::MatMulGradMKLDNNKernel::Compute(ctx); + matmul_v1_grad_mkldnn_kernel.Compute(ctx); return; } @@ -342,33 +349,29 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { dy_bd_dims); if (trans_x && trans_y) { - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, - y_dims, true, dout, dout_dims, true, &dx_tmp, - dx_bd_dims, 1); - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, - dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims, - 2); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims, + true, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, true, x, x_dims, true, &dy_tmp, dy_bd_dims, + 2); } else if (trans_x) { - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, - y_dims, false, dout, dout_dims, true, &dx_tmp, - dx_bd_dims, 1); - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, - x_dims, false, dout, dout_dims, false, &dy_tmp, - dy_bd_dims, 2); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), y, y_dims, + false, dout, dout_dims, true, &dx_tmp, dx_bd_dims, 1); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims, + false, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2); } else if (trans_y) { - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, - dout_dims, false, y, y_dims, false, &dx_tmp, - dx_bd_dims, 1); - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, - dout_dims, true, x, x_dims, false, &dy_tmp, - dy_bd_dims, 2); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, false, y, y_dims, false, &dx_tmp, + dx_bd_dims, 1); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, true, x, x_dims, false, &dy_tmp, dy_bd_dims, + 2); } else { - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, - dout_dims, false, y, y_dims, true, &dx_tmp, - dx_bd_dims, 1); - this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, - x_dims, true, dout, dout_dims, false, &dy_tmp, - dy_bd_dims, 2); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), dout, + dout_dims, false, y, y_dims, true, &dx_tmp, dx_bd_dims, + 1); + ExecuteMatMulV2(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), x, x_dims, + true, dout, dout_dims, false, &dy_tmp, dy_bd_dims, 2); } if (x_dims != dx_bd_dims) { @@ -389,8 +392,12 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { dy->set_layout(paddle::framework::DataLayout::kMKLDNN); dy->set_format(y->format()); } + + private: + paddle::operators::MatMulGradMKLDNNKernel matmul_v1_grad_mkldnn_kernel; }; } // anonymous namespace + namespace ops = paddle::operators; REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,