From d060b5dfac0da8c76201570a01a44c278d015ac8 Mon Sep 17 00:00:00 2001 From: ktlichkid Date: Thu, 19 Apr 2018 20:33:23 +0800 Subject: [PATCH] Registered beam search op --- paddle/fluid/operators/beam_search_op.cc | 10 +++++++++- paddle/fluid/operators/beam_search_op.h | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index ecee017b0e..f08c71ee01 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -316,9 +316,17 @@ class BeamSearchInferVarType : public framework::VarTypeInference { */ } // namespace operators } // namespace paddle - +/* REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp, paddle::operators::BeamSearchProtoAndCheckerMaker, paddle::operators::BeamSearchInferShape, paddle::operators::BeamSearchInferVarType, paddle::framework::EmptyGradOpMaker); +*/ +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp, + ops::BeamSearchOpMaker); +REGISTER_OP_CPU_KERNEL( + beam_search, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 07cdfcf5ce..dfafe12425 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -193,7 +193,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item); template -class BeamSearchKernel : public framework::OpKernel{ +class BeamSearchOpKernel : public framework::OpKernel{ public: void Compute(const framework::ExecutionContext& context) const override { auto* ids_var = context.Input("ids"); -- GitLab