未验证 提交 bf2db646 编写于 作者: L LutaoChu 提交者: GitHub

fix cumsum op for API 2.0, optimize performance

update cumsum api and fix up the cumsum op
上级 9c611210
......@@ -36,25 +36,28 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse");
auto x_dims = X.dims();
if (axis == -1) {
axis = x_dims.size() - 1;
auto out_dims = Out.dims();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()), true,
platform::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(), out_dims.size() - 1, axis));
if (axis < 0) {
axis += out_dims.size();
}
PADDLE_ENFORCE_LT(
axis, x_dims.size(),
platform::errors::InvalidArgument("axis(%d) should be less than the "
"dimension(%d) of the input tensor.",
axis, x_dims.size()));
Out.template mutable_data<T>(context.GetPlace());
int pre = 1;
int post = 1;
int mid = x_dims[axis];
int mid = out_dims[axis];
for (int i = 0; i < axis; ++i) {
pre *= x_dims[i];
pre *= out_dims[i];
}
for (int i = axis + 1; i < x_dims.size(); ++i) {
post *= x_dims[i];
for (int i = axis + 1; i < out_dims.size(); ++i) {
post *= out_dims[i];
}
auto x = framework::EigenVector<T>::Flatten(X);
......
......@@ -22,7 +22,14 @@ class CumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<bool>("flatten")) {
ctx->SetOutputDim(
"Out",
framework::make_ddim({framework::product(ctx->GetInputDim("X"))}));
} else {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......@@ -35,8 +42,11 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("axis",
"The dimension to accumulate along. -1 means the last "
"dimension [default -1].")
.SetDefault(-1)
.EqualGreaterThan(-1);
.SetDefault(-1);
AddAttr<bool>("flatten",
"Whether to compute the cumsum over the flattened array. "
"[default false].")
.SetDefault(false);
AddAttr<bool>("exclusive",
"Whether to perform exclusive cumsum. [default false].")
.SetDefault(false);
......@@ -63,6 +73,8 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
BOOST_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse",
!BOOST_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
......
......@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
......@@ -251,34 +255,62 @@ class CumCUDAKernel : public framework::OpKernel<T> {
int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse");
auto in_dims = in->dims();
auto out_dims = out->dims();
auto size = in->numel();
if (axis == -1) {
axis = in_dims.size() - 1;
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()), true,
platform::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(), out_dims.size() - 1, axis));
if (axis < 0) {
axis += out_dims.size();
}
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument("axis(%d) should be less than the "
"dimension(%d) of the input tensor.",
axis, in_dims.size()));
int scan_dim_size = in_dims[axis];
bool optimize_condition = (axis == (in_dims.size() - 1)) ? true : false;
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* in_data = in->data<T>();
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
if (size == out_dims[axis]) {
if (reverse) {
thrust::device_ptr<const T> dev_ptr =
thrust::device_pointer_cast(in_data);
thrust::device_vector<T> vec(dev_ptr, dev_ptr + size);
if (exclusive) {
thrust::exclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
out_data);
} else {
thrust::inclusive_scan(thrust::device, vec.rbegin(), vec.rend(),
out_data);
}
thrust::reverse(thrust::device, out_data, out_data + size);
} else {
if (exclusive) {
thrust::exclusive_scan(thrust::device, in_data, in_data + size,
out_data);
} else {
thrust::inclusive_scan(thrust::device, in_data, in_data + size,
out_data);
}
}
return;
}
const int& scan_dim_size = out_dims[axis];
bool optimize_condition = (axis == (out_dims.size() - 1)) ? true : false;
int outer_dim_size = 1;
int inner_dim_size = 1;
// treat all dim index < axis as outer_dim_size
for (size_t i = 0; i < axis; i++) {
outer_dim_size *= in_dims[i];
outer_dim_size *= out_dims[i];
}
// treat all dim index > axis as innner_dim_size
for (size_t i = axis + 1; i < in_dims.size(); i++) {
inner_dim_size *= in_dims[i];
for (size_t i = axis + 1; i < out_dims.size(); i++) {
inner_dim_size *= out_dims[i];
}
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* in_data = in->data<T>();
auto& dev_ctx = context.template device_context<DeviceContext>();
if (optimize_condition) {
auto nextPowerOfTwo = [](int x) -> int {
......
......@@ -17,9 +17,90 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.imperative import to_variable
class TestCumsumOp(unittest.TestCase):
def run_cases(self):
data_np = np.arange(12).reshape(3, 4)
data = to_variable(data_np)
y = paddle.cumsum(data)
z = np.cumsum(data_np)
self.assertTrue(np.array_equal(z, y.numpy()))
y = paddle.cumsum(data, axis=0)
z = np.cumsum(data_np, axis=0)
self.assertTrue(np.array_equal(z, y.numpy()))
y = paddle.cumsum(data, axis=-1)
z = np.cumsum(data_np, axis=-1)
self.assertTrue(np.array_equal(z, y.numpy()))
y = paddle.cumsum(data, dtype='float64')
self.assertTrue(y.dtype == core.VarDesc.VarType.FP64)
y = paddle.cumsum(data, dtype=np.int32)
self.assertTrue(y.dtype == core.VarDesc.VarType.INT32)
y = paddle.cumsum(data, axis=-2)
z = np.cumsum(data_np, axis=-2)
self.assertTrue(np.array_equal(z, y.numpy()))
def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
x = paddle.nn.data('X', [100, 100])
y = paddle.cumsum(x)
y2 = paddle.cumsum(x, axis=0)
y3 = paddle.cumsum(x, axis=-1)
y4 = paddle.cumsum(x, dtype='float64')
y5 = paddle.cumsum(x, dtype=np.int32)
y6 = paddle.cumsum(x, axis=-2)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
out = exe.run(feed={'X': data_np},
fetch_list=[
y.name, y2.name, y3.name, y4.name, y5.name,
y6.name
])
z = np.cumsum(data_np)
self.assertTrue(np.allclose(z, out[0]))
z = np.cumsum(data_np, axis=0)
self.assertTrue(np.allclose(z, out[1]))
z = np.cumsum(data_np, axis=-1)
self.assertTrue(np.allclose(z, out[2]))
self.assertTrue(out[3].dtype == np.float64)
self.assertTrue(out[4].dtype == np.int32)
z = np.cumsum(data_np, axis=-2)
self.assertTrue(np.allclose(z, out[5]))
def test_cpu(self):
with paddle.imperative.guard(paddle.fluid.CPUPlace()):
self.run_cases()
self.run_static()
def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
with paddle.imperative.guard(paddle.fluid.CUDAPlace(0)):
self.run_cases()
self.run_static(use_gpu=True)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = paddle.nn.data('x', [3, 4])
y = paddle.cumsum(x, name='out')
self.assertTrue('out' in y.name)
class TestSumOp1(OpTest):
......
......@@ -21,7 +21,7 @@ from ..fluid import layers
from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn
from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn
import sys
# TODO: define math functions
......@@ -33,7 +33,6 @@ from ..fluid.layers import ceil #DEFINE_ALIAS
from ..fluid.layers import cos #DEFINE_ALIAS
from ..fluid.layers import sinh #DEFINE_ALIAS
from ..fluid.layers import cosh #DEFINE_ALIAS
from ..fluid.layers import cumsum #DEFINE_ALIAS
from ..fluid.layers import elementwise_add #DEFINE_ALIAS
from ..fluid.layers import elementwise_div #DEFINE_ALIAS
from ..fluid.layers import elementwise_floordiv #DEFINE_ALIAS
......@@ -1543,3 +1542,73 @@ ${comment}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="kron", inputs={"X": x, "Y": y}, outputs={"Out": out})
return out
def cumsum(x, axis=None, dtype=None, name=None):
"""
:alias_main: paddle.cumsum
:alias: paddle.cumsum,paddle.tensor.cumsum,paddle.tensor.math.cumsum
The cumulative sum of the elements along a given axis. The first element of the result is the same of the first element of the input.
Args:
x (Tensor): Input of cumsum operator, the Tensor needed to be cumsumed.
axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the result of cumsum operator, output of cumsum operator.
Examples:
.. code-block:: python
import paddle
from paddle.imperative import to_variable
import numpy as np
paddle.enable_imperative()
data_np = np.arange(12).reshape(3, 4)
data = to_variable(data_np)
y = paddle.cumsum(data)
print(y.numpy())
# [ 0 1 3 6 10 15 21 28 36 45 55 66]
y = paddle.cumsum(data, axis=0)
print(y.numpy())
# [[ 0 1 2 3]
# [ 4 6 8 10]
# [12 15 18 21]]
y = paddle.cumsum(data, axis=-1)
print(y.numpy())
# [[ 0 1 3 6]
# [ 4 9 15 22]
# [ 8 17 27 38]]
y = paddle.cumsum(data, dtype='float64')
print(y.dtype)
# VarType.FP64
"""
if axis is None:
flatten = True
else:
flatten = False
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = layers.cast(x, dtype)
if in_dygraph_mode():
if axis is None:
return core.ops.cumsum(x, 'flatten', flatten)
else:
return core.ops.cumsum(x, 'axis', axis, 'flatten', flatten)
check_type(x, 'x', (Variable), 'cumsum')
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
if val is not None:
kwargs[name] = val
_cum_sum_ = generate_layer_fn('cumsum')
return _cum_sum_(**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册