diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise_add_op.cc index e9068fcd50ba9309a37939788ca8f67f1f6e25cd..4aab54f60236ecc5fa7f70e22f1553c3bfe68198 100644 --- a/paddle/fluid/operators/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise_add_op.cc @@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker, - elementwise_add_grad, ops::ElementwiseOpGrad); +REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp, + ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad); + REGISTER_OP_CPU_KERNEL( elementwise_add, ops::ElementwiseAddKernel, diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index fe31bbaed44fced68b7b51dd2c2031950ec4247d..f04d8d8fd82ed2336dff9c5b88808dc32de6630a 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel { } }; +class ElementwiseOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto x_var = op_desc.Input("X")[0]; + auto out_var = op_desc.Output("Out")[0]; + block->Var(out_var)->SetType(block->Var(x_var)->GetType()); + } +}; + class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { public: ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)