提交 7bd16fe1 编写于 作者: Y Yancey1989

registry var type infer

上级 6e83c003
...@@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -59,6 +59,16 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
} }
}; };
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
}
}
};
class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker { ...@@ -80,7 +90,8 @@ class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp, REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
ops::SplitSelectedRowsOpMaker, ops::SplitSelectedRowsOpMaker,
ops::SplitSelectedRowsGradMaker); ops::SplitSelectedRowsGradMaker,
ops::SplitSelectedRowsOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
split_selected_rows, split_selected_rows,
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>); ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册