提交 53d19f5b 编写于 作者: C chengduoZH

Add ElementwiseOpInferVarType

上级 0d49b921
......@@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel {
}
};
class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_var = op_desc.Input("X")[0];
auto out_var = op_desc.Output("Out")[0];
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
}
};
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册