未验证 提交 af247f95 编写于 作者: H hong 提交者: GitHub

fix reduce prod backward bug (#41357)

上级 e0ccaeaf
......@@ -1559,8 +1559,6 @@ class OpTest(unittest.TestCase):
def _compare_numpy(self, name, actual_np, expect_np):
with _test_eager_guard():
print(actual_np)
print(expect_np)
super()._compare_numpy(name, actual_np, expect_np)
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
......
......@@ -238,10 +238,14 @@ class TestMin8DOp(OpTest):
self.check_output(check_eager=True)
def raw_reduce_prod(x, dim=[0], keep_dim=False):
return paddle.prod(x, dim, keep_dim)
class TestProdOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = paddle.prod
self.python_api = raw_reduce_prod
self.init_data_type()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)}
self.outputs = {'Out': self.inputs['X'].prod(axis=0)}
......@@ -251,15 +255,16 @@ class TestProdOp(OpTest):
) else "float64"
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = raw_reduce_prod
self.init_data_type()
self.inputs = {
'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type)
......@@ -274,15 +279,16 @@ class TestProd6DOp(OpTest):
) else "float64"
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestProd8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = raw_reduce_prod
self.init_data_type()
self.inputs = {
'X': np.random.random(
......@@ -298,10 +304,10 @@ class TestProd8DOp(OpTest):
) else "float64"
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
class TestAllOp(OpTest):
......
......@@ -769,7 +769,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : reduce_prod_grad
func : prod_grad
- backward_api : relu_double_grad
forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册