未验证 提交 47fa8066 编写于 作者: G GGBond8488 提交者: GitHub

【0D output】support 0D output for matrix_rank/multi_dot (#52861)

* support_0D_output_for_matrix_rank_multi_dot, test=allcase

* add 0D output test for matrox_rank and mutli_dot test=allcase

* fix assert error ,test=allcase

* fix test error, test=allcase

* fix other test error, test=allcase

* fix other test error, test=allcase

* fix test error, test=allcase

* fix matrix_rank and multi dot test err test=allcase

* fix test error test=allcase

* fix test zero dim test, test=allcase

* add static backward test for multi_dot, test=allcase

* add tol 2d broadcast test case, test=allcase
上级 07878a34
......@@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
......
......@@ -2345,7 +2345,7 @@ void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
// If the last tensor is 1D of size n view it as a column vector (n, 1)
if (last_dim.size() == 1) {
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]});
} else {
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
: phi::make_ddim({first_dim[0], last_dim[1]});
......
......@@ -38,7 +38,7 @@ namespace detail {
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
......
......@@ -2123,6 +2123,23 @@ class TestSundryAPI(unittest.TestCase):
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False
out = paddle.linalg.multi_dot([a, b, c])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(a.grad.shape, [4])
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -3710,6 +3727,26 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))
@prog_scope()
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False
out = paddle.linalg.multi_dot([a, b, c])
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4,))
self.assertEqual(res[2].shape, (4, 5))
self.assertEqual(res[3].shape, (5,))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......@@ -3901,6 +3938,38 @@ class TestNoBackwardAPI(unittest.TestCase):
self.assertEqual(inverse.shape, [1])
self.assertEqual(counts.shape, [1])
def test_matrix_rank(self):
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
self.assertEqual(out.shape, [])
np.testing.assert_equal(out, np.array(10))
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
self.assertEqual(out_c.shape, [3])
np.testing.assert_equal(out_c, np.array([1, 1, 1]))
# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
self.assertEqual(out_tol.shape, [])
# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
self.assertEqual(out_c_tol.shape, [3])
tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
self.assertEqual(out_d.shape, [2])
class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -4135,6 +4204,51 @@ class TestNoBackwardAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, (1,))
self.assertEqual(res[3].shape, (1,))
@prog_scope()
def test_static_matrix_rank(self):
# 2D : OUTPUT 0D
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
# 3D : OUTPUT 1D
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c])
self.assertEqual(res[0].shape, (3,))
# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_tol])
self.assertEqual(res[0].shape, ())
# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c_tol])
self.assertEqual(res[0].shape, (3,))
tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_d])
self.assertEqual(res[0].shape, (2,))
unary_apis_with_complex_input = [
paddle.real,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册