提交 df80b6ea 编写于 作者: K ktlichkid

Added InferVarType

上级 f57efeb6
......@@ -318,7 +318,7 @@ class BeamSearchInferShape : public framework::InferShapeBase {
}
}
};
*/
class BeamSearchInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
......@@ -331,7 +331,7 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
}
}
};
*/
} // namespace operators
} // namespace paddle
/*
......@@ -343,7 +343,8 @@ REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
*/
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp,
ops::BeamSearchOpMaker);
ops::BeamSearchOpMaker,
ops::BeamSearchInferVarType);
REGISTER_OP_CPU_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册