未验证 提交 6adfcdf6 编写于 作者: Z zqw_1997 提交者: GitHub

[Zero-Dim] Support output 0D for squeeze, unbind, unstack. (#52843)

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* test=allcase

* fix test cases, test=allcase

* fix test cases, test=allcase

* modify the test_squeeze to not use Tensor type axis, test=allcase

* add grad check for unbind and unstack, test=allcase

* check for squeeze axis tensor type, test=allcase

* fix bug, test=allcase
上级 96180fff
...@@ -3761,6 +3761,9 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -3761,6 +3761,9 @@ void SqueezeInferMeta(const MetaTensor& x,
if (!config.is_runtime && axes.FromTensor()) { if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape, set all elements to -1. // compile time infershape, set all elements to -1.
int output_size = x.dims().size() - axes.GetData().size(); int output_size = x.dims().size() - axes.GetData().size();
if (x.dims().size() == 0 && output_size == -1) {
output_size = 0;
}
std::vector<int64_t> vec_out_dims(output_size, -1); std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dims(phi::make_ddim(vec_out_dims)); out->set_dims(phi::make_ddim(vec_out_dims));
} else { } else {
......
...@@ -2401,6 +2401,56 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2401,6 +2401,56 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out1.shape, [2, 3, 12, 12]) self.assertEqual(out1.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6]) self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])
def test_unstack(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
[out1] = paddle.unstack(x1, 0)
out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unstack(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])
def test_unbind(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
[out1] = paddle.unbind(x1, 0)
out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unbind(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])
def test_maseked_select(self): def test_maseked_select(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -2415,6 +2465,26 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2415,6 +2465,26 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1) self.assertEqual(x.grad.numpy(), 1)
def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
x1.retain_grads()
out1 = paddle.squeeze(x1, axis=0)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x1.grad.shape, [])
x2 = paddle.full([], 3)
x3 = paddle.full([1], 0, dtype='int32')
x2.stop_gradient = False
x2.retain_grads()
out2 = paddle.squeeze(x2, axis=x3)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x2.grad.shape, [])
def test_unsqueeze(self): def test_unsqueeze(self):
x1 = paddle.full([], 2) x1 = paddle.full([], 2)
x1.stop_gradient = False x1.stop_gradient = False
...@@ -4242,6 +4312,50 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -4242,6 +4312,50 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res1[0].shape, (2, 3, 12, 12)) self.assertEqual(res1[0].shape, (2, 3, 12, 12))
self.assertEqual(res1[1].shape, (2, 3, 6, 6)) self.assertEqual(res1[1].shape, (2, 3, 6, 6))
@prog_scope()
def test_unstack(self):
x1 = paddle.full([1], 0, 'float32')
x1.stop_gradient = False
out1 = paddle.unstack(x1, 0)
out1 = paddle.add_n(out1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (1,))
x2 = paddle.full([2], 2, 'float32')
x2.stop_gradient = False
out2 = paddle.unstack(x2, 0)
out2_sum = paddle.add_n(out2)
paddle.static.append_backward(out2_sum)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
@prog_scope()
def test_unbind(self):
x1 = paddle.full([1], 0, 'float32')
x1.stop_gradient = False
out1 = paddle.unbind(x1, 0)
out1 = paddle.add_n(out1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (1,))
x2 = paddle.full([2], 2, 'float32')
x2.stop_gradient = False
out2 = paddle.unbind(x2, 0)
out2_sum = paddle.add_n(out2)
paddle.static.append_backward(out2_sum)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
@prog_scope() @prog_scope()
def test_maseked_select(self): def test_maseked_select(self):
x = paddle.rand([]) x = paddle.rand([])
...@@ -4258,6 +4372,34 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -4258,6 +4372,34 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1) self.assertEqual(res[3], 1)
@prog_scope()
def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
out1 = paddle.squeeze(x1, axis=0)
paddle.static.append_backward(out1.sum())
x2 = paddle.full([], 3)
x3 = paddle.full([], 0, dtype='int32')
x2.stop_gradient = False
out2 = paddle.squeeze(x2, axis=x3)
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
@prog_scope() @prog_scope()
def test_unsqueeze(self): def test_unsqueeze(self):
x1 = paddle.full([], 2) x1 = paddle.full([], 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册