未验证 提交 50a8b655 编写于 作者: W wangzhen38 提交者: GitHub

[0 Tensor support] cumprod (#49550)

上级 1574a862
......@@ -82,7 +82,7 @@ class CumprodGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(cumprod,
CumprodInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
PD_INFER_META(phi::UnchangedInferMetaCheckAxis));
REGISTER_OPERATOR(cumprod,
ops::CumprodOp,
......
......@@ -441,8 +441,7 @@
args : (Tensor x, int dim)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param: [x]
func : UnchangedInferMetaCheckAxis
kernel :
func : cumprod
backward : cumprod_grad
......
......@@ -44,6 +44,10 @@ void CumprodGradKernel(const Context& dev_ctx,
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
if (shape.size() == 0) {
phi::Copy<Context>(dev_ctx, d_out, dev_ctx.GetPlace(), false, d_x);
return;
}
size_t numel = outer_dim * mid_dim * inner_dim;
// deal with complex
......
......@@ -37,6 +37,10 @@ void CumprodKernel(const Context& dev_ctx,
size_t mid_dim = 1;
size_t inner_dim = 1;
GetCumprodDimInfo(shape, dim, &outer_dim, &mid_dim, &inner_dim);
if (shape.size() == 0) {
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, out);
return;
}
for (size_t i = 0; i < outer_dim; i++) {
for (size_t j = 0; j < mid_dim; j++) {
......
......@@ -30,6 +30,18 @@ static void GetCumprodDimInfo(const DDim& dim,
"rank of input x which is %d.But received dim=%d",
-dim.size(),
cumprod_dim));
if (dim.size() == 0) {
PADDLE_ENFORCE_LE(
cumprod_dim,
dim.size(),
phi::errors::InvalidArgument(
"The input dim of CumprodOp should be smaller than the "
"rank of input x which is %d.But received dim=%d",
dim.size(),
cumprod_dim));
return;
}
PADDLE_ENFORCE_LT(cumprod_dim,
dim.size(),
phi::errors::InvalidArgument(
......
......@@ -136,6 +136,10 @@ void CumprodGradKernel(const Context &dev_ctx,
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x.dims(), dim, &outer_dim, &mid_dim, &inner_dim);
if (x.dims().size() == 0) {
phi::Copy<Context>(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
return;
}
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
size_t numel = outer_dim * mid_dim * inner_dim;
......
......@@ -32,6 +32,10 @@ void CumprodKernel(const Context &dev_ctx,
auto *y = out;
size_t outer_dim, mid_dim, inner_dim;
GetCumprodDimInfo(x->dims(), dim, &outer_dim, &mid_dim, &inner_dim);
if (x->dims().size() == 0) {
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, out);
return;
}
const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
......
......@@ -563,6 +563,18 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_cumprod(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
out = paddle.cumprod(x, 0)
out.backward()
with self.assertRaises(ValueError):
tmp = paddle.cumprod(x, 2)
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_clip(self):
x = paddle.uniform([], None, -10, 10)
x.stop_gradient = False
......@@ -994,6 +1006,19 @@ class TestSundryAPIStatic(unittest.TestCase):
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
@prog_scope()
def test_cumprod(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
out = paddle.cumprod(x, 0)
paddle.static.append_backward(out)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
with self.assertRaises(ValueError):
tmp = paddle.cumprod(x, 2)
self.assertEqual(res[0].shape, ())
@prog_scope()
def test_clip(self):
x = paddle.uniform([], None, -10, 10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册