From cc3306f7c8d62e42196ac3d61e744c0e9d1a1563 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 10 Sep 2020 10:20:04 +0800 Subject: [PATCH] restruct logsumexp to speed up compiling (#27191) --- .../operators/reduce_ops/logsumexp_op.cc | 154 ++++++++++++++++-- .../operators/reduce_ops/logsumexp_op.cu | 10 +- .../fluid/operators/reduce_ops/logsumexp_op.h | 112 ++++++++++++- .../operators/reduce_ops/logsumexp_op.part.cu | 9 +- .../fluid/tests/unittests/test_logsumexp.py | 4 +- python/paddle/tensor/math.py | 5 +- 6 files changed, 261 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.cc b/paddle/fluid/operators/reduce_ops/logsumexp_op.cc index 322a1637f5..7cd164bfd3 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.cc +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.cc @@ -13,18 +13,138 @@ // limitations under the License. #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" -#include +#include #include -#include #include namespace paddle { namespace operators { -class LogsumexpOpMaker : public ops::ReduceOpMaker { - protected: - virtual std::string GetName() const { return "logsumexp"; } - virtual std::string GetOpType() const { return "Reduce logsumexp"; } +class LogsumexpOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp"); + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 4, + platform::errors::InvalidArgument( + "The input tensor X's dimensions of logsumexp " + "should be less equal than 4. But received X's " + "dimensions = %d, X's shape = [%s].", + x_rank, x_dims)); + auto axis = ctx->Attrs().Get>("axis"); + PADDLE_ENFORCE_GT( + axis.size(), 0, + platform::errors::InvalidArgument( + "The size of axis of logsumexp " + "should be greater than 0. But received the size of axis " + "of logsumexp is %d.", + axis.size())); + + for (size_t i = 0; i < axis.size(); i++) { + PADDLE_ENFORCE_LT( + axis[i], x_rank, + platform::errors::InvalidArgument( + "axis[%d] should be in the " + "range [-dimension(X), dimension(X)] " + "where dimesion(X) is %d. But received axis[i] = %d.", + i, x_rank, axis[i])); + PADDLE_ENFORCE_GE( + axis[i], -x_rank, + platform::errors::InvalidArgument( + "axis[%d] should be in the " + "range [-dimension(X), dimension(X)] " + "where dimesion(X) is %d. But received axis[i] = %d.", + i, x_rank, axis[i])); + if (axis[i] < 0) { + axis[i] += x_rank; + } + } + + bool keepdim = ctx->Attrs().Get("keepdim"); + bool reduce_all = ctx->Attrs().Get("reduce_all"); + auto dims_vector = vectorize(x_dims); + if (reduce_all) { + if (keepdim) + ctx->SetOutputDim( + "Out", framework::make_ddim(std::vector(x_rank, 1))); + else + ctx->SetOutputDim("Out", {1}); + } else { + auto dims_vector = vectorize(x_dims); + if (keepdim) { + for (size_t i = 0; i < axis.size(); ++i) { + dims_vector[axis[i]] = 1; + } + } else { + const int kDelFlag = -1; + for (size_t i = 0; i < axis.size(); ++i) { + dims_vector[axis[i]] = kDelFlag; + } + dims_vector.erase( + std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + if (!keepdim && dims_vector.size() == 0) { + dims_vector.push_back(1); + } + auto out_dims = framework::make_ddim(dims_vector); + ctx->SetOutputDim("Out", out_dims); + if (axis.size() > 0 && axis[0] != 0) { + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } +}; + +class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) The input tensor. Tensors with rank at most 4 are " + "supported."); + AddOutput("Out", "(Tensor) The result tensor."); + AddAttr>( + "axis", + "(list, default {0}) The dimensions to reduce. " + "Must be in the range [-rank(input), rank(input)). " + "If `axis[i] < 0`, the axis[i] to reduce is `rank + axis[i]`. " + "Note that reducing on the first dim will make the LoD info lost.") + .SetDefault({0}); + AddAttr("keepdim", + "(bool, default false) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + AddAttr("reduce_all", + "(bool, default false) " + "If true, output a scalar reduced along all dimensions.") + .SetDefault(false); + AddComment(string::Sprintf(R"DOC( +logsumexp Operator. + +This operator computes the logsumexp of input tensor along the given axis. +The result tensor has 1 fewer dimension than the input unless keep_dim is true. +If reduce_all is true, just reduce along all dimensions and output a scalar. + +)DOC")); + } +}; + +class LogsumexpGrapOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "logsumexp"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } }; template @@ -32,7 +152,6 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; - protected: void Apply(GradOpPtr op) const override { op->SetType("logsumexp_grad"); op->SetInput("X", this->Input("X")); @@ -46,18 +165,17 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle -REGISTER_OPERATOR(logsumexp, ops::ReduceOp, ops::LogsumexpOpMaker, +namespace ops = paddle::operators; + +REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker, ops::LogsumexpGradOpMaker, ops::LogsumexpGradOpMaker); -REGISTER_OPERATOR(logsumexp_grad, ops::ReduceGradOp); +REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp); -REGISTER_OP_CPU_KERNEL(logsumexp, - ops::ReduceKernel, - ops::ReduceKernel); REGISTER_OP_CPU_KERNEL( - logsumexp_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel); + logsumexp, ops::LogsumexpKernel, + ops::LogsumexpKernel); +REGISTER_OP_CPU_KERNEL( + logsumexp_grad, + ops::LogsumexpGradKernel, + ops::LogsumexpGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.cu b/paddle/fluid/operators/reduce_ops/logsumexp_op.cu index c9ad1075c0..86a31595eb 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.cu +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.cu @@ -14,8 +14,8 @@ #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" -REGISTER_OP_CUDA_KERNEL(logsumexp, - ops::ReduceKernel, - ops::ReduceKernel); +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + logsumexp, ops::LogsumexpKernel, + ops::LogsumexpKernel); diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.h b/paddle/fluid/operators/reduce_ops/logsumexp_op.h index 1d0e00262a..a478690976 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.h +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.h @@ -14,11 +14,20 @@ #pragma once -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include +#include +#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" namespace paddle { namespace operators { +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + ReduceFunctor( \ + context.template device_context(), *input, output, \ + axis, keepdim); \ + } + struct LogsumexpFunctor { template void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { @@ -54,5 +63,106 @@ struct LogsumexpGradFunctor { } }; +template +class LogsumexpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + auto axis = context.Attr>("axis"); + auto keepdim = context.Attr("keepdim"); + auto reduce_all = context.Attr("reduce_all"); + + const auto& input_dim_size = input->dims().size(); + // The dims has full dim, set the reduce_all is True + reduce_all |= (static_cast(axis.size()) == input_dim_size); + + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = EigenVector::Flatten(*input); + auto out = EigenScalar::From(*output); + auto& place = + *context.template device_context().eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + LogsumexpFunctor()(place, &x, &out, reduce_dim); + } else { + int ndim = input_dim_size; + int rdim = axis.size(); + // comments for accelerating compiling temporarily. + // HANDLE_DIM(6, 5); + // HANDLE_DIM(6, 4); + // HANDLE_DIM(6, 3); + // HANDLE_DIM(6, 2); + // HANDLE_DIM(6, 1); + // HANDLE_DIM(5, 4); + // HANDLE_DIM(5, 3); + // HANDLE_DIM(5, 2); + // HANDLE_DIM(5, 1); + HANDLE_DIM(4, 3); + HANDLE_DIM(4, 2); + HANDLE_DIM(4, 1); + HANDLE_DIM(3, 2); + HANDLE_DIM(3, 1); + HANDLE_DIM(2, 1); + } + } +}; + +template +class LogsumexpGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Input("Out"); + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* input_grad = context.Output(framework::GradVarName("X")); + input_grad->mutable_data(context.GetPlace()); + + auto axis = context.Attr>("axis"); + auto reduce_all = context.Attr("reduce_all"); + const auto input_dim_size = context.Input("X")->dims().size(); + reduce_all |= (static_cast(axis.size()) == input_dim_size); + + if (reduce_all) { + auto x = EigenVector::Flatten(*input); + auto y = EigenVector::Flatten(*output); + auto dy = EigenVector::Flatten(*output_grad); + auto dx = EigenVector::Flatten(*input_grad); + auto& place = + *context.template device_context().eigen_device(); + auto broadcast_dim = + Eigen::array({{static_cast(input->numel())}}); + LogsumexpGradFunctor()(place, &x, &y, &dx, &dy, broadcast_dim, + broadcast_dim[0]); + } else { + int rank = input->dims().size(); + switch (rank) { + case 1: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, axis); + break; + case 2: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, axis); + break; + case 3: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, axis); + break; + case 4: + ReduceGradFunctor( + context.template device_context(), *input, *output, + *output_grad, input_grad, axis); + break; + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu b/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu index d6ad486309..81124e4f07 100644 --- a/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/logsumexp_op.part.cu @@ -15,8 +15,9 @@ // .part used to speed up nvcc compile #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" +namespace ops = paddle::operators; + REGISTER_OP_CUDA_KERNEL( - logsumexp_grad, ops::ReduceGradKernel, - ops::ReduceGradKernel); + logsumexp_grad, + ops::LogsumexpGradKernel, + ops::LogsumexpGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_logsumexp.py b/python/paddle/fluid/tests/unittests/test_logsumexp.py index c2201a5260..cf9203dffc 100644 --- a/python/paddle/fluid/tests/unittests/test_logsumexp.py +++ b/python/paddle/fluid/tests/unittests/test_logsumexp.py @@ -46,8 +46,8 @@ class TestLogsumexp(OpTest): self.inputs = {'X': x} self.outputs = {'Out': out} self.attrs = { - 'dim': self.axis, - 'keep_dim': self.keepdim, + 'axis': self.axis, + 'keepdim': self.keepdim, 'reduce_all': self.reduce_all } diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ed2bbe03a3..079178e1cf 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1194,15 +1194,14 @@ def logsumexp(x, axis=None, keepdim=False, name=None): axis = [0] if in_dygraph_mode(): - return core.ops.logsumexp(x, 'dim', axis, 'keep_dim', keepdim, - 'reduce_all', reduce_all) + return core.ops.logsumexp(x, 'axis', axis, 'keepdim', keepdim, 'reduce_all', reduce_all) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'logsumexp') helper = LayerHelper('logsumexp', **locals()) - attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} + attrs = {'axis': axis, 'keepdim': keepdim, 'reduce_all':reduce_all} out = helper.create_variable_for_type_inference(x.dtype) helper.append_op( type='logsumexp', inputs={'X': x}, outputs={'Out': out}, attrs=attrs) -- GitLab