From cc157d599008c6faca23aa1a3b5270f6ff64541e Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 24 Sep 2019 10:18:20 +0800 Subject: [PATCH] add inplace to assign op, test=develop (#19927) --- .../framework/no_need_buffer_vars_inference.h | 19 ++++++++++--------- paddle/fluid/operators/assign_op.cc | 4 +++- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/no_need_buffer_vars_inference.h b/paddle/fluid/framework/no_need_buffer_vars_inference.h index 2c93365984..a63575611b 100644 --- a/paddle/fluid/framework/no_need_buffer_vars_inference.h +++ b/paddle/fluid/framework/no_need_buffer_vars_inference.h @@ -45,15 +45,16 @@ class NoNeedBufferVarsInference { const AttributeMap &attrs_; }; -#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \ - class class_type : public ::paddle::framework::NoNeedBufferVarsInference { \ - public: \ - using ::paddle::framework::NoNeedBufferVarsInference:: \ - NoNeedBufferVarsInference; \ - \ - std::unordered_set operator()() const override { \ - return {__VA_ARGS__}; \ - } \ +#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \ + class class_type final \ + : public ::paddle::framework::NoNeedBufferVarsInference { \ + public: \ + using ::paddle::framework::NoNeedBufferVarsInference:: \ + NoNeedBufferVarsInference; \ + \ + std::unordered_set operator()() const final { \ + return {__VA_ARGS__}; \ + } \ } } // namespace framework diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 871dfe6734..ff423778c5 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -144,12 +144,14 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, - ops::AssignOpProtoMaker); + ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer); REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ops::AssignKernel, int, ops::AssignKernel, int64_t, ops::AssignKernel); -- GitLab