未验证 提交 cc157d59 编写于 作者: Z Zeng Jinle 提交者: GitHub

add inplace to assign op, test=develop (#19927)

上级 55ce6969
...@@ -46,12 +46,13 @@ class NoNeedBufferVarsInference { ...@@ -46,12 +46,13 @@ class NoNeedBufferVarsInference {
}; };
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \ #define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
class class_type : public ::paddle::framework::NoNeedBufferVarsInference { \ class class_type final \
: public ::paddle::framework::NoNeedBufferVarsInference { \
public: \ public: \
using ::paddle::framework::NoNeedBufferVarsInference:: \ using ::paddle::framework::NoNeedBufferVarsInference:: \
NoNeedBufferVarsInference; \ NoNeedBufferVarsInference; \
\ \
std::unordered_set<std::string> operator()() const override { \ std::unordered_set<std::string> operator()() const final { \
return {__VA_ARGS__}; \ return {__VA_ARGS__}; \
} \ } \
} }
......
...@@ -144,12 +144,14 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker { ...@@ -144,12 +144,14 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
ops::AssignOpProtoMaker); ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel, ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel); int64_t, ops::AssignKernel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册