From 7bd16fe13b672a9744afce1fdfad0891baddcf36 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 28 Feb 2018 19:05:33 +0800 Subject: [PATCH] registry var type infer --- paddle/fluid/operators/split_selected_rows_op.cc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index b0e21e01ec..e1ce3d0c1b 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { } }; +class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + for (auto &out_var : op_desc.Output("Out")) { + block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS); + } + } +}; + class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp, ops::SplitSelectedRowsOpMaker, - ops::SplitSelectedRowsGradMaker); + ops::SplitSelectedRowsGradMaker, + ops::SplitSelectedRowsOpInferVarType); REGISTER_OP_CPU_KERNEL( split_selected_rows, ops::SplitSelectedRowsOpKernel); -- GitLab