未验证 提交 2d02b0c1 编写于 作者: G GGBond8488 提交者: GitHub

[Cherry-pick]Cherry pick 0d output (#53538)

* add 0D output support for inalg.slogdet,test=allcase

* fix zerom dime test error test=allcase

* fix test error test=allcase

* add static backward test, test=allcase

* support_0D_output_for_matrix_rank_multi_dot, test=allcase

* add 0D output test for matrox_rank and mutli_dot test=allcase

* fix assert error ,test=allcase

* fix test error, test=allcase

* fix other test error, test=allcase

* fix other test error, test=allcase

* fix test error, test=allcase

* fix matrix_rank and multi dot test err test=allcase

* fix test error test=allcase

* fix test zero dim test, test=allcase

* add static backward test for multi_dot, test=allcase

* add tol 2d broadcast test case, test=allcase

* fix test error test=allcase

* fix test error test=allcase

* test=allcase

* support_0d_output_for_linalg.norm

* fix test error test=allcase

* fix 0D test

* fix test error test=allcase

* fix test error test=allcase

* fix tets,test=allcase

* fix error,test=allcase

* fix errors ,test=allcase

* add static backward , test=allcase

* add static backwward test, test=allcase

* slogdet_support_0D_output

* add new case

* fix tests, test=allcase

* cherry-pick

* cherry-pick

* fix trace gpu kernel 0d error, test=allcase

* fix windows error, test=allcase

* add matrixrank cherry-pick
上级 1d23e0bb
......@@ -27,7 +27,7 @@ namespace detail {
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
......
......@@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
......@@ -990,7 +990,7 @@ void DistInferMeta(const MetaTensor& x,
"The Input(Y) has not been initialized properly. The "
"shape of Input(Y) = [%s].",
y_dims));
out->set_dims({1});
out->set_dims(phi::make_ddim({}));
out->set_dtype(x.dtype());
}
......
......@@ -2344,7 +2344,7 @@ void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
// If the last tensor is 1D of size n view it as a column vector (n, 1)
if (last_dim.size() == 1) {
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]});
} else {
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
: phi::make_ddim({first_dim[0], last_dim[1]});
......
......@@ -38,7 +38,7 @@ namespace detail {
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
......@@ -4405,7 +4405,6 @@ void TraceInferMeta(
auto sizes = vectorize(x_dims);
if (x_dims.size() == 2) {
sizes.clear();
sizes.push_back(1);
} else {
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
......
......@@ -32,7 +32,10 @@ void TraceKernel(const Context& ctx,
auto diag = funcs::Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) {
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
// Adapt to 0D output
auto out_dim_size = out->dims().size();
if (out_dim_size == 0) out_dim_size = 1;
reduce_dims.push_back(out_dim_size);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims);
} else {
......
......@@ -90,10 +90,10 @@ void DeterminantGradKernel(const Context& dev_ctx,
" input tensor's, but here differ %d",
input_dims_size - out_grad.dims().size()));
} else if (input_dims_size == 2) {
// input dims size 2 and grad dims size 1 is possible
// input dims size 2 and grad dims size 0 is possible
PADDLE_ENFORCE_EQ(
out_grad.dims().size(),
1,
0,
phi::errors::InvalidArgument(
"The grad tensor of det dims size should be 2 less than"
" input tensor's, but here differ %d",
......
......@@ -116,7 +116,7 @@ void DeterminantKernel(const Context& dev_ctx,
out->Resize(output_dims);
} else {
// when input is a two-dimension matrix, The det value is a number.
out->Resize({1});
out->Resize(phi::make_ddim({}));
}
VLOG(10) << "output dim:" << out->dims();
}
......
......@@ -91,7 +91,8 @@ void TraceGradKernel(const Context& ctx,
auto input_dims = in_grad->dims();
auto input_stride = phi::stride(input_dims);
auto output_dims = out_grad.dims();
auto output_stride = phi::stride(output_dims);
auto output_stride = output_dims.size() == 0 ? phi::DDim(output_dims)
: phi::stride(output_dims);
auto* out_data = out_grad.data<T>();
T* x_data = ctx.template Alloc<T>(in_grad);
......
......@@ -2311,6 +2311,128 @@ class TestSundryAPI(unittest.TestCase):
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False
out = paddle.linalg.multi_dot([a, b, c])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(a.grad.shape, [4])
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])
def test_dist(self):
x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32")
y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32")
x.stop_gradient = False
y.stop_gradient = False
out = paddle.dist(x, y, 0)
out.backward()
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, np.array(1))
self.assertEqual(x.grad.shape, [2, 2])
self.assertEqual(y.grad.shape, [2, 2])
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
out.backward()
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, np.array(12))
self.assertEqual(x.grad.shape, [2, 2])
def test_cond(self):
pass
# def assert_shape(out):
# self.assertEqual(out.shape, [])
# x = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
# x.stop_gradient = False
# p = 2 : use paddle.sum, paddle.max, paddle.min
# out = paddle.linalg.cond(x)
# assert_shape(out)
# p = fro : use paddle.sum
# out_fro = paddle.linalg.cond(x, p='fro')
# assert_shape(out_fro)
# p = nuc : use paddle.sum, paddle.max, paddle.min
# out_nuc = paddle.linalg.cond(x, p='nuc')
# assert_shape(out_nuc)
# p in (-1, 1) : use paddle.sum, paddle.max, paddle.min
# out_1 = paddle.linalg.cond(x, p=1)
# assert_shape(out_1)
# out_minus_1 = paddle.linalg.cond(x, p=-1)
# assert_shape(out_minus_1)
# p in (-2, 2) :use paddle.max, paddle.min
# out_2 = paddle.linalg.cond(x, p=2)
# assert_shape(out_2)
# out_minus_2 = paddle.linalg.cond(x, p=-2)
# assert_shape(out_minus_2)
# p in (-inf, inf):use paddle.sum, paddle.max, paddle.min
# out_inf = paddle.linalg.cond(x, p=float("inf"))
# assert_shape(out_inf)
# out_minus_inf = paddle.linalg.cond(x, p=-float("inf"))
# assert_shape(out_minus_inf)
# out_minus_inf.backward()
# self.assertTrue(x.grad.shape, [3, 3])
# a = paddle.randn([2, 4, 4])
# a.stop_gradient = False
# a_cond_fro = paddle.linalg.cond(a, p='fro')
# a_cond_fro.backward()
# self.assertEqual(len(a_cond_fro.shape), 1)
# self.assertEqual(a.grad.shape, [2, 4, 4])
def test_cov(self):
xt = paddle.randn((3, 4))
xt.stop_gradient = False
xt_1 = paddle.randn((12,))
xt_1.stop_gradient = False
xt_out = paddle.linalg.cov(xt)
xt_out.retain_grads()
xt_out.backward()
self.assertEqual(xt_out.shape, [3, 3])
self.assertEqual(xt.grad.shape, [3, 4])
xt_1_out = paddle.linalg.cov(xt_1)
xt_1.retain_grads()
xt_1_out.backward()
self.assertEqual(xt_1_out.shape, [])
self.assertEqual(xt_1.grad.shape, [12])
def test_det(self):
xt = paddle.randn([3, 3, 3])
xt.stop_gradient = False
xt_1 = paddle.randn([3, 3])
xt_1.stop_gradient = False
xt_out = paddle.linalg.det(xt)
xt.retain_grads()
xt_out.backward()
self.assertEqual(xt_out.shape, [3])
self.assertEqual(xt.grad.shape, [3, 3, 3])
xt_1_out = paddle.linalg.det(xt_1)
xt_1.retain_grads()
xt_1_out.backward()
self.assertEqual(xt_1_out.shape, [])
self.assertEqual(xt_1.grad.shape, [3, 3])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -4122,6 +4244,100 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))
@prog_scope()
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False
out = paddle.linalg.multi_dot([a, b, c])
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4,))
self.assertEqual(res[2].shape, (4, 5))
self.assertEqual(res[3].shape, (5,))
@prog_scope()
def test_dist(self):
x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32")
y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32")
x.stop_gradient = False
y.stop_gradient = False
out = paddle.dist(x, y)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 2))
np.testing.assert_array_equal(res[0], np.array(2).astype(np.float32))
@prog_scope()
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 2))
np.testing.assert_allclose(res[0], np.array(12))
@prog_scope()
def test_cond(self):
pass
# use paddle.sum, paddle.max, paddle.min
# x = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
# x.stop_gradient = False
# out = paddle.linalg.cond(x)
# paddle.static.append_backward(out)
# prog = paddle.static.default_main_program()
# res = self.exe.run(prog, fetch_list=[out, x.grad_name])
# self.assertTrue(res[0].shape, ())
# self.assertTrue(res[1].shape, (3, 3))
# np.testing.assert_allclose(out, np.array(1.41421342))
@prog_scope()
def test_cov(self):
xt_1 = paddle.randn((12,))
xt_1.stop_gradient = False
out = paddle.linalg.cov(xt_1)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (12,))
@prog_scope()
def test_det(self):
xt_1 = paddle.randn((3, 3))
xt_1.stop_gradient = False
out = paddle.linalg.det(xt_1)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (3, 3))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......@@ -4313,6 +4529,38 @@ class TestNoBackwardAPI(unittest.TestCase):
self.assertEqual(inverse.shape, [1])
self.assertEqual(counts.shape, [1])
def test_matrix_rank(self):
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
self.assertEqual(out.shape, [])
np.testing.assert_equal(out, np.array(10))
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
self.assertEqual(out_c.shape, [3])
np.testing.assert_equal(out_c, np.array([1, 1, 1]))
# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
self.assertEqual(out_tol.shape, [])
# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
self.assertEqual(out_c_tol.shape, [3])
tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
self.assertEqual(out_d.shape, [2])
class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -4547,6 +4795,51 @@ class TestNoBackwardAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, (1,))
self.assertEqual(res[3].shape, (1,))
@prog_scope()
def test_static_matrix_rank(self):
# 2D : OUTPUT 0D
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
# 3D : OUTPUT 1D
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c])
self.assertEqual(res[0].shape, (3,))
# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_tol])
self.assertEqual(res[0].shape, ())
# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c_tol])
self.assertEqual(res[0].shape, (3,))
tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_d])
self.assertEqual(res[0].shape, (2,))
unary_apis_with_complex_input = [
paddle.real,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册