未验证 提交 f7465641 编写于 作者: J Jacek Czaja 提交者: GitHub

Added reshape grad bf16 (#31035)

* - added Reshape grad bf16

* - Added reshape grad bf16

* - cosmetics in py
上级 4dbe16c4
...@@ -642,12 +642,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR( ...@@ -642,12 +642,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double, reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t, ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, paddle::platform::complex64, ops::ReshapeGradKernel, ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel,
paddle::platform::complex64, ops::ReshapeGradKernel,
paddle::platform::complex128, ops::ReshapeGradKernel); paddle::platform::complex128, ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR( REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double, reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t, ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool, ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16,
ops::ReshapeDoubleGradKernel, paddle::platform::complex64, ops::ReshapeDoubleGradKernel, paddle::platform::complex64,
ops::ReshapeDoubleGradKernel, paddle::platform::complex128, ops::ReshapeDoubleGradKernel, paddle::platform::complex128,
ops::ReshapeDoubleGradKernel); ops::ReshapeDoubleGradKernel);
......
...@@ -28,7 +28,7 @@ from paddle import enable_static ...@@ -28,7 +28,7 @@ from paddle import enable_static
class TestReshapeBf16Op(OpTest): class TestReshapeBf16Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reshape2" self.op_type = "reshape2"
self.use_mkldnn = True self.use_mkldnn = False
self.mkldnn_data_type = "bfloat16" self.mkldnn_data_type = "bfloat16"
self.init_data() self.init_data()
self.init_input_data() self.init_input_data()
...@@ -56,6 +56,16 @@ class TestReshapeBf16Op(OpTest): ...@@ -56,6 +56,16 @@ class TestReshapeBf16Op(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), no_check_set=['XShape']) self.check_output_with_place(core.CPUPlace(), no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
check_dygraph=False,
user_defined_grads=[self.inputs["X"]],
user_defined_grad_outputs=[
self.inputs["X"].reshape(self.infered_shape)
])
if __name__ == '__main__': if __name__ == '__main__':
enable_static() enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册