From 01fdf17e974b696ee19afb73b68fec83e89e4953 Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 22 May 2018 13:15:46 +0800 Subject: [PATCH] Fix ElementwiseOpInferVarType in elementwise_op to use the default InferVarType to find var recursively --- paddle/fluid/operators/elementwise_op.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index d75aa6a6094..f4cec8ad971 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -42,6 +42,18 @@ class ElementwiseOp : public framework::OperatorWithKernel { } }; +class ElementwiseOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto x_name = op_desc.Input("X")[0]; + auto out_name = op_desc.Output("Out")[0]; + auto& x = block->FindRecursiveOrCreateVar(x_name); + auto& out = block->FindRecursiveOrCreateVar(out_name); + out.SetType(x.GetType()); + } +}; + class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() final { @@ -138,5 +150,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { }; \ REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ __ElemwiseOp##op_type##Maker__, \ + ::paddle::operators::ElementwiseOpInferVarType, \ ::paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad) -- GitLab