未验证 提交 1c25c88a 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine memory usage of some operators, test=develop (#19700)

上级 87f13f75
......@@ -163,8 +163,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -195,13 +196,16 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
ops::ExpandGradOpDescMaker);
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp);
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp,
ops::ExpandGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
......
......@@ -186,7 +186,6 @@ class ExpandGradKernel : public framework::OpKernel<T> {
"reduce dimensions.");
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
......@@ -200,7 +199,9 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
}
};
......
......@@ -85,7 +85,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(ctx.Input<framework::Tensor>("ROIs")->type(),
ctx.device_context());
}
};
......@@ -167,13 +167,16 @@ class ROIAlignGradDescMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(RoiAlignGradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
ops::ROIAlignGradDescMaker);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp,
ops::RoiAlignGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
roi_align,
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -168,6 +168,12 @@ class SigmoidCrossEntropyWithLogitsGradOpDescMaker
}
};
DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsInplaceInferer,
{"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......@@ -175,9 +181,11 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsOp,
ops::SigmoidCrossEntropyWithLogitsOpMaker,
ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker);
ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker,
ops::SigmoidCrossEntropyWithLogitsInplaceInferer);
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp);
ops::SigmoidCrossEntropyWithLogitsGradOp,
ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册