diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 6d1089ecf72a44850cfdd9e7cfd666fc15b5e7de..4c23020413ee5fbabbec88ce81439ce821df4008 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -33,6 +33,27 @@ class CumOp : public framework::OperatorWithKernel { } }; +class CumGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "cumsum"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), + "Input", + "Out@GRAD", + "cumsum"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return phi::KernelKey(input_data_type, ctx.GetPlace()); + } +}; + class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -69,12 +90,13 @@ class CumsumGradMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("cumsum"); - grad_op->SetInput("X", this->OutputGrad("Out")); - grad_op->SetOutput("Out", this->InputGrad("X")); + grad_op->SetType("cumsum_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttr("reverse", - !PADDLE_GET_CONST(bool, this->GetAttr("reverse"))); + PADDLE_GET_CONST(bool, this->GetAttr("reverse"))); } }; @@ -153,6 +175,7 @@ using CPU = phi::CPUContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, PD_INFER_META(phi::CumScalarAxisInferMeta)); + DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor, PD_INFER_META(phi::CumInferMeta)); @@ -169,6 +192,7 @@ REGISTER_OPERATOR(logcumsumexp, ops::LogcumsumexpGradMaker, LogcumsumexpInferShapeFunctor); REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp); +REGISTER_OPERATOR(cumsum_grad, ops::CumGradOp); REGISTER_OP_VERSION(cumsum).AddCheckpoint( R"ROC( diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 10ca2aee865afd203ee690dde0ac06ba45463423..f47e206c7ce2fe2742529382ef18092f92571cde 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -316,9 +316,14 @@ - backward_op : cumsum_grad forward : cumsum(Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) - args : (Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse) + args : (Tensor x, Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse) output : Tensor(x_grad) - invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : cumsum_grad + data_type: x - backward_op : deformable_conv_grad forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 04373fa29edf9e58661d8c52cf85d08be0ceaf13..5a7b2cf16a1f8cdc896da44193befe43674b23cd 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -424,10 +424,36 @@ void CumInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim({phi::product(x_dims)})); out->set_dtype(x.dtype()); } else { + if (x_dims.size() > 0) { + PADDLE_ENFORCE_GE( + axis, + -x_dims.size(), + phi::errors::OutOfRange( + "axis is out of range (expected to be in range of [%ld, " + "%ld), but got %ld).", + -(x_dims.size()), + x_dims.size(), + axis)); + PADDLE_ENFORCE_LT( + axis, + x_dims.size(), + phi::errors::OutOfRange( + "axis is out of range (expected to be in range of [%ld, " + "%ld), but got %ld).", + -(x_dims.size()), + x_dims.size(), + axis)); + } else { + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), + true, + errors::InvalidArgument("The axis must be -1 or 0 in 0D Tensor, " + "but the value given is %d.", + axis)); + } out->set_dims(x_dims); out->set_dtype(x.dtype()); } - out->share_lod(x); } diff --git a/paddle/phi/kernels/cpu/cum_grad_kernel.cc b/paddle/phi/kernels/cpu/cum_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..32be44661348fbf8e3e4e6713637af88c89fe560 --- /dev/null +++ b/paddle/phi/kernels/cpu/cum_grad_kernel.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_grad_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +namespace phi { + +template +void CumsumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* x_grad) { + x_grad->Resize(x.dims()); + CumsumKernel( + dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cumsum_grad, + CPU, + ALL_LAYOUT, + phi::CumsumGradKernel, + float, + double, + int16_t, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 2b6a9be371afb6f2c64355e241c2611571bca395..f7ec5bbbf9e844fc51538641e180018c32b9e438 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -57,6 +57,14 @@ void ScanKernel(const Context& dev_ctx, bool reverse, Reducer reducer, DenseTensor* out) { + dev_ctx.template Alloc(out); + + if (x.numel() == 1) { + auto raw_dims = out->dims(); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + out->Resize(raw_dims); + return; + } auto out_dims = out->dims(); PADDLE_ENFORCE_EQ( @@ -72,8 +80,6 @@ void ScanKernel(const Context& dev_ctx, axis += out_dims.size(); } - dev_ctx.template Alloc(out); - int pre = 1; int post = 1; int mid = out_dims[axis]; diff --git a/paddle/phi/kernels/cum_grad_kernel.h b/paddle/phi/kernels/cum_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f2428524fe5a2fa87bbe20ce417538e708fa8ab4 --- /dev/null +++ b/paddle/phi/kernels/cum_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CumsumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/cum_grad_kernel.cu b/paddle/phi/kernels/gpu/cum_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..6039c313b78ed007f74c093111b76a0efe399f30 --- /dev/null +++ b/paddle/phi/kernels/gpu/cum_grad_kernel.cu @@ -0,0 +1,75 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_grad_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" + +#include +#include +#include +#include +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CumsumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* x_grad) { + x_grad->Resize(x.dims()); + CumsumKernel( + dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(cumsum_grad, + GPU, + ALL_LAYOUT, + phi::CumsumGradKernel, + float, + double, + int16_t, + int, + int64_t) {} +#else +PD_REGISTER_KERNEL(cumsum_grad, + GPU, + ALL_LAYOUT, + phi::CumsumGradKernel, + float, + double, + int16_t, + int, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 0c6cd8b5562af4897238e37dc77901224d6a0621..9bf06d7bf19dcd763263447629d3521516bcf736 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -270,6 +270,16 @@ void ScanKernel(const Context& dev_ctx, bool reverse, Op op, DenseTensor* out) { + T* out_data = dev_ctx.template Alloc(out); + + // For 0D Tensor + if (out->numel() == 1) { + auto raw_dims = out->dims(); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + out->Resize(raw_dims); + return; + } + auto out_dims = out->dims(); auto size = x.numel(); @@ -286,7 +296,6 @@ void ScanKernel(const Context& dev_ctx, axis += out_dims.size(); } - T* out_data = dev_ctx.template Alloc(out); const T* in_data = x.data(); // Use thrust for parallel acceleration when the input size is equal to the diff --git a/paddle/phi/kernels/xpu/cum_grad_kernel.cc b/paddle/phi/kernels/xpu/cum_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b11ba47f0a79c708e9662cd158f024188cbb8f3 --- /dev/null +++ b/paddle/phi/kernels/xpu/cum_grad_kernel.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +namespace phi { + +template +void CumsumGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* x_grad) { + x_grad->Resize(x.dims()); + CumsumKernel( + dev_ctx, out_grad, axis, flatten, exclusive, !reverse, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + cumsum_grad, XPU, ALL_LAYOUT, phi::CumsumGradKernel, float, int, int64_t) {} diff --git a/paddle/phi/kernels/xpu/cum_kernel.cc b/paddle/phi/kernels/xpu/cum_kernel.cc index 17eca4008607e65457ccae3b813dc43d5e92ac44..13a1dab66d72f28665ce2d27558230a37e457a0a 100644 --- a/paddle/phi/kernels/xpu/cum_kernel.cc +++ b/paddle/phi/kernels/xpu/cum_kernel.cc @@ -30,6 +30,15 @@ void CumsumKernel(const Context& dev_ctx, using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); + if (x.numel() == 1) { + int r = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + return; + } + // prepare for call xdnn api std::vector x_shape = phi::vectorize(x.dims()); int axis_as_int = axis.to(); diff --git a/paddle/phi/ops/compat/cumsum_sig.cc b/paddle/phi/ops/compat/cumsum_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..00992b15435d2153ccd38d95689ce9e1ee9f31bc --- /dev/null +++ b/paddle/phi/ops/compat/cumsum_sig.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CumsumOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("cumsum_grad", + {"X", "Out@GRAD"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(cumsum_grad, phi::CumsumOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index 4b0cae035b06dd2c5552f84bfe136a311c1a3055..3e21537cdecf01840980172dd1b502b168fb12b9 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -16,15 +16,12 @@ import os import tempfile import unittest -import gradient_checker import numpy as np -from decorator_helper import prog_scope from op_test import OpTest import paddle import paddle.fluid as fluid import paddle.fluid.core as core -import paddle.fluid.layers as layers import paddle.inference as paddle_infer @@ -230,7 +227,7 @@ class TestSumOpExclusive1(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True} - a = np.random.random((4, 5, 65)).astype("float64") + a = np.random.random((4, 5, 20)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( @@ -245,12 +242,15 @@ class TestSumOpExclusive1(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestSumOpExclusive2(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True} - a = np.random.random((1, 1, 888)).astype("float64") + a = np.random.random((1, 1, 100)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( @@ -265,12 +265,15 @@ class TestSumOpExclusive2(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestSumOpExclusive3(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True} - a = np.random.random((4, 5, 888)).astype("float32") + a = np.random.random((4, 5, 20)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( @@ -285,12 +288,15 @@ class TestSumOpExclusive3(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestSumOpExclusive4(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True} - a = np.random.random((1, 1, 3049)).astype("float64") + a = np.random.random((1, 1, 100)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( @@ -305,12 +311,15 @@ class TestSumOpExclusive4(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestSumOpExclusive5(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True} - a = np.random.random((4, 5, 3096)).astype("float64") + a = np.random.random((4, 5, 40)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( @@ -325,12 +334,15 @@ class TestSumOpExclusive5(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestSumOpExclusiveFP16(OpTest): def setUp(self): self.op_type = "cumsum" self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"} - a = np.random.random((4, 5, 3096)).astype("float64") + a = np.random.random((4, 5, 20)).astype("float64") self.inputs = {'X': a} self.outputs = { 'Out': np.concatenate( @@ -345,6 +357,9 @@ class TestSumOpExclusiveFP16(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestSumOpReverseExclusive(OpTest): def setUp(self): @@ -366,6 +381,9 @@ class TestSumOpReverseExclusive(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class BadInputTest(unittest.TestCase): def test_error(self): @@ -407,7 +425,6 @@ class TestTensorAxis(unittest.TestCase): with paddle.static.program_guard(main_prog, starup_prog): # run static x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype) - print(x) linear = paddle.nn.Linear(np_x.shape[-1], np_x.shape[-1]) linear_out = linear(x) relu_out = paddle.nn.functional.relu(linear_out) @@ -444,67 +461,5 @@ class TestTensorAxis(unittest.TestCase): np.testing.assert_allclose(static_out[0], infer_out) -class TestCumsumDoubleGradCheck(unittest.TestCase): - def cumsum_wrapper(self, x): - return paddle.cumsum(x[0], 0) - - @prog_scope() - def func(self, place): - # the shape of input variable should be clearly specified, not inlcude -1. - eps = 0.005 - dtype = np.float64 - - data = layers.data('data', [3, 4], False, dtype) - data.persistable = True - out = paddle.cumsum(data, 0) - data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype) - - gradient_checker.double_grad_check( - [data], out, x_init=[data_arr], place=place, eps=eps - ) - gradient_checker.double_grad_check_for_dygraph( - self.cumsum_wrapper, [data], out, x_init=[data_arr], place=place - ) - - def test_grad(self): - paddle.enable_static() - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for p in places: - self.func(p) - - -class TestCumsumTripleGradCheck(unittest.TestCase): - def cumsum_wrapper(self, x): - return paddle.cumsum(x[0], 0) - - @prog_scope() - def func(self, place): - # the shape of input variable should be clearly specified, not inlcude -1. - eps = 0.005 - dtype = np.float32 - - data = layers.data('data', [2, 3], False, dtype) - data.persistable = True - out = paddle.cumsum(data, 0) - data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype) - - gradient_checker.triple_grad_check( - [data], out, x_init=[data_arr], place=place, eps=eps - ) - gradient_checker.triple_grad_check_for_dygraph( - self.cumsum_wrapper, [data], out, x_init=[data_arr], place=place - ) - - def test_grad(self): - paddle.enable_static() - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for p in places: - self.func(p) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 619579df1ec3ac909c5bcd4a213475114edb3536..cc7a257e4c4726cbbdfdadc4af90c653a8724ac9 100755 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -954,6 +954,34 @@ class TestSundryAPI(unittest.TestCase): np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) + def test_cumsum(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + + out1 = paddle.cumsum(x1) + out2 = paddle.cumsum(x1, axis=0) + out3 = paddle.cumsum(x1, axis=-1) + + out1.retain_grads() + out2.retain_grads() + out3.retain_grads() + + out1.backward() + out2.backward() + out3.backward() + + self.assertEqual(x1.grad.shape, []) + self.assertTrue(x1.grad.numpy() == 3) + self.assertEqual(out1.shape, [1]) + self.assertEqual(out1.grad.shape, [1]) + self.assertTrue(out1.grad.numpy() == 1) + self.assertEqual(out2.shape, []) + self.assertEqual(out2.grad.shape, []) + self.assertTrue(out2.grad.numpy() == 1) + self.assertEqual(out3.shape, []) + self.assertEqual(out3.grad.shape, []) + self.assertTrue(out3.grad.numpy() == 1) + def test_add_n(self): x1 = paddle.rand([]) x1.stop_gradient = False @@ -1674,6 +1702,45 @@ class TestSundryAPIStatic(unittest.TestCase): np.testing.assert_array_equal(out3_2, np.asarray(1)) @prog_scope() + def test_cumsum(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + + out1 = paddle.cumsum(x1) + out2 = paddle.cumsum(x1, axis=0) + out3 = paddle.cumsum(x1, axis=-1) + + paddle.static.append_backward(out1.sum()) + paddle.static.append_backward(out2.sum()) + paddle.static.append_backward(out3.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[ + out1, + out2, + out3, + x1.grad_name, + out1.grad_name, + out2.grad_name, + out3.grad_name, + ], + ) + self.assertEqual(res[0].shape, (1,)) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[3].shape, ()) + self.assertEqual(res[3], 1) + self.assertEqual(res[4].shape, (1,)) + self.assertEqual(res[4], 1) + self.assertEqual(res[5].shape, ()) + self.assertEqual(res[5], 1) + self.assertEqual(res[6].shape, ()) + self.assertEqual(res[6], 1) + self.assertEqual(out2.shape, ()) + self.assertEqual(out3.shape, ()) + def test_add_n(self): x1 = paddle.rand([]) x1.stop_gradient = False diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 770227e2e848795c655e6d25a7bae477c3501dc5..a9d95fc963ce338dd06787d13ea26514dc9b4855 100755 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -592,6 +592,29 @@ class TestSundryAPI(unittest.TestCase): np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) + def test_cumsum(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + + out1 = paddle.cumsum(x1) + out2 = paddle.cumsum(x1, axis=0) + out3 = paddle.cumsum(x1, axis=-1) + + out1.retain_grads() + out2.retain_grads() + out3.retain_grads() + + out1.backward() + out2.backward() + out3.backward() + + self.assertEqual(out1.shape, [1]) + self.assertEqual(out1.grad.shape, [1]) + self.assertEqual(out2.shape, []) + self.assertEqual(out2.grad.shape, []) + self.assertEqual(out3.shape, []) + self.assertEqual(out3.grad.shape, []) + def test_add_n(self): x1 = paddle.rand([]) x1.stop_gradient = False