From 25fc2a1fdb4b949f94f97a6d954ba13862f6c38a Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Fri, 19 Mar 2021 13:28:04 +0100 Subject: [PATCH] [oneDNN] Added Elementwise Mul grad fp32/bf16 (#31647) --- .../operators/elementwise/elementwise_op.h | 5 +- .../mkldnn/elementwise_add_mkldnn_op.cc | 11 ++ .../mkldnn/elementwise_mkldnn_op.h | 1 - .../mkldnn/elementwise_mul_mkldnn_op.cc | 116 ++++++++++++++++++ paddle/fluid/platform/mkldnn_reuse.h | 10 +- .../test_elementwise_mul_bf16_mkldnn_op.py | 66 ++++++++-- .../mkldnn/test_elementwise_mul_mkldnn_op.py | 12 +- 7 files changed, 206 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 6ec73b02ade..e09f94a6c0f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -276,7 +276,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN // If broadcasting is needed, use native implementation - auto CanMKLDNNElementwiseAddGradBeUsed = [&]() { + auto CanMKLDNNElementwiseGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); auto dy_dims = ctx.Input("Y")->dims(); // No broadcast or broadcasting of data on inner dims is supported @@ -284,8 +284,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { }; if (this->CanMKLDNNBeUsed(ctx, input_data_type) && - (ctx.Type() != "elementwise_add_grad" || - CanMKLDNNElementwiseAddGradBeUsed())) { + CanMKLDNNElementwiseGradBeUsed()) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 4db4adfe9e9..b43dddfcf19 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -61,6 +61,9 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { platform::EventRole::kUniqueOp); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); + + dx->set_layout(DataLayout::kMKLDNN); + dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); } if (dy) { @@ -75,6 +78,9 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); + + dy->set_layout(DataLayout::kMKLDNN); + dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p)); } else { // Broadcasting platform::ReductionMKLDNNHandler handler_sum( @@ -86,6 +92,11 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, {DNNL_ARG_DST, *dy_memory_p}}); astream.wait(); + + dy->set_layout(DataLayout::kMKLDNN); + dy->set_format( + platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( + paddle::framework::vectorize(dy->dims())))); } } } diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index 8a646e5865d..df827117a0d 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -15,7 +15,6 @@ #pragma once #include #include -#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index 293b5a1a2d3..c9209cc39d5 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -14,6 +14,118 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" +namespace paddle { +namespace framework { +class ExecutionContext; +} // namespace framework +namespace platform { +class CPUDeviceContext; +struct CPUPlace; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { +template +class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + if (dx) { + // dx = dout*y + platform::BinaryMKLDNNHandler handler( + dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, + ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f, + ctx.InputName(framework::GradVarName("Out"))); + + const auto src_dout_memory = handler.AcquireSrcMemory(dout); + const auto src_y_memory = handler.AcquireSecondSrcMemory(y); + const auto dst_dx_memory = handler.AcquireDstMemory(dx); + + const auto binary_prim = handler.AcquireForwardPrimitive(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_dout_memory}, + {DNNL_ARG_SRC_1, *src_y_memory}, + {DNNL_ARG_DST, *dst_dx_memory}}; + + binary_prim->execute(astream, args); + astream.wait(); + + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory)); + } + + if (dy) { + // dy = dout*x + // Handler is having nullptr passed instead of output tensor as + // we want Dst buffer to be allocated by oneDNN not to use Tensor + platform::BinaryMKLDNNHandler handler( + dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, + ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f, + ctx.InputName(framework::GradVarName("Out"))); + + const auto src_dout_memory = handler.AcquireSrcMemory(dout); + const auto src_x_memory = handler.AcquireSecondSrcMemory(x); + + // If broadcasting is in use then let's write to temporary + // buffer allocated by oneDNN + const auto dst_dy_memory = (dout->dims() == dy->dims()) + ? handler.AcquireDstMemory(dy) + : handler.AcquireDstMemory(); + + const auto binary_prim = handler.AcquireForwardPrimitive(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_dout_memory}, + {DNNL_ARG_SRC_1, *src_x_memory}, + {DNNL_ARG_DST, *dst_dy_memory}}; + + binary_prim->execute(astream, args); + astream.wait(); + + dy->set_layout(framework::DataLayout::kMKLDNN); + + // Reduction is needed for broadcasting scenario + if (dout->dims() != dy->dims()) { + platform::ReductionMKLDNNHandler handler_sum( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, + ctx.GetPlace(), dout, dy, + ctx.InputName(framework::GradVarName("Out"))); + auto dy_memory_p = handler_sum.AcquireDstMemory(dy); + auto reduction_p = handler_sum.AcquireForwardPrimitive(); + // As source we use mem object with results from binary operation + reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory}, + {DNNL_ARG_DST, *dy_memory_p}}); + astream.wait(); + dy->set_format( + platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape( + paddle::framework::vectorize(dy->dims())))); + + } else { + dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory)); + } + } + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_KERNEL( @@ -23,3 +135,7 @@ REGISTER_OP_KERNEL( dnnl::algorithm::binary_mul>, ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel) + +REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseMulMKLDNNGradKernel, + ops::EltwiseMulMKLDNNGradKernel) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 0503c3f71a8..c79b642c51b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -87,6 +87,11 @@ class MKLDNNHandlerT { "@dst_mem_p"); } + template + std::shared_ptr AcquireDstMemory(void) { + return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), "@dstt_mem_p"); + } + template std::shared_ptr AcquireDstMemory( const framework::Tensor* output) { @@ -561,7 +566,10 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { const auto src_x_tz = framework::vectorize(x->dims()); const auto src_y_tz = framework::vectorize(y->dims()); - const auto dst_tz = framework::vectorize(z->dims()); + // if output tensor(z) is nullptr then we are computing into oneDNN + // managed buffer + const auto dst_tz = + (z == nullptr) ? src_x_tz : framework::vectorize(z->dims()); const auto src0_md = dnnl::memory::desc( src_x_tz, platform::MKLDNNGetDataType(), x->format()); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py index c2716420fba..9b7f4b9b860 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py @@ -30,10 +30,9 @@ class TestElementwiseMulBf16MklDNNOp(OpTest): self.axis = -1 self.generate_data() - self.inputs = { - 'X': convert_float_to_uint16(self.x), - 'Y': convert_float_to_uint16(self.y) - } + self.x_bf16 = convert_float_to_uint16(self.x) + self.y_bf16 = convert_float_to_uint16(self.y) + self.inputs = {'X': self.x_bf16, 'Y': self.y_bf16} self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': convert_float_to_uint16(self.out)} @@ -46,13 +45,66 @@ class TestElementwiseMulBf16MklDNNOp(OpTest): self.check_output_with_place(core.CPUPlace()) def test_check_grad_normal(self): - pass + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + check_dygraph=False, + user_defined_grads=[ + np.multiply(self.x, self.y), np.multiply(self.x, self.x) + ], + user_defined_grad_outputs=[self.x_bf16]) def test_check_grad_ingore_x(self): - pass + self.check_grad_with_place( + core.CPUPlace(), ["Y"], + "Out", + check_dygraph=False, + user_defined_grads=[np.multiply(self.y, self.x)], + user_defined_grad_outputs=[self.y_bf16]) def test_check_grad_ingore_y(self): - pass + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[np.multiply(self.x, self.y)], + user_defined_grad_outputs=[self.x_bf16]) + + +class TestElementwiseMulBroadcastingBf16MklDNNOp( + TestElementwiseMulBf16MklDNNOp): + def generate_data(self): + self.x = np.random.uniform(1, 2, [1, 2, 3, 100]).astype(np.float32) + self.y = np.random.uniform(1, 2, [100]).astype(np.float32) + self.out = np.multiply(self.x, self.y) + + # Compute partial sums along all axes but last one + def compute_reduced_gradients(self, out_grads): + part_sum = np.add.reduceat(out_grads, [0], axis=0) + part_sum = np.add.reduceat(part_sum, [0], axis=1) + part_sum = np.add.reduceat(part_sum, [0], axis=2) + return part_sum.flatten() + + def test_check_grad_normal(self): + self.check_grad_with_place( + core.CPUPlace(), ["X", "Y"], + "Out", + check_dygraph=False, + user_defined_grads=[ + np.multiply(self.x, self.y), + self.compute_reduced_gradients(np.multiply(self.x, self.x)) + ], + user_defined_grad_outputs=[self.x_bf16]) + + def test_check_grad_ingore_x(self): + self.check_grad_with_place( + core.CPUPlace(), ["Y"], + "Out", + check_dygraph=False, + user_defined_grads=[ + self.compute_reduced_gradients(np.multiply(self.x, self.x)) + ], + user_defined_grad_outputs=[self.x_bf16]) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py index d66f3dfb891..03dc2421b65 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci from paddle.fluid.tests.unittests.test_elementwise_mul_op import ElementwiseMulOp +from paddle import enable_static class TestMKLDNNElementwiseMulOp(ElementwiseMulOp): @@ -51,13 +52,17 @@ class TestMKLDNNElementwiseMulOp4(TestMKLDNNElementwiseMulOp): def test_check_grad_normal(self): pass - def test_check_grad_ingore_x(self): - pass - def test_check_grad_ingore_y(self): pass +class TestMKLDNNElementwiseMulOp5(TestMKLDNNElementwiseMulOp): + def init_input_output(self): + self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + ''' INT8 Tests ''' @@ -140,4 +145,5 @@ class TestUint8Scales(TestInt8Scales): if __name__ == '__main__': + enable_static() unittest.main() -- GitLab