From bf2db646de438b4da0e864f37c1a592033b14c47 Mon Sep 17 00:00:00 2001 From: LutaoChu <30695251+LutaoChu@users.noreply.github.com> Date: Mon, 10 Aug 2020 19:00:28 +0800 Subject: [PATCH] fix cumsum op for API 2.0, optimize performance update cumsum api and fix up the cumsum op --- paddle/fluid/operators/cum_op.h | 27 ++++--- paddle/fluid/operators/cumsum_op.cc | 18 ++++- paddle/fluid/operators/cumsum_op.cu | 66 +++++++++++---- .../fluid/tests/unittests/test_cumsum_op.py | 81 +++++++++++++++++++ python/paddle/tensor/math.py | 73 ++++++++++++++++- 5 files changed, 231 insertions(+), 34 deletions(-) mode change 100644 => 100755 paddle/fluid/operators/cumsum_op.cc diff --git a/paddle/fluid/operators/cum_op.h b/paddle/fluid/operators/cum_op.h index e336e25f0f4..ab3860ecafc 100644 --- a/paddle/fluid/operators/cum_op.h +++ b/paddle/fluid/operators/cum_op.h @@ -36,25 +36,28 @@ class CumKernel : public framework::OpKernel { int axis = context.Attr("axis"); bool exclusive = context.Attr("exclusive"); bool reverse = context.Attr("reverse"); - auto x_dims = X.dims(); - if (axis == -1) { - axis = x_dims.size() - 1; + auto out_dims = Out.dims(); + + PADDLE_ENFORCE_EQ( + axis < out_dims.size() && axis >= (0 - out_dims.size()), true, + platform::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + out_dims.size(), out_dims.size() - 1, axis)); + if (axis < 0) { + axis += out_dims.size(); } - PADDLE_ENFORCE_LT( - axis, x_dims.size(), - platform::errors::InvalidArgument("axis(%d) should be less than the " - "dimension(%d) of the input tensor.", - axis, x_dims.size())); + Out.template mutable_data(context.GetPlace()); int pre = 1; int post = 1; - int mid = x_dims[axis]; + int mid = out_dims[axis]; for (int i = 0; i < axis; ++i) { - pre *= x_dims[i]; + pre *= out_dims[i]; } - for (int i = axis + 1; i < x_dims.size(); ++i) { - post *= x_dims[i]; + for (int i = axis + 1; i < out_dims.size(); ++i) { + post *= out_dims[i]; } auto x = framework::EigenVector::Flatten(X); diff --git a/paddle/fluid/operators/cumsum_op.cc b/paddle/fluid/operators/cumsum_op.cc old mode 100644 new mode 100755 index 962d73d0689..2e9db16be55 --- a/paddle/fluid/operators/cumsum_op.cc +++ b/paddle/fluid/operators/cumsum_op.cc @@ -22,7 +22,14 @@ class CumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (ctx->Attrs().Get("flatten")) { + ctx->SetOutputDim( + "Out", + framework::make_ddim({framework::product(ctx->GetInputDim("X"))})); + } else { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -35,8 +42,11 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("axis", "The dimension to accumulate along. -1 means the last " "dimension [default -1].") - .SetDefault(-1) - .EqualGreaterThan(-1); + .SetDefault(-1); + AddAttr("flatten", + "Whether to compute the cumsum over the flattened array. " + "[default false].") + .SetDefault(false); AddAttr("exclusive", "Whether to perform exclusive cumsum. [default false].") .SetDefault(false); @@ -63,6 +73,8 @@ class CumsumGradMaker : public framework::SingleGradOpMaker { grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis"))); + grad_op->SetAttr("flatten", + BOOST_GET_CONST(bool, this->GetAttr("flatten"))); grad_op->SetAttr("reverse", !BOOST_GET_CONST(bool, this->GetAttr("reverse"))); grad_op->SetAttr("exclusive", diff --git a/paddle/fluid/operators/cumsum_op.cu b/paddle/fluid/operators/cumsum_op.cu index 7ca5ba3289b..cff0a101e03 100644 --- a/paddle/fluid/operators/cumsum_op.cu +++ b/paddle/fluid/operators/cumsum_op.cu @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include +#include +#include #include "paddle/fluid/operators/cum_op.h" #include "paddle/fluid/platform/gpu_launch_param_config.h" @@ -251,34 +255,62 @@ class CumCUDAKernel : public framework::OpKernel { int axis = context.Attr("axis"); bool exclusive = context.Attr("exclusive"); bool reverse = context.Attr("reverse"); - auto in_dims = in->dims(); + auto out_dims = out->dims(); auto size = in->numel(); - if (axis == -1) { - axis = in_dims.size() - 1; + PADDLE_ENFORCE_EQ( + axis < out_dims.size() && axis >= (0 - out_dims.size()), true, + platform::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + out_dims.size(), out_dims.size() - 1, axis)); + if (axis < 0) { + axis += out_dims.size(); } - PADDLE_ENFORCE_LT( - axis, in_dims.size(), - platform::errors::InvalidArgument("axis(%d) should be less than the " - "dimension(%d) of the input tensor.", - axis, in_dims.size())); - - int scan_dim_size = in_dims[axis]; - bool optimize_condition = (axis == (in_dims.size() - 1)) ? true : false; + + T* out_data = out->mutable_data(context.GetPlace()); + const T* in_data = in->data(); + + // Use thrust for parallel acceleration when the input size is equal to the + // length of the ‘axis’ dimension. + if (size == out_dims[axis]) { + if (reverse) { + thrust::device_ptr dev_ptr = + thrust::device_pointer_cast(in_data); + thrust::device_vector vec(dev_ptr, dev_ptr + size); + if (exclusive) { + thrust::exclusive_scan(thrust::device, vec.rbegin(), vec.rend(), + out_data); + } else { + thrust::inclusive_scan(thrust::device, vec.rbegin(), vec.rend(), + out_data); + } + thrust::reverse(thrust::device, out_data, out_data + size); + } else { + if (exclusive) { + thrust::exclusive_scan(thrust::device, in_data, in_data + size, + out_data); + } else { + thrust::inclusive_scan(thrust::device, in_data, in_data + size, + out_data); + } + } + return; + } + + const int& scan_dim_size = out_dims[axis]; + bool optimize_condition = (axis == (out_dims.size() - 1)) ? true : false; int outer_dim_size = 1; int inner_dim_size = 1; // treat all dim index < axis as outer_dim_size for (size_t i = 0; i < axis; i++) { - outer_dim_size *= in_dims[i]; + outer_dim_size *= out_dims[i]; } // treat all dim index > axis as innner_dim_size - for (size_t i = axis + 1; i < in_dims.size(); i++) { - inner_dim_size *= in_dims[i]; + for (size_t i = axis + 1; i < out_dims.size(); i++) { + inner_dim_size *= out_dims[i]; } - T* out_data = out->mutable_data(context.GetPlace()); - const T* in_data = in->data(); - auto& dev_ctx = context.template device_context(); if (optimize_condition) { auto nextPowerOfTwo = [](int x) -> int { diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index a1a80bfdb54..c3283324bdc 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -17,9 +17,90 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +from paddle.imperative import to_variable + + +class TestCumsumOp(unittest.TestCase): + def run_cases(self): + data_np = np.arange(12).reshape(3, 4) + data = to_variable(data_np) + + y = paddle.cumsum(data) + z = np.cumsum(data_np) + self.assertTrue(np.array_equal(z, y.numpy())) + + y = paddle.cumsum(data, axis=0) + z = np.cumsum(data_np, axis=0) + self.assertTrue(np.array_equal(z, y.numpy())) + + y = paddle.cumsum(data, axis=-1) + z = np.cumsum(data_np, axis=-1) + self.assertTrue(np.array_equal(z, y.numpy())) + + y = paddle.cumsum(data, dtype='float64') + self.assertTrue(y.dtype == core.VarDesc.VarType.FP64) + + y = paddle.cumsum(data, dtype=np.int32) + self.assertTrue(y.dtype == core.VarDesc.VarType.INT32) + + y = paddle.cumsum(data, axis=-2) + z = np.cumsum(data_np, axis=-2) + self.assertTrue(np.array_equal(z, y.numpy())) + + def run_static(self, use_gpu=False): + with fluid.program_guard(fluid.Program()): + data_np = np.random.random((100, 100)).astype(np.float32) + x = paddle.nn.data('X', [100, 100]) + y = paddle.cumsum(x) + y2 = paddle.cumsum(x, axis=0) + y3 = paddle.cumsum(x, axis=-1) + y4 = paddle.cumsum(x, dtype='float64') + y5 = paddle.cumsum(x, dtype=np.int32) + y6 = paddle.cumsum(x, axis=-2) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run(feed={'X': data_np}, + fetch_list=[ + y.name, y2.name, y3.name, y4.name, y5.name, + y6.name + ]) + + z = np.cumsum(data_np) + self.assertTrue(np.allclose(z, out[0])) + z = np.cumsum(data_np, axis=0) + self.assertTrue(np.allclose(z, out[1])) + z = np.cumsum(data_np, axis=-1) + self.assertTrue(np.allclose(z, out[2])) + self.assertTrue(out[3].dtype == np.float64) + self.assertTrue(out[4].dtype == np.int32) + z = np.cumsum(data_np, axis=-2) + self.assertTrue(np.allclose(z, out[5])) + + def test_cpu(self): + with paddle.imperative.guard(paddle.fluid.CPUPlace()): + self.run_cases() + + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + with paddle.imperative.guard(paddle.fluid.CUDAPlace(0)): + self.run_cases() + + self.run_static(use_gpu=True) + + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = paddle.nn.data('x', [3, 4]) + y = paddle.cumsum(x, name='out') + self.assertTrue('out' in y.name) class TestSumOp1(OpTest): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index f8fa29757d8..c67ac474d47 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -21,7 +21,7 @@ from ..fluid import layers from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype -from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn +from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn import sys # TODO: define math functions @@ -33,7 +33,6 @@ from ..fluid.layers import ceil #DEFINE_ALIAS from ..fluid.layers import cos #DEFINE_ALIAS from ..fluid.layers import sinh #DEFINE_ALIAS from ..fluid.layers import cosh #DEFINE_ALIAS -from ..fluid.layers import cumsum #DEFINE_ALIAS from ..fluid.layers import elementwise_add #DEFINE_ALIAS from ..fluid.layers import elementwise_div #DEFINE_ALIAS from ..fluid.layers import elementwise_floordiv #DEFINE_ALIAS @@ -1543,3 +1542,73 @@ ${comment} out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type="kron", inputs={"X": x, "Y": y}, outputs={"Out": out}) return out + + +def cumsum(x, axis=None, dtype=None, name=None): + """ + :alias_main: paddle.cumsum + :alias: paddle.cumsum,paddle.tensor.cumsum,paddle.tensor.math.cumsum + + The cumulative sum of the elements along a given axis. The first element of the result is the same of the first element of the input. + + Args: + x (Tensor): Input of cumsum operator, the Tensor needed to be cumsumed. + axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. + dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the result of cumsum operator, output of cumsum operator. + + Examples: + .. code-block:: python + + import paddle + from paddle.imperative import to_variable + import numpy as np + + paddle.enable_imperative() + data_np = np.arange(12).reshape(3, 4) + data = to_variable(data_np) + + y = paddle.cumsum(data) + print(y.numpy()) + # [ 0 1 3 6 10 15 21 28 36 45 55 66] + + y = paddle.cumsum(data, axis=0) + print(y.numpy()) + # [[ 0 1 2 3] + # [ 4 6 8 10] + # [12 15 18 21]] + + y = paddle.cumsum(data, axis=-1) + print(y.numpy()) + # [[ 0 1 3 6] + # [ 4 9 15 22] + # [ 8 17 27 38]] + + y = paddle.cumsum(data, dtype='float64') + print(y.dtype) + # VarType.FP64 + """ + if axis is None: + flatten = True + else: + flatten = False + if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): + x = layers.cast(x, dtype) + + if in_dygraph_mode(): + if axis is None: + return core.ops.cumsum(x, 'flatten', flatten) + else: + return core.ops.cumsum(x, 'axis', axis, 'flatten', flatten) + + check_type(x, 'x', (Variable), 'cumsum') + locals_var = locals().copy() + kwargs = dict() + for name, val in locals_var.items(): + if val is not None: + kwargs[name] = val + _cum_sum_ = generate_layer_fn('cumsum') + return _cum_sum_(**kwargs) -- GitLab