提交 af692c91 编写于 作者: L Leo Chen 提交者: Zeng Jinle

update reduce_sum and reduce_mean to save memory, test=develop (#19608)

上级 e3e98ed6
......@@ -20,7 +20,6 @@ rank_loss
reduce_max
reduce_min
reduce_prod
reduce_sum
reshape
rnn_memory_helper
sequence_softmax
......
......@@ -61,6 +61,8 @@ class ReduceMeanDoubleGradMaker : public framework::GradOpDescMakerBase {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ReduceMeanGradNoNeedBufferVarInference,
"X");
} // namespace operators
} // namespace paddle
......@@ -73,7 +75,8 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
ops::ReduceMeanOpGradDescMaker);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradMaker);
ops::ReduceMeanDoubleGradMaker,
ops::ReduceMeanGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
......@@ -83,12 +86,13 @@ REGISTER_OP_CPU_KERNEL(reduce_mean,
int, ops::MeanFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
double, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MeanGradFunctor>);
template <typename T>
using CPUReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T,
ops::MeanGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
CPUReduceMeanGradKernel<double>,
CPUReduceMeanGradKernel<int>,
CPUReduceMeanGradKernel<int64_t>);
......@@ -15,12 +15,12 @@
// .part used to speed up nvcc compile
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_mean_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::MeanGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::MeanGradFunctor>);
template <typename T>
using CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>,
CUDAReduceMeanGradKernel<int>,
CUDAReduceMeanGradKernel<int64_t>);
......@@ -75,7 +75,8 @@ class ReduceKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T, typename Functor>
template <typename DeviceContext, typename T, typename Functor,
bool kNoNeedBufferX = false, bool kNoNeedBufferY = false>
class ReduceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -88,6 +89,17 @@ class ReduceGradKernel : public framework::OpKernel<T> {
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
// NOTE: EigenTensor::From() uses tensor->data()
// if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
// kNoNeedBufferY should set true
// and use fake var that has same dims.
if (kNoNeedBufferX) {
input0 = output;
}
if (kNoNeedBufferY) {
input1 = input2;
}
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if (!input1) input1 = input2;
......@@ -220,6 +232,14 @@ class ReduceGradOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -13,8 +13,47 @@
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
// NOTE: Input(Out) is unnecessary in reduce_sum_grad, and Input(X) needs no
// buffer
class ReduceSumOpGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("reduce_sum_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ReduceSumGradNoNeedBufferVarInference,
"X");
} // namespace operators
} // namespace paddle
class ReduceSumOpMaker : public ops::ReduceOpMaker {
protected:
virtual std::string GetName() const { return "reduce_sum"; }
virtual std::string GetOpType() const { return "Reduce reduce_sum"; }
};
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumOpGradDescMaker);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumGradNoNeedBufferVarInference);
REGISTER_REDUCE_OP(reduce_sum);
REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::SumFunctor>,
......@@ -23,13 +62,13 @@ REGISTER_OP_CPU_KERNEL(
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>);
REGISTER_OP_CPU_KERNEL(
reduce_sum_grad,
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, float,
ops::SumGradFunctor>,
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, double,
ops::SumGradFunctor>,
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, int,
ops::SumGradFunctor>,
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumGradFunctor>);
template <typename T>
using CPUReduceSumGradKernel =
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>,
CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<int>,
CPUReduceSumGradKernel<int64_t>);
......@@ -22,7 +22,8 @@ namespace paddle {
namespace operators {
// use for loop to speed up Eigen broadcast. 4 timer faster then broadcast
template <typename DeviceContext, typename T, typename Functor>
template <typename DeviceContext, typename T, typename Functor,
bool kNoNeedBufferX = false>
class ReduceSumGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -72,7 +73,7 @@ class ReduceSumGradKernel : public framework::OpKernel<T> {
}
// default use Eigen broadcast
ReduceGradKernel<DeviceContext, T, Functor> kernel;
ReduceGradKernel<DeviceContext, T, Functor, kNoNeedBufferX> kernel;
kernel.Compute(context);
}
};
......
......@@ -15,12 +15,12 @@
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
float, ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, double,
ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int,
ops::SumGradFunctor>,
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, int64_t,
ops::SumGradFunctor>);
template <typename T>
using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<float>,
CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>,
CUDAReduceSumGradKernel<int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册