From 7f49b9ba4c610d5e85ce619e4efd0aa321c71003 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 24 Aug 2022 17:14:51 +0800 Subject: [PATCH] Adapt tensor axis for cumsum (#45372) --- paddle/fluid/operators/cum_op.cc | 12 +++- paddle/phi/api/yaml/legacy_api.yaml | 4 +- paddle/phi/api/yaml/legacy_backward.yaml | 4 +- paddle/phi/infermeta/unary.cc | 9 +++ paddle/phi/infermeta/unary.h | 7 +++ paddle/phi/kernels/cpu/cum_kernel.cc | 4 +- paddle/phi/kernels/cum_kernel.h | 3 +- paddle/phi/kernels/gpu/cum_kernel.cu | 4 +- .../fluid/tests/unittests/test_cumsum_op.py | 62 +++++++++++++++++++ 9 files changed, 98 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 54e7a374338..09d3f1dbe74 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -24,6 +24,13 @@ namespace operators { class CumOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { @@ -34,7 +41,8 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("axis", "The dimension to accumulate along. -1 means the last " "dimension [default -1].") - .SetDefault(-1); + .SetDefault(-1) + .SupportTensor(); AddAttr("flatten", "Whether to compute the cumsum over the flattened array. " "[default false].") @@ -148,7 +156,7 @@ namespace ops = paddle::operators; using CPU = phi::CPUContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, - PD_INFER_META(phi::CumInferMeta)); + PD_INFER_META(phi::CumScalarAxisInferMeta)); DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor, PD_INFER_META(phi::CumInferMeta)); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index a2bbb28a34c..e68e22965d9 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -638,10 +638,10 @@ backward : cumprod_grad - api : cumsum - args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) + args : (Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) output : Tensor(out) infer_meta : - func : CumInferMeta + func : CumScalarAxisInferMeta kernel : func : cumsum backward : cumsum_grad diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6dfbdba0c4f..d731b8c4492 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -591,11 +591,11 @@ func : cumprod_grad - backward_api : cumsum_grad - forward : cumsum(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) + forward : cumsum(Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) infer_meta : func : UnchangedInferMeta param : [x] - args : (Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse) + args : (Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse) output : Tensor(x_grad) invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c0b6c2b0e22..8eba84a24e0 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -405,6 +405,15 @@ void CumInferMeta(const MetaTensor& x, out->share_lod(x); } +void CumScalarAxisInferMeta(const MetaTensor& x, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse, + MetaTensor* out) { + CumInferMeta(x, axis.to(), flatten, exclusive, reverse, out); +} + void CropTensorInferMeta(const MetaTensor& x, const IntArray& shape, const IntArray& offsets, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6350ef4a970..1479bcf6493 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -95,6 +95,13 @@ void CumInferMeta(const MetaTensor& x, bool reverse, MetaTensor* out); +void CumScalarAxisInferMeta(const MetaTensor& x, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse, + MetaTensor* out); + void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index cd171cc8fc5..2b6a9be371a 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -135,7 +135,7 @@ void ScanKernel(const Context& dev_ctx, template void CumsumKernel(const Context& dev_ctx, const DenseTensor& x, - int axis, + const Scalar& axis, bool flatten, bool exclusive, bool reverse, @@ -143,7 +143,7 @@ void CumsumKernel(const Context& dev_ctx, using Reducer = Eigen::internal::SumReducer; auto reducer = Reducer(); ScanKernel( - dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); + dev_ctx, x, axis.to(), flatten, exclusive, reverse, reducer, out); } template diff --git a/paddle/phi/kernels/cum_kernel.h b/paddle/phi/kernels/cum_kernel.h index 38cdbd7787b..870a305573b 100644 --- a/paddle/phi/kernels/cum_kernel.h +++ b/paddle/phi/kernels/cum_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -21,7 +22,7 @@ namespace phi { template void CumsumKernel(const Context& dev_ctx, const DenseTensor& x, - int axis, + const Scalar& axis, bool flatten, bool exclusive, bool reverse, diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 40d7f74379f..1db74770a7d 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -353,7 +353,7 @@ void ScanKernel(const Context& dev_ctx, template void CumsumKernel(const Context& dev_ctx, const DenseTensor& x, - int axis, + const Scalar& axis, bool flatten, bool exclusive, bool reverse, @@ -361,7 +361,7 @@ void CumsumKernel(const Context& dev_ctx, using Op = cub::Sum; auto op = Op(); ScanKernel( - dev_ctx, x, axis, flatten, exclusive, reverse, op, out); + dev_ctx, x, axis.to(), flatten, exclusive, reverse, op, out); } template diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index dfcd34cebfe..2f66cd80dde 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -14,13 +14,16 @@ from __future__ import print_function +import os import unittest +import tempfile import numpy as np from op_test import OpTest import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +import paddle.inference as paddle_infer class TestCumsumOp(unittest.TestCase): @@ -318,5 +321,64 @@ class BadInputTest(unittest.TestCase): self.assertRaises(TypeError, test_bad_x) +class TestTensorAxis(unittest.TestCase): + + def setUp(self): + paddle.seed(2022) + self.temp_dir = tempfile.TemporaryDirectory() + self.save_path = os.path.join(self.temp_dir.name, 'tensor_axis_cumsum') + self.place = paddle.CUDAPlace( + 0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + + def test_dygraph(self): + paddle.disable_static() + x = np.random.randn(5, 6) + axis = 1 + np_out = np.cumsum(x, axis) + pd_out = paddle.cumsum(paddle.to_tensor(x), + axis=paddle.to_tensor([axis], dtype='int32')) + np.testing.assert_allclose(np_out, pd_out.numpy()) + + def test_static_and_infer(self): + paddle.enable_static() + np_x = np.random.randn(9, 10, 11).astype('float32') + main_prog = paddle.static.Program() + starup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, starup_prog): + # run static + x = paddle.static.data(shape=np_x.shape, name='x', dtype=np_x.dtype) + print(x) + linear = paddle.nn.Linear(np_x.shape[-1], np_x.shape[-1]) + linear_out = linear(x) + relu_out = paddle.nn.functional.relu(linear_out) + axis = paddle.full([1], 2, dtype='int64') + out = paddle.cumsum(relu_out, axis=axis) + + exe = paddle.static.Executor(self.place) + exe.run(starup_prog) + static_out = exe.run(feed={'x': np_x}, fetch_list=[out]) + + # run infer + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + config = paddle_infer.Config(self.save_path + '.pdmodel', + self.save_path + '.pdiparams') + if paddle.is_compiled_with_cuda(): + config.enable_use_gpu(100, 0) + else: + config.disable_gpu() + + predictor = paddle_infer.create_predictor(config) + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + fake_input = np_x + input_handle.reshape(np_x.shape) + input_handle.copy_from_cpu(fake_input) + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + infer_out = output_handle.copy_to_cpu() + np.testing.assert_allclose(static_out[0], infer_out) + + if __name__ == '__main__': unittest.main() -- GitLab