未验证 提交 cc3306f7 编写于 作者: Z zhupengyang 提交者: GitHub

restruct logsumexp to speed up compiling (#27191)

上级 50e60e87
...@@ -13,18 +13,138 @@ ...@@ -13,18 +13,138 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
#include <memory> #include <algorithm>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class LogsumexpOpMaker : public ops::ReduceOpMaker { class LogsumexpOp : public framework::OperatorWithKernel {
protected: public:
virtual std::string GetName() const { return "logsumexp"; } using framework::OperatorWithKernel::OperatorWithKernel;
virtual std::string GetOpType() const { return "Reduce logsumexp"; }
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 4,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of logsumexp "
"should be less equal than 4. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank, x_dims));
auto axis = ctx->Attrs().Get<std::vector<int>>("axis");
PADDLE_ENFORCE_GT(
axis.size(), 0,
platform::errors::InvalidArgument(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d.",
axis.size()));
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_LT(
axis[i], x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-dimension(X), dimension(X)] "
"where dimesion(X) is %d. But received axis[i] = %d.",
i, x_rank, axis[i]));
PADDLE_ENFORCE_GE(
axis[i], -x_rank,
platform::errors::InvalidArgument(
"axis[%d] should be in the "
"range [-dimension(X), dimension(X)] "
"where dimesion(X) is %d. But received axis[i] = %d.",
i, x_rank, axis[i]));
if (axis[i] < 0) {
axis[i] += x_rank;
}
}
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
auto dims_vector = vectorize(x_dims);
if (reduce_all) {
if (keepdim)
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
ctx->SetOutputDim("Out", {1});
} else {
auto dims_vector = vectorize(x_dims);
if (keepdim) {
for (size_t i = 0; i < axis.size(); ++i) {
dims_vector[axis[i]] = 1;
}
} else {
const int kDelFlag = -1;
for (size_t i = 0; i < axis.size(); ++i) {
dims_vector[axis[i]] = kDelFlag;
}
dims_vector.erase(
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keepdim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx->SetOutputDim("Out", out_dims);
if (axis.size() > 0 && axis[0] != 0) {
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
};
class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 4 are "
"supported.");
AddOutput("Out", "(Tensor) The result tensor.");
AddAttr<std::vector<int>>(
"axis",
"(list<int>, default {0}) The dimensions to reduce. "
"Must be in the range [-rank(input), rank(input)). "
"If `axis[i] < 0`, the axis[i] to reduce is `rank + axis[i]`. "
"Note that reducing on the first dim will make the LoD info lost.")
.SetDefault({0});
AddAttr<bool>("keepdim",
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
.SetDefault(false);
AddAttr<bool>("reduce_all",
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
logsumexp Operator.
This operator computes the logsumexp of input tensor along the given axis.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC"));
}
};
class LogsumexpGrapOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
}; };
template <typename T> template <typename T>
...@@ -32,7 +152,6 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -32,7 +152,6 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override { void Apply(GradOpPtr<T> op) const override {
op->SetType("logsumexp_grad"); op->SetType("logsumexp_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -46,18 +165,17 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -46,18 +165,17 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(logsumexp, ops::ReduceOp, ops::LogsumexpOpMaker, namespace ops = paddle::operators;
REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker,
ops::LogsumexpGradOpMaker<paddle::framework::OpDesc>, ops::LogsumexpGradOpMaker<paddle::framework::OpDesc>,
ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>); ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logsumexp_grad, ops::ReduceGradOp); REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp);
REGISTER_OP_CPU_KERNEL(logsumexp,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::LogsumexpFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
double, ops::LogsumexpFunctor>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
logsumexp_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, logsumexp, ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, float>,
float, ops::LogsumexpGradFunctor>, ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, double>);
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, double, REGISTER_OP_CPU_KERNEL(
ops::LogsumexpGradFunctor>); logsumexp_grad,
ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
REGISTER_OP_CUDA_KERNEL(logsumexp, namespace ops = paddle::operators;
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::LogsumexpFunctor>, REGISTER_OP_CUDA_KERNEL(
ops::ReduceKernel<paddle::platform::CUDADeviceContext, logsumexp, ops::LogsumexpKernel<paddle::platform::CUDADeviceContext, float>,
double, ops::LogsumexpFunctor>); ops::LogsumexpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -14,11 +14,20 @@ ...@@ -14,11 +14,20 @@
#pragma once #pragma once
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include <algorithm>
#include <vector>
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
ReduceFunctor<DeviceContext, OutT, NDIM, RDIM, LogsumexpFunctor>( \
context.template device_context<DeviceContext>(), *input, output, \
axis, keepdim); \
}
struct LogsumexpFunctor { struct LogsumexpFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim> template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
...@@ -54,5 +63,106 @@ struct LogsumexpGradFunctor { ...@@ -54,5 +63,106 @@ struct LogsumexpGradFunctor {
} }
}; };
template <typename DeviceContext, typename OutT>
class LogsumexpKernel : public framework::OpKernel<OutT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<OutT>(context.GetPlace());
auto axis = context.Attr<std::vector<int>>("axis");
auto keepdim = context.Attr<bool>("keepdim");
auto reduce_all = context.Attr<bool>("reduce_all");
const auto& input_dim_size = input->dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = EigenVector<OutT>::Flatten(*input);
auto out = EigenScalar<OutT>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor()(place, &x, &out, reduce_dim);
} else {
int ndim = input_dim_size;
int rdim = axis.size();
// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
// HANDLE_DIM(6, 3);
// HANDLE_DIM(6, 2);
// HANDLE_DIM(6, 1);
// HANDLE_DIM(5, 4);
// HANDLE_DIM(5, 3);
// HANDLE_DIM(5, 2);
// HANDLE_DIM(5, 1);
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
}
}
};
template <typename DeviceContext, typename T>
class LogsumexpGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Input<Tensor>("Out");
auto* output_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = context.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(context.GetPlace());
auto axis = context.Attr<std::vector<int>>("axis");
auto reduce_all = context.Attr<bool>("reduce_all");
const auto input_dim_size = context.Input<Tensor>("X")->dims().size();
reduce_all |= (static_cast<const int>(axis.size()) == input_dim_size);
if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input);
auto y = EigenVector<T>::Flatten(*output);
auto dy = EigenVector<T>::Flatten(*output_grad);
auto dx = EigenVector<T>::Flatten(*input_grad);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto broadcast_dim =
Eigen::array<int, 1>({{static_cast<int>(input->numel())}});
LogsumexpGradFunctor()(place, &x, &y, &dx, &dy, broadcast_dim,
broadcast_dim[0]);
} else {
int rank = input->dims().size();
switch (rank) {
case 1:
ReduceGradFunctor<DeviceContext, T, 1, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, axis);
break;
case 2:
ReduceGradFunctor<DeviceContext, T, 2, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, axis);
break;
case 3:
ReduceGradFunctor<DeviceContext, T, 3, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, axis);
break;
case 4:
ReduceGradFunctor<DeviceContext, T, 4, LogsumexpGradFunctor>(
context.template device_context<DeviceContext>(), *input, *output,
*output_grad, input_grad, axis);
break;
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
// .part used to speed up nvcc compile // .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" #include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
logsumexp_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, logsumexp_grad,
float, ops::LogsumexpGradFunctor>, ops::LogsumexpGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double, ops::LogsumexpGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::LogsumexpGradFunctor>);
...@@ -46,8 +46,8 @@ class TestLogsumexp(OpTest): ...@@ -46,8 +46,8 @@ class TestLogsumexp(OpTest):
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = { self.attrs = {
'dim': self.axis, 'axis': self.axis,
'keep_dim': self.keepdim, 'keepdim': self.keepdim,
'reduce_all': self.reduce_all 'reduce_all': self.reduce_all
} }
......
...@@ -1194,15 +1194,14 @@ def logsumexp(x, axis=None, keepdim=False, name=None): ...@@ -1194,15 +1194,14 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
axis = [0] axis = [0]
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.logsumexp(x, 'dim', axis, 'keep_dim', keepdim, return core.ops.logsumexp(x, 'axis', axis, 'keepdim', keepdim, 'reduce_all', reduce_all)
'reduce_all', reduce_all)
check_variable_and_dtype(x, 'x', check_variable_and_dtype(x, 'x',
['float32', 'float64'], ['float32', 'float64'],
'logsumexp') 'logsumexp')
helper = LayerHelper('logsumexp', **locals()) helper = LayerHelper('logsumexp', **locals())
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} attrs = {'axis': axis, 'keepdim': keepdim, 'reduce_all':reduce_all}
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type='logsumexp', inputs={'X': x}, outputs={'Out': out}, attrs=attrs) type='logsumexp', inputs={'X': x}, outputs={'Out': out}, attrs=attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册