diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 60644820df7cd4133c5fd8f24fe693245d68a5f3..e3b45d05d85e9da0d1112fe7dabd06f10225166d 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -47,6 +47,12 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, int16_t, INT16); \ _ForEachDataTypeHelper_(callback, int8_t, INT8) +#define _ForEachDataTypeSmall_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); + #define DefineDataTypeTrait(cpp_type, proto_type) \ template <> \ struct DataTypeTrait { \ @@ -75,6 +81,20 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { PADDLE_THROW("Not supported %d", type); } +template +inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) { +#define VisitDataTypeCallbackSmall(cpp_type, proto_type) \ + do { \ + if (type == proto_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _ForEachDataTypeSmall_(VisitDataTypeCallbackSmall); +#undef VisitDataTypeCallbackSmall +} + extern std::string DataTypeToString(const proto::VarType::Type type); extern size_t SizeOfType(proto::VarType::Type type); inline std::ostream& operator<<(std::ostream& out, diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index 7dc78270c21866866851a396c234bd6564064be4..66fea71c635441571e89f34f5ee52865eb431cb9 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -65,7 +65,8 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, int idx_y = threadIdx.x; Ty reduce_var = init; for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) - reduce_var = reducer(reduce_var, transformer(x[idx_x + idx_y])); + reduce_var = + reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); __syncthreads(); reduce_var = @@ -112,7 +113,8 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, int idx_x = 0; for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - reduce_var = static_cast(reducer(reduce_var, transformer(x[idx_x]))); + reduce_var = static_cast( + reducer(reduce_var, static_cast(transformer(x[idx_x])))); } __syncthreads(); @@ -341,5 +343,35 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y, #undef CUB_BLOCK_DIM_CASE } +template +struct TensorReduceFunctor { + const framework::Tensor& x; + framework::Tensor* y; + std::vector origin_reduce_dims; + const double& init; + const ReduceOp& reducer; + const TransformOp& transformer; + cudaStream_t stream; + TensorReduceFunctor(const framework::Tensor& x, framework::Tensor* y, + std::vector origin_reduce_dims, const double& init, + const ReduceOp& reducer, const TransformOp& transformer, + cudaStream_t stream) + : x(x), + y(y), + origin_reduce_dims(origin_reduce_dims), + init(init), + reducer(reducer), + transformer(transformer), + stream(stream) {} + + template + + void apply() const { + const Ty& init_cast = static_cast(init); + TensorReduce( + x, y, origin_reduce_dims, init_cast, reducer, transformer, stream); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc index 49d6e72988ee00edc947e1a3fe8bc16067627193..30265b3cc71fc6c587a7f4c716529962e1556f45 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc @@ -18,5 +18,5 @@ // compare and logical ops REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all, UseInputPlace); REGISTER_OP_CPU_KERNEL(reduce_all, - ops::ReduceKernel); + ops::BoolReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu index bd94ba263d957d0d65506ecd802bf43add6e2fb4..89f3345fcbe42deb572700cb12827d79cb22d3d3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cu @@ -14,6 +14,6 @@ #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" -REGISTER_OP_CUDA_KERNEL(reduce_all, - ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_all, ops::BoolReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc index 516d3183fd614128deec3fefaed4df089305c6c0..cbc18f18b8e5534b37294dbfb8630bac906e8066 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc @@ -18,5 +18,5 @@ // compare and logical ops REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any, UseInputPlace); REGISTER_OP_CPU_KERNEL(reduce_any, - ops::ReduceKernel); + ops::BoolReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu index 66f0c9997ea1e27cf172a6839a68d2eb23395c4d..c0f94098a351ea9042e44b8550b305bb0f9d74c6 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cu @@ -14,6 +14,6 @@ #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" -REGISTER_OP_CUDA_KERNEL(reduce_any, - ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_any, ops::BoolReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index d17e6b65cdf6e0bee70ce02ae8103e135980b200..a40df3a82716e189237f1ad31f64a2633671ab12 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" namespace paddle { @@ -25,27 +27,111 @@ namespace operators { #define HANDLE_DIM(NDIM, RDIM) \ if (ndim == NDIM && rdim == RDIM) { \ - ReduceFunctor( \ + ReduceFunctor( \ context.template device_context(), *input, output, \ dims, keep_dim); \ } +using Tensor = framework::Tensor; + +template +struct ReduceKernelFunctor { + const Tensor* input; + Tensor* output; + std::vector dims; + bool keep_dim; + bool reduce_all; + const framework::ExecutionContext& context; + ReduceKernelFunctor(const Tensor* input, Tensor* output, + const std::vector& dims, bool keep_dim, + bool reduce_all, + const framework::ExecutionContext& context) + : input(input), + output(output), + dims(dims), + keep_dim(keep_dim), + reduce_all(reduce_all), + context(context) {} + + template + void apply() const { + output->mutable_data(context.GetPlace()); + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = EigenVector::Flatten(*input); + auto out = EigenScalar::From(*output); + auto& place = + *context.template device_context().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + Functor functor; + functor(place, &x, &out, reduce_dim); + } else { + int ndim = input->dims().size(); + int rdim = dims.size(); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + HANDLE_DIM(1, 1); + } + } +}; template class ReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool reduce_all = context.Attr("reduce_all"); + auto* output = context.Output("Out"); + auto dims = context.Attr>("dim"); + bool keep_dim = context.Attr("keep_dim"); + int out_dtype = context.Attr("out_dtype"); + framework::proto::VarType::Type cast_out_dtype; + + if (out_dtype < 0) { + auto* cast_input = context.Input("X"); + cast_out_dtype = + static_cast(cast_input->type()); + framework::VisitDataType( + cast_out_dtype, + ReduceKernelFunctor( + cast_input, output, dims, keep_dim, reduce_all, context)); + } else { + Tensor tmp_tensor; + cast_out_dtype = static_cast(out_dtype); + auto* input = context.Input("X"); + + tmp_tensor.Resize(input->dims()); + framework::VisitDataType( + cast_out_dtype, + CastOpFunctor( + input, &tmp_tensor, + context.template device_context())); + framework::VisitDataType( + cast_out_dtype, + ReduceKernelFunctor( + &tmp_tensor, output, dims, keep_dim, reduce_all, context)); + } + } +}; + +template +class BoolReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { bool reduce_all = context.Attr("reduce_all"); auto* input = context.Input("X"); auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); + output->mutable_data(context.GetPlace()); auto dims = context.Attr>("dim"); bool keep_dim = context.Attr("keep_dim"); if (reduce_all) { // Flatten and reduce 1-D tensor - auto x = EigenVector::Flatten(*input); - auto out = EigenScalar::From(*output); + auto x = EigenVector::Flatten(*input); + auto out = EigenScalar::From(*output); auto& place = *context.template device_context().eigen_device(); auto reduce_dim = Eigen::array({{0}}); @@ -74,18 +160,17 @@ class ReduceKernel : public framework::OpKernel { } } }; - template class ReduceGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void ComputeFromInput(const Tensor* input2, + const framework::ExecutionContext& context) const { bool reduce_all = context.Attr("reduce_all"); auto dims = context.Attr>("dim"); - auto* input0 = context.Input("X"); auto* input1 = context.Input("Out"); - auto* input2 = context.Input(framework::GradVarName("Out")); + auto* output = context.Output(framework::GradVarName("X")); output->mutable_data(context.GetPlace()); @@ -152,6 +237,26 @@ class ReduceGradKernel : public framework::OpKernel { } } } + + void Compute(const framework::ExecutionContext& context) const override { + int in_dtype = context.Attr("in_dtype"); + if (in_dtype >= 0) { + Tensor tmp_tensor; + auto* pre_input = context.Input(framework::GradVarName("Out")); + auto in_kernel_type = + framework::OpKernelType(pre_input->type(), context.GetPlace()); + auto out_kernel_type = framework::OpKernelType( + static_cast(in_dtype), + context.GetPlace()); + framework::TransDataType(in_kernel_type, out_kernel_type, *pre_input, + &tmp_tensor); + ComputeFromInput(&tmp_tensor, context); + + } else { + auto* input2 = context.Input(framework::GradVarName("Out")); + ComputeFromInput(input2, context); + } + } }; class ReduceOp : public framework::OperatorWithKernel { @@ -267,6 +372,12 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + int in_dtype = ctx.Attr("in_dtype"); + if (in_dtype >= 0) { + return framework::OpKernelType( + static_cast(in_dtype), + ctx.GetPlace()); + } return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")), ctx.GetPlace()); @@ -295,6 +406,16 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) " "If true, output a scalar reduced along all dimensions.") .SetDefault(false); + AddAttr("in_dtype", + "(int, default -1)" + "The dtype of input, default value is -1, the user could not " + "set this value.") + .SetDefault(-1); + AddAttr( + "out_dtype", + "(int, default -1)" + "The dtype of output, default value is -1, the dtype is same as intput") + .SetDefault(-1); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index b4ac0714254a0a66269f54a3713009511387ef32..0eb6f2e66ee6a5a09badf56f861cd4b31149d472 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -35,9 +35,34 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker { op->SetAttrMap(this->Attrs()); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + int in_dtype = ctx.Attr("in_dtype"); + if (in_dtype >= 0) { + return framework::OpKernelType( + static_cast(in_dtype), + ctx.GetPlace()); + } + return framework::OpKernelType( + framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } }; DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInference, "X"); +class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { + public: + void operator()(paddle::framework::InferVarTypeContext* ctx) const override { + auto data_type = static_cast( + boost::get(ctx->GetAttr("out_dtype"))); + if (data_type >= 0) { + auto& out_var_name = ctx->Output("Out").front(); + ctx->SetDataType(out_var_name, data_type); + } + } +}; } // namespace operators } // namespace paddle @@ -49,6 +74,7 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { }; REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, + ops::ReduceSumVarTypeInference, ops::ReduceSumOpGradMaker, ops::ReduceSumOpGradMaker); REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index 9051740e83aabd783750e8f415da09921608e470..e64845a4f74e34e9e3835ed111798b9a89ea2bc7 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -32,6 +32,7 @@ class ReduceSumKernel : public framework::OpKernel { bool reduce_all = context.Attr("reduce_all"); auto* input = context.Input("X"); auto* output = context.Output("Out"); + auto out_dtype = context.Attr("out_dtype"); auto dims = context.Attr>("dim"); bool keep_dim = context.Attr("keep_dim"); @@ -52,9 +53,17 @@ class ReduceSumKernel : public framework::OpKernel { } auto stream = context.cuda_device_context().stream(); - TensorReduce>( - *input, output, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + if (out_dtype >= 0) { + framework::VisitDataTypeSmall( + static_cast(out_dtype), + TensorReduceFunctor>( + *input, output, reduce_dims, static_cast(0.0), cub::Sum(), + IdentityFunctor(), stream)); + } else { + TensorReduce>( + *input, output, reduce_dims, static_cast(0), cub::Sum(), + IdentityFunctor(), stream); + } } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.h b/paddle/fluid/operators/reduce_ops/reduce_sum_op.h index ceaba30b01fe3c7f9a7eebbc88f6acccc4ce4586..7f61794fbb11b180e6d289a9aa3f7c07fe2ade54 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.h @@ -26,52 +26,73 @@ template class ReduceSumGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void ComputeFromInput(const Tensor* input2, + const framework::ExecutionContext& context) const { auto dims = context.Attr>("dim"); - if (context.GetPlace().type() == typeid(platform::CPUPlace) && - dims.size() == 1) { - auto* input0 = context.Input("X"); - auto* input2 = context.Input(framework::GradVarName("Out")); - auto* output = context.Output(framework::GradVarName("X")); - output->mutable_data(context.GetPlace()); - const auto* input2_d = input2->data(); - auto* output_d = output->data(); + auto* input0 = context.Input("X"); - // handle reduce_all - if (input2->dims().size() == 1 && input2->dims()[0] == 1) { - for (int64_t i = 0; i < framework::product(input0->dims()); ++i) { - output_d[i] = input2_d[0]; - } - return; - } + auto* output = context.Output(framework::GradVarName("X")); + output->mutable_data(context.GetPlace()); + const auto* input2_d = input2->data(); + auto* output_d = output->data(); - // handle reduce by one dimension - int reduce_dim_index = dims[0]; - if (reduce_dim_index < 0) { - reduce_dim_index += input0->dims().size(); + // handle reduce_all + if (input2->dims().size() == 1 && input2->dims()[0] == 1) { + for (int64_t i = 0; i < framework::product(input0->dims()); ++i) { + output_d[i] = input2_d[0]; } + return; + } - auto& input_dim = input0->dims(); - int64_t before_dim = 1; - for (int i = 0; i < reduce_dim_index; ++i) { - before_dim *= input_dim[i]; - } - int64_t reduce_dim = input_dim[reduce_dim_index]; - int64_t after_dim = 1; - for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) { - after_dim *= input_dim[i]; - } - for (int64_t i = 0; i < before_dim; ++i) { - for (int64_t j = 0; j < reduce_dim; ++j) { - for (int64_t k = 0; k < after_dim; ++k) { - output_d[i * reduce_dim * after_dim + j * after_dim + k] = - input2_d[i * after_dim + k]; - } + // handle reduce by one dimension + int reduce_dim_index = dims[0]; + if (reduce_dim_index < 0) { + reduce_dim_index += input0->dims().size(); + } + + auto& input_dim = input0->dims(); + int64_t before_dim = 1; + for (int i = 0; i < reduce_dim_index; ++i) { + before_dim *= input_dim[i]; + } + int64_t reduce_dim = input_dim[reduce_dim_index]; + int64_t after_dim = 1; + for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) { + after_dim *= input_dim[i]; + } + for (int64_t i = 0; i < before_dim; ++i) { + for (int64_t j = 0; j < reduce_dim; ++j) { + for (int64_t k = 0; k < after_dim; ++k) { + output_d[i * reduce_dim * after_dim + j * after_dim + k] = + input2_d[i * after_dim + k]; } } - return; } + } + void Compute(const framework::ExecutionContext& context) const override { + auto dims = context.Attr>("dim"); + if (context.GetPlace().type() == typeid(platform::CPUPlace) && + dims.size() == 1) { + int in_dtype = context.Attr("in_dtype"); + + if (in_dtype >= 0) { + Tensor tmp_tensor; + auto* pre_input = context.Input(framework::GradVarName("Out")); + auto in_kernel_type = + framework::OpKernelType(pre_input->type(), context.GetPlace()); + auto out_kernel_type = framework::OpKernelType( + static_cast(in_dtype), + context.GetPlace()); + framework::TransDataType(in_kernel_type, out_kernel_type, *pre_input, + &tmp_tensor); + ComputeFromInput(&tmp_tensor, context); + } else { + auto* input2 = context.Input(framework::GradVarName("Out")); + ComputeFromInput(input2, context); + } + return; + } // default use Eigen broadcast ReduceGradKernel kernel; kernel.Compute(context); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ca9d1d27f1ddcd5f390923a860cdce20ac5ac79b..1f2aa5cacda94a61f9de2af822833e7444ddb3b6 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -126,7 +126,7 @@ from .tensor.math import sin #DEFINE_ALIAS from .tensor.math import sqrt #DEFINE_ALIAS # from .tensor.math import square #DEFINE_ALIAS # from .tensor.math import stanh #DEFINE_ALIAS -# from .tensor.math import sum #DEFINE_ALIAS +from .tensor.math import sum #DEFINE_ALIAS # from .tensor.math import sums #DEFINE_ALIAS from .tensor.math import tanh #DEFINE_ALIAS # from .tensor.math import elementwise_sum #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index efd1b15e920074e65af5a8b6a7a18a19b214e71d..ae55a7844d16b23b05577d0b4959ccd329057a9b 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -17,9 +17,11 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest, skip_check_grad_ci +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ class TestSumOp(OpTest): @@ -426,6 +428,48 @@ class Test1DReduceWithAxes1(OpTest): self.check_grad(['X'], 'Out') +class TestReduceWithDtype(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")} + self.outputs = {'Out': self.inputs['X'].sum().astype('float64')} + self.attrs = {'reduce_all': True} + self.attrs.update({ + 'in_dtype': int(convert_np_dtype_to_dtype_(np.float32)), + 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)) + }) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceWithDtype1(TestReduceWithDtype): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")} + self.outputs = {'Out': self.inputs['X'].sum(axis=1)} + self.attrs = {'dim': [1]} + self.attrs.update({ + 'in_dtype': int(convert_np_dtype_to_dtype_(np.float32)), + 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)) + }) + + +class TestReduceWithDtype2(TestReduceWithDtype): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")} + self.outputs = {'Out': self.inputs['X'].sum(axis=1, keepdims=True)} + self.attrs = {'dim': [1], 'keep_dim': True} + self.attrs.update({ + 'in_dtype': int(convert_np_dtype_to_dtype_(np.float32)), + 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)) + }) + + class TestReduceSumOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -450,5 +494,85 @@ class TestReduceMeanOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.reduce_mean, x2) +class API_TestSumOpError(unittest.TestCase): + def test_errors(self): + def test_dtype1(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.sum(data, dtype="int32") + + self.assertRaises(ValueError, test_dtype1) + + def test_dtype2(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.sum(data, dtype="float32") + + self.assertRaises(ValueError, test_dtype2) + + def test_dtype3(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="int32") + paddle.sum(data, dtype="bool") + + self.assertRaises(ValueError, test_dtype3) + + def test_dtype4(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="int32") + paddle.sum(data, dtype="int32") + + self.assertRaises(ValueError, test_dtype3) + + +class API_TestSumOp(unittest.TestCase): + def test_1(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="float32") + result_sum = paddle.sum(input=data, dim=1, dtype="float64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.rand(10, 10).astype(np.float32) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) + self.assertEqual( + (res == np.sum(input_data.astype(np.float64), axis=1)).all(), True) + + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="int32") + result_sum = paddle.sum(input=data, dim=1, dtype="int64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int32) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) + self.assertEqual( + (res == np.sum(input_data.astype(np.int64), axis=1)).all(), True) + + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="int32") + result_sum = paddle.sum(input=data, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int32) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) + self.assertEqual((res == np.sum(input_data, axis=1)).all(), True) + + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=[10, 10], dtype="int32") + result_sum = paddle.sum(input=data, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input_data = np.random.randint(10, size=(10, 10)).astype(np.int32) + res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) + self.assertEqual((res == np.sum(input_data, axis=1)).all(), True) + + with fluid.dygraph.guard(): + np_x = np.array([10, 10]).astype('float64') + x = fluid.dygraph.to_variable(np_x) + z = paddle.sum(x, dim=0) + np_z = z.numpy() + z_expected = np.array(np.sum(np_x, axis=0)) + self.assertEqual((np_z == z_expected).all(), True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 54b60bba7056de661e50d50d111ecda129cfbb08..3a43d95414c360a847afc97725cc708836448bfc 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -102,7 +102,7 @@ from .math import sin #DEFINE_ALIAS from .math import sqrt #DEFINE_ALIAS # from .math import square #DEFINE_ALIAS # from .math import stanh #DEFINE_ALIAS -# from .math import sum #DEFINE_ALIAS +from .math import sum #DEFINE_ALIAS # from .math import sums #DEFINE_ALIAS from .math import tanh #DEFINE_ALIAS # from .math import elementwise_sum #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6200d728596c5467044c086a4b538c1164e9afc2..db31cc8bdee9a03a8fa5f5bfb78708d25a3187c5 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -60,7 +60,7 @@ __all__ = [ 'sqrt', # 'square', # 'stanh', -# 'sum', + 'sum', # 'sums', 'tanh', # 'elementwise_sum', @@ -647,3 +647,103 @@ for func in [ additional_args_lines=additional_args_lines, skip_attrs_set={"x_data_format", "y_data_format", "axis" }) + """\n""" + str(func.__doc__) + +def sum(input, dim=None, dtype=None, keep_dim=False, name=None): + """ + Computes the sum of tensor elements over the given dimension. + + Args: + input (Variable): The input variable which is a Tensor, the data type is float32, + float64, int32, int64. + dim (list|int, optional): The dimensions along which the sum is performed. If + :attr:`None`, sum all elements of :attr:`input` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, + the dimension to reduce is :math:`rank + dim[i]`. + dtype(str, optional): The dtype of output tensor. The default value is None, the dtype + of output is the same as input tensor. + keep_dim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true, default + value is False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Variable: Tensor, results of summation operation on the specified dim of input tensor, + it's data type is the same as input's Tensor. + + Raises: + ValueError, the :attr:`dtype` must be float64 or int64. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + # x is a Tensor variable with following elements: + # [[0.2, 0.3, 0.5, 0.9] + # [0.1, 0.2, 0.6, 0.7]] + # Each example is followed by the corresponding output tensor. + x = fluid.data(name='x', shape=[2, 4], dtype='float32') + out1 = paddle.sum(x) # [3.5] + out2 = paddle.sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6] + out3 = paddle.sum(x, dim=-1) # [1.9, 1.6] + out4 = paddle.sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]] + + # y is a Tensor variable with shape [2, 2, 2] and elements as below: + # [[[1, 2], [3, 4]], + # [[5, 6], [7, 8]]] + # Each example is followed by the corresponding output tensor. + y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32') + out5 = paddle.sum(y, dim=[1, 2]) # [10, 26] + out6 = paddle.sum(y, dim=[0, 1]) # [16, 20] + + """ + if dim is not None and not isinstance(dim, list): + dim = [dim] + attrs = { + 'dim': dim if dim != None and dim != [] else [0], + 'keep_dim': keep_dim, + 'reduce_all': True if dim == None or dim == [] else False, + } + dtype_flag = False + if dtype is not None: + if dtype in ['float64', 'int64']: + if (convert_dtype(input.dtype) == "float32" and dtype == "float64") or \ + (convert_dtype(input.dtype) == "int32" and dtype == "int64"): + attrs.update({ + 'in_dtype': input.dtype, + 'out_dtype': convert_np_dtype_to_dtype_(dtype) + }) + dtype_flag = True + else: + raise ValueError( + "The value of 'dtype' in sum op must be float64, int64, but received of {}". + format(dtype)) + + if in_dygraph_mode(): + reduce_all = True if dim == None or dim == [] else False + dim = dim if dim != None and dim != [] else [0] + if dtype_flag: + return core.ops.reduce_sum(input, 'dim', dim, 'keep_dim', keep_dim, + 'reduce_all', reduce_all, 'in_dtype', + input.dtype, 'out_dtype', + convert_np_dtype_to_dtype_(dtype)) + else: + return core.ops.reduce_sum(input, 'dim', dim, 'keep_dim', keep_dim, + 'reduce_all', reduce_all) + check_variable_and_dtype( + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_sum') + helper = LayerHelper('sum', **locals()) + if dtype_flag: + out = helper.create_variable_for_type_inference( + dtype=convert_np_dtype_to_dtype_(dtype)) + else: + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='reduce_sum', + inputs={'X': input}, + outputs={'Out': out}, + attrs=attrs) + return out