提交 01fdf17e 编写于 作者: G guosheng

Fix ElementwiseOpInferVarType in elementwise_op to use the default...

Fix ElementwiseOpInferVarType in elementwise_op to use the default InferVarType to find var recursively
上级 f176a9cf
...@@ -42,6 +42,18 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -42,6 +42,18 @@ class ElementwiseOp : public framework::OperatorWithKernel {
} }
}; };
class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
}
};
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() final { void Make() final {
...@@ -138,5 +150,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -138,5 +150,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
}; \ }; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \ __ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad) REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册