未验证 提交 1a3d2592 编写于 作者: 傅剑寒 提交者: GitHub

[Zero-Dim] Support 0D for paddle.diagflat (#48735)

* [Zero-Dim] Support 0D for paddle.diagflat
上级 65420271
......@@ -563,8 +563,8 @@ void DiagInferMeta(const MetaTensor& x,
MetaTensor* out) {
auto x_dims = x.dims();
if (x_dims.size() == 1UL) {
int64_t size_ = x_dims[0] + std::abs(offset);
if (x_dims.size() <= 1) {
int64_t size_ = (x_dims.size() == 1UL ? x_dims[0] : 1) + std::abs(offset);
out->set_dims({size_, size_});
out->set_dtype(x.dtype());
} else if (x_dims.size() == 2UL) {
......
......@@ -32,9 +32,9 @@ void DiagGradKernel(const Context& dev_ctx,
auto dx_dims = x_grad->dims();
auto dout_dims = out_grad.dims();
if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
int dx_stride = phi::funcs::ComputeStride(0, dx_dims);
if (dx_dims.size() <= 1) {
auto dx_length = (dx_dims.size() == 1 ? dx_dims[0] : int64_t(1));
int dx_stride = 1;
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
......
......@@ -33,12 +33,12 @@ void DiagKernel(const Context& dev_ctx,
auto out_dims = out->dims();
int64_t i;
if (x_dims.size() == 1) {
if (x_dims.size() <= 1) {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
const int& x_stride = phi::funcs::ComputeStride(0, x_dims);
auto x_length = (x_dims.size() == 1UL ? x_dims[0] : int64_t(1));
const int& x_stride = 1;
auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
auto out_stride_1 = phi::funcs::ComputeStride(1, out_dims);
......
......@@ -73,10 +73,10 @@ void DiagGradKernel(const Context& dev_ctx,
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
if (dx_dims.size() <= 1) {
auto dx_length = (dx_dims.size() == 1 ? dx_dims[0] : int64_t(1));
auto size = (offset > 0) ? dx_length + offset : dx_length - offset;
int dx_stride = phi::funcs::ComputeStride(0, dx_dims);
int dx_stride = 1;
if (size > 0) {
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
......
......@@ -77,13 +77,13 @@ void DiagKernel(const Context& dev_ctx,
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (x_dims.size() == 1) {
if (x_dims.size() <= 1) {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
auto x_length = (x_dims.size() == 1UL ? x_dims[0] : int64_t(1));
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = phi::funcs::ComputeStride(0, x_dims);
const int& x_stride = 1;
if (size > 0) {
const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
const auto& out_stride_1 = phi::funcs::ComputeStride(1, out_dims);
......
......@@ -653,6 +653,35 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
self.assertEqual(out.grad.shape, [2, 3])
def test_diagflat(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x3 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
x3.stop_gradient = False
out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0)
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(out1.shape, [2, 2])
self.assertEqual(out2.shape, [2, 2])
self.assertEqual(out3.shape, [1, 1])
self.assertEqual(out1.grad.shape, [2, 2])
self.assertEqual(out2.grad.shape, [2, 2])
self.assertEqual(out3.grad.shape, [1, 1])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x3.grad.shape, [])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -796,6 +825,26 @@ class TestSundryAPIStatic(unittest.TestCase):
for i in range(3):
self.assertEqual(res[0][1][i], 4)
@prog_scope()
def test_diagflat(self):
x1 = paddle.rand([])
out1 = paddle.diagflat(x1, 1)
paddle.static.append_backward(out1)
x2 = paddle.rand([])
out2 = paddle.diagflat(x2, -1)
paddle.static.append_backward(out2)
x3 = paddle.rand([])
out3 = paddle.diagflat(x3)
paddle.static.append_backward(out3)
prog = paddle.static.default_main_program()
res1, res2, res3 = self.exe.run(prog, fetch_list=[out1, out2, out3])
self.assertEqual(res1.shape, (2, 2))
self.assertEqual(res2.shape, (2, 2))
self.assertEqual(res3.shape, (1, 1))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -475,6 +475,35 @@ class TestSundryAPI(unittest.TestCase):
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
def test_diagflat(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x3 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
x3.stop_gradient = False
out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0)
out1.backward()
out2.backward()
out3.backward()
self.assertEqual(out1.shape, [2, 2])
self.assertEqual(out2.shape, [2, 2])
self.assertEqual(out3.shape, [1, 1])
self.assertEqual(out1.grad.shape, [2, 2])
self.assertEqual(out2.grad.shape, [2, 2])
self.assertEqual(out3.grad.shape, [1, 1])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x3.grad.shape, [])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -1479,7 +1479,7 @@ def diagflat(x, offset=0, name=None):
"""
padding_value = 0
if in_dygraph_mode():
if len(x.shape) == 1:
if len(x.shape) <= 1:
return _C_ops.diag(x, offset, padding_value)
else:
y = _C_ops.flatten(x, 0, -1)
......@@ -1509,7 +1509,7 @@ def diagflat(x, offset=0, name=None):
out1_shape = helper.create_variable_for_type_inference(x.dtype)
out2 = helper.create_variable_for_type_inference(dtype=x.dtype)
if len(x.shape) == 1:
if len(x.shape) <= 1:
helper.append_op(
type='diag_v2',
inputs={'X': x},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册