提交 d060b5df 编写于 作者: K ktlichkid

Registered beam search op

上级 b94c5188
...@@ -316,9 +316,17 @@ class BeamSearchInferVarType : public framework::VarTypeInference { ...@@ -316,9 +316,17 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
*/ */
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/*
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp, REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
paddle::operators::BeamSearchProtoAndCheckerMaker, paddle::operators::BeamSearchProtoAndCheckerMaker,
paddle::operators::BeamSearchInferShape, paddle::operators::BeamSearchInferShape,
paddle::operators::BeamSearchInferVarType, paddle::operators::BeamSearchInferVarType,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
*/
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp,
ops::BeamSearchOpMaker);
REGISTER_OP_CPU_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -193,7 +193,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); ...@@ -193,7 +193,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
std::string ItemToString(const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item);
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BeamSearchKernel : public framework::OpKernel<T>{ class BeamSearchOpKernel : public framework::OpKernel<T>{
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* ids_var = context.Input<framework::Tensor>("ids"); auto* ids_var = context.Input<framework::Tensor>("ids");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册