diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index b0e21e01ecc55d666f43aee621e9b5a05b347e1d..e1ce3d0c1bf11e9a623e4e9adc8f08f5069f4d94 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { } }; +class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + for (auto &out_var : op_desc.Output("Out")) { + block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS); + } + } +}; + class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp, ops::SplitSelectedRowsOpMaker, - ops::SplitSelectedRowsGradMaker); + ops::SplitSelectedRowsGradMaker, + ops::SplitSelectedRowsOpInferVarType); REGISTER_OP_CPU_KERNEL( split_selected_rows, ops::SplitSelectedRowsOpKernel);