未验证 提交 65b0181e 编写于 作者: W wawltor 提交者: GitHub

add add_n for the 0d tensor (#49854)

上级 8fdb9087
......@@ -297,7 +297,7 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}
bool is_all_0d_tensor = true;
phi::DDim in_dim({0});
for (size_t i = 0; i < x.size(); ++i) {
auto x_dim = x[i]->dims();
......@@ -313,6 +313,7 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
if (x_dim.size() == 0) {
continue;
}
is_all_0d_tensor = false;
if (phi::product(in_dim) == 0) {
in_dim = x_dim;
} else {
......@@ -360,7 +361,11 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
}
}
}
out->set_dims(in_dim);
if (is_all_0d_tensor) {
out->set_dims(make_ddim({}));
} else {
out->set_dims(in_dim);
}
out->share_lod(*x[0]);
}
......
......@@ -900,6 +900,31 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
x2 = paddle.rand([])
x2.stop_gradient = False
x3 = paddle.rand([])
x3.stop_gradient = False
out1 = paddle.add_n(x1)
out2 = paddle.add_n([x2, x3])
out1.backward()
out2.backward()
self.assertEqual(x1.grad.shape, [])
self.assertTrue(x1.grad.numpy() == 1)
self.assertEqual(x2.grad.shape, [])
self.assertTrue(x2.grad.numpy() == 1)
self.assertEqual(x3.grad.shape, [])
self.assertTrue(x3.grad.numpy() == 1)
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])
def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False
......@@ -1534,6 +1559,46 @@ class TestSundryAPIStatic(unittest.TestCase):
np.testing.assert_array_equal(out3_1, out3_2)
np.testing.assert_array_equal(out3_2, np.asarray(1))
@prog_scope()
def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
x2 = paddle.rand([])
x2.stop_gradient = False
x3 = paddle.rand([])
x3.stop_gradient = False
out1 = paddle.add_n(x1)
out2 = paddle.add_n([x2, x3])
paddle.static.append_backward(out1.sum())
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
x3.grad_name,
out1.grad_name,
out2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[2], 1)
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)
self.assertEqual(res[4].shape, ())
self.assertEqual(res[4], 1)
self.assertEqual(res[5].shape, ())
self.assertEqual(res[6].shape, ())
@prog_scope()
def test_reshape_list(self):
x1 = paddle.rand([])
......
......@@ -592,6 +592,28 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
def test_add_n(self):
x1 = paddle.rand([])
x1.stop_gradient = False
x2 = paddle.rand([])
x2.stop_gradient = False
x3 = paddle.rand([])
x3.stop_gradient = False
out1 = paddle.add_n(x1)
out2 = paddle.add_n([x2, x3])
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])
def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册