diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index fdab4e92f47c7c8f241d93268a73dcb8c2eb2dc6..cff097cca13f3b92c7efe4b69259fdf7c75b3760 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -195,10 +195,9 @@ std::string ItemToString(const BeamSearch::Item &item) { return stream.str(); } -class BeamSearchProtoAndCheckerMaker - : public framework::OpProtoAndCheckerMaker { +class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { public: - BeamSearchProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { // inputs and outputs stored in proto AddInput("pre_ids", "ids in previous step"); @@ -222,20 +221,32 @@ class BeamSearchProtoAndCheckerMaker } }; -class BeamSearchInferShape : public framework::InferShapeBase { +class BeamSearchOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *context) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { for (const std::string &arg : std::vector({"pre_ids", "ids", "scores"})) { - PADDLE_ENFORCE(context->HasInput(arg), - "BeamSearch need input argument '%s'", arg); + PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'", + arg); } for (const std::string &arg : std::vector({"selected_ids", "selected_scores"})) { - PADDLE_ENFORCE(context->HasOutput(arg), + PADDLE_ENFORCE(ctx->HasOutput(arg), "BeamSearch need output argument '%s'", arg); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = framework::OpKernelType( + framework::ToDataType( + ctx.Input("pre_ids")->type()), + platform::CPUPlace()); + return kt; + } }; class BeamSearchInferVarType : public framework::VarTypeInference { @@ -254,8 +265,13 @@ class BeamSearchInferVarType : public framework::VarTypeInference { } // namespace operators } // namespace paddle -REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp, - paddle::operators::BeamSearchProtoAndCheckerMaker, - paddle::operators::BeamSearchInferShape, - paddle::operators::BeamSearchInferVarType, - paddle::framework::EmptyGradOpMaker); +namespace ops = paddle::operators; + +REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker, + ops::BeamSearchInferVarType); +REGISTER_OP_CPU_KERNEL( + beam_search, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 0a481a85ce6fbb582b8c0e12710455aaaac72aa1..9b51db8a45186c2a90cf8b2eb7966d0aaea04028 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -192,49 +192,29 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item); -class BeamSearchOp : public framework::OperatorBase { +template +class BeamSearchOpKernel : public framework::OpKernel { public: - BeamSearchOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - BeamSearchOp(const BeamSearchOp& o) - : framework::OperatorBase( - static_cast(o)) { - PADDLE_THROW("Not Implemented"); - } - - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - auto ids_var = scope.FindVar(Input("ids")); - auto scores_var = scope.FindVar(Input("scores")); - auto pre_ids_var = scope.FindVar(Input("pre_ids")); + void Compute(const framework::ExecutionContext& context) const override { + auto* ids_var = context.Input("ids"); + auto* scores_var = context.Input("scores"); + auto* pre_ids_var = context.Input("pre_ids"); PADDLE_ENFORCE_NOT_NULL(ids_var); PADDLE_ENFORCE_NOT_NULL(scores_var); PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - auto& ids = ids_var->Get(); - auto& scores = scores_var->Get(); - auto& pre_ids = pre_ids_var->Get(); - size_t level = Attr("level"); - size_t beam_size = Attr("beam_size"); - int end_id = Attr("end_id"); - BeamSearch alg(ids, scores, level, beam_size, end_id); - - auto selected_ids_var = scope.FindVar(Output("selected_ids")); - auto selected_scores_var = scope.FindVar(Output("selected_scores")); + size_t level = context.Attr("level"); + size_t beam_size = context.Attr("beam_size"); + int end_id = context.Attr("end_id"); + BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id); + auto selected_ids_var = + context.Output("selected_ids"); + auto selected_scores_var = + context.Output("selected_scores"); PADDLE_ENFORCE_NOT_NULL(selected_ids_var); PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - auto& selected_ids_tensor = - *selected_ids_var->GetMutable(); - auto& selected_scores_tensor = - *selected_scores_var->GetMutable(); - alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); + alg(*pre_ids_var, selected_ids_var, selected_scores_var); } }; - } // namespace operators } // namespace paddle