未验证 提交 9e527d99 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Added basic changes for elementwise_add_grad bf16 (#30925)

上级 c98f144f
...@@ -90,4 +90,5 @@ REGISTER_OP_KERNEL( ...@@ -90,4 +90,5 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>)
REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseAddMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseAddMKLDNNGradKernel<float>) ops::EltwiseAddMKLDNNGradKernel<float>)
...@@ -30,10 +30,10 @@ class TestElementwiseAddBf16MklDNNOp(OpTest): ...@@ -30,10 +30,10 @@ class TestElementwiseAddBf16MklDNNOp(OpTest):
self.axis = -1 self.axis = -1
self.generate_data() self.generate_data()
self.inputs = { self.x_bf16 = convert_float_to_uint16(self.x)
'X': convert_float_to_uint16(self.x), self.y_bf16 = convert_float_to_uint16(self.y)
'Y': convert_float_to_uint16(self.y)
} self.inputs = {'X': self.x_bf16, 'Y': self.y_bf16}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': convert_float_to_uint16(self.out)} self.outputs = {'Out': convert_float_to_uint16(self.out)}
...@@ -45,14 +45,30 @@ class TestElementwiseAddBf16MklDNNOp(OpTest): ...@@ -45,14 +45,30 @@ class TestElementwiseAddBf16MklDNNOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
# elementwise_add grad is just passing upper gradients to either X or Y or both
def test_check_grad_normal(self): def test_check_grad_normal(self):
pass self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
check_dygraph=False,
user_defined_grads=[self.x_bf16, self.x_bf16],
user_defined_grad_outputs=[self.x_bf16])
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
pass self.check_grad_with_place(
core.CPUPlace(), ["Y"],
"Out",
check_dygraph=False,
user_defined_grads=[self.y_bf16],
user_defined_grad_outputs=[self.y_bf16])
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
pass self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
check_dygraph=False,
user_defined_grads=[self.x_bf16],
user_defined_grad_outputs=[self.x_bf16])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册