diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index ecee017b0e4a72618a2756494ef50496b5e1238e..f08c71ee01674d4b3863bc3432bc35a781d28525 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -316,9 +316,17 @@ class BeamSearchInferVarType : public framework::VarTypeInference { */ } // namespace operators } // namespace paddle - +/* REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp, paddle::operators::BeamSearchProtoAndCheckerMaker, paddle::operators::BeamSearchInferShape, paddle::operators::BeamSearchInferVarType, paddle::framework::EmptyGradOpMaker); +*/ +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp, + ops::BeamSearchOpMaker); +REGISTER_OP_CPU_KERNEL( + beam_search, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 07cdfcf5cefa0ea9138a81b78c8fe9100eb2d8c2..dfafe1242523686f4745e34c069a0c2d78ec9bf9 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -193,7 +193,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item); template -class BeamSearchKernel : public framework::OpKernel{ +class BeamSearchOpKernel : public framework::OpKernel{ public: void Compute(const framework::ExecutionContext& context) const override { auto* ids_var = context.Input("ids");