diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 41f631f5547369a491e886434b243336fc57b0b4..0e11771d87c99b007d2f3303ff04bdf2c216b3e5 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -642,12 +642,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2_grad, float, ops::ReshapeGradKernel, double, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool, - ops::ReshapeGradKernel, paddle::platform::complex64, ops::ReshapeGradKernel, + ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel, + paddle::platform::complex64, ops::ReshapeGradKernel, paddle::platform::complex128, ops::ReshapeGradKernel); REGISTER_OP_CPU_KERNEL_FUNCTOR( reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool, + ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16, ops::ReshapeDoubleGradKernel, paddle::platform::complex64, ops::ReshapeDoubleGradKernel, paddle::platform::complex128, ops::ReshapeDoubleGradKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py index 5128dc1c4a3447a3d975f0f9d31019fbf4cc060d..ac9b881313a31663c09e25eeae4108991e0f84ff 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reshape_bf16_op.py @@ -28,7 +28,7 @@ from paddle import enable_static class TestReshapeBf16Op(OpTest): def setUp(self): self.op_type = "reshape2" - self.use_mkldnn = True + self.use_mkldnn = False self.mkldnn_data_type = "bfloat16" self.init_data() self.init_input_data() @@ -56,6 +56,16 @@ class TestReshapeBf16Op(OpTest): def test_check_output(self): self.check_output_with_place(core.CPUPlace(), no_check_set=['XShape']) + def test_check_grad(self): + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.inputs["X"]], + user_defined_grad_outputs=[ + self.inputs["X"].reshape(self.infered_shape) + ]) + if __name__ == '__main__': enable_static()