未验证 提交 9d90738c 编写于 作者: G GGBond8488 提交者: GitHub

add 0D support for trace (#53208)

* add 0D support for trace, test=allcase

* fix trace gpu kernel 0d error, test=allcase

* fix windows error, test=allcase
上级 a85e038a
......@@ -4402,7 +4402,6 @@ void TraceInferMeta(
auto sizes = vectorize(x_dims);
if (x_dims.size() == 2) {
sizes.clear();
sizes.push_back(1);
} else {
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
......
......@@ -32,7 +32,10 @@ void TraceKernel(const Context& ctx,
auto diag = funcs::Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) {
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
// Adapt to 0D output
auto out_dim_size = out->dims().size();
if (out_dim_size == 0) out_dim_size = 1;
reduce_dims.push_back(out_dim_size);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims);
} else {
......
......@@ -91,7 +91,8 @@ void TraceGradKernel(const Context& ctx,
auto input_dims = in_grad->dims();
auto input_stride = phi::stride(input_dims);
auto output_dims = out_grad.dims();
auto output_stride = phi::stride(output_dims);
auto output_stride = output_dims.size() == 0 ? phi::DDim(output_dims)
: phi::stride(output_dims);
auto* out_data = out_grad.data<T>();
T* x_data = ctx.template Alloc<T>(in_grad);
......
......@@ -2437,6 +2437,16 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
out.backward()
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, np.array(12))
self.assertEqual(x.grad.shape, [2, 2])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -4426,6 +4436,20 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[2].shape, (4, 5))
self.assertEqual(res[3].shape, (5,))
@prog_scope()
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 2))
np.testing.assert_allclose(res[0], np.array(12))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册