From f7465641c35a990837f988c931e567be80d2ef01 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Fri, 19 Feb 2021 14:17:33 +0100 Subject: [PATCH] Added reshape grad bf16 (#31035) * - added Reshape grad bf16 * - Added reshape grad bf16 * - cosmetics in py --- paddle/fluid/operators/reshape_op.cc | 4 +++- .../tests/unittests/mkldnn/test_reshape_bf16_op.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 41f631f5547..0e11771d87c 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -642,12 +642,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool, - ops::ReshapeGradKernel, paddle::platform::complex64, ops::ReshapeGradKernel, + ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel, + paddle::platform::complex64, ops::ReshapeGradKernel, paddle::platform::complex128, ops::ReshapeGradKernel); REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool, + ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16, ops::ReshapeDoubleGradKernel, paddle::platform::complex64, ops::ReshapeDoubleGradKernel, paddle::platform::complex128, ops::ReshapeDoubleGradKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py index 5128dc1c4a3..ac9b881313a 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py @@ -28,7 +28,7 @@ from paddle import enable_static class TestReshapeBf16Op(OpTest): def setUp(self): self.op_type = "reshape2" - self.use_mkldnn = True + self.use_mkldnn = False self.mkldnn_data_type = "bfloat16" self.init_data() self.init_input_data() @@ -56,6 +56,16 @@ class TestReshapeBf16Op(OpTest): def test_check_output(self): self.check_output_with_place(core.CPUPlace(), no_check_set=['XShape']) + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.inputs["X"]], + user_defined_grad_outputs=[ + self.inputs["X"].reshape(self.infered_shape) + ]) + if __name__ == '__main__': enable_static() -- GitLab