未验证 提交 1c25c88a 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine memory usage of some operators, test=develop (#19700)

上级 87f13f75
...@@ -163,7 +163,8 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -163,7 +163,8 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context()); ctx.device_context());
} }
...@@ -195,13 +196,16 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -195,13 +196,16 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
ops::ExpandGradOpDescMaker); ops::ExpandGradOpDescMaker);
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp,
ops::ExpandGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>, expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>, ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -186,7 +186,6 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -186,7 +186,6 @@ class ExpandGradKernel : public framework::OpKernel<T> {
"reduce dimensions."); "reduce dimensions.");
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0); auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims; Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
...@@ -200,7 +199,9 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -200,7 +199,9 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto out_grad = EigenVector<T>::Flatten(*in0); auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device( x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) = *context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions()); out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
} }
}; };
......
...@@ -85,7 +85,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { ...@@ -85,7 +85,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(ctx.Input<framework::Tensor>("ROIs")->type(),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -167,13 +167,16 @@ class ROIAlignGradDescMaker : public framework::SingleGradOpDescMaker { ...@@ -167,13 +167,16 @@ class ROIAlignGradDescMaker : public framework::SingleGradOpDescMaker {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(RoiAlignGradNoNeedBufVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
ops::ROIAlignGradDescMaker); ops::ROIAlignGradDescMaker);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp); REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp,
ops::RoiAlignGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roi_align, roi_align,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -168,6 +168,12 @@ class SigmoidCrossEntropyWithLogitsGradOpDescMaker ...@@ -168,6 +168,12 @@ class SigmoidCrossEntropyWithLogitsGradOpDescMaker
} }
}; };
DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsInplaceInferer,
{"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -175,9 +181,11 @@ namespace ops = paddle::operators; ...@@ -175,9 +181,11 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits, REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsOp, ops::SigmoidCrossEntropyWithLogitsOp,
ops::SigmoidCrossEntropyWithLogitsOpMaker, ops::SigmoidCrossEntropyWithLogitsOpMaker,
ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker); ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker,
ops::SigmoidCrossEntropyWithLogitsInplaceInferer);
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad, REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp); ops::SigmoidCrossEntropyWithLogitsGradOp,
ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sigmoid_cross_entropy_with_logits, sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext, ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册