diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 5309639e9dfa2194872902270265e63ae5c3f9bd..ecee017b0e4a72618a2756494ef50496b5e1238e 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -235,9 +235,19 @@ class BeamSearchOp : public framework::OperatorWithKernel { static_cast(o)) { PADDLE_THROW("Not Implemented"); } + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - + for (const std::string &arg : + std::vector({"pre_ids", "ids", "scores"})) { + PADDLE_ENFORCE(context->HasInput(arg), + "BeamSearch need input argument '%s'", arg); + } + for (const std::string &arg : + std::vector({"selected_ids", "selected_scores"})) { + PADDLE_ENFORCE(context->HasOutput(arg), + "BeamSearch need output argument '%s'", arg); + } } private: diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index e6221313f72f107da44e819fc29849e3f54b121b..07cdfcf5cefa0ea9138a81b78c8fe9100eb2d8c2 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -194,8 +194,65 @@ std::string ItemToString(const BeamSearch::Item& item); template class BeamSearchKernel : public framework::OpKernel{ + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* ids_var = context.Input("ids"); + auto* scores_var = context.Input("scores"); + auto* pre_ids_var = context.Input("pre_ids"); + PADDLE_ENFORCE_NOT_NULL(ids_var); + PADDLE_ENFORCE_NOT_NULL(scores_var); + PADDLE_ENFORCE_NOT_NULL(pre_ids_var); + + auto& ids = ids_var->Get(); + auto& scores = scores_var->Get(); + auto& pre_ids = pre_ids_var->Get(); + size_t level = Attr("level"); + size_t beam_size = Attr("beam_size"); + int end_id = Attr("end_id"); + BeamSearch alg(ids, scores, level, beam_size, end_id); + + auto* selected_ids_var = context.Output("selected_ids"); + auto* selected_scores_var = context.Output("selected_scores"); + PADDLE_ENFORCE_NOT_NULL(selected_ids_var); + PADDLE_ENFORCE_NOT_NULL(selected_scores_var); + auto& selected_ids_tensor = + *selected_ids_var->GetMutable(); + auto& selected_scores_tensor = + *selected_scores_var->GetMutable(); + alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); + } } +/* + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + auto ids_var = scope.FindVar(Input("ids")); + auto scores_var = scope.FindVar(Input("scores")); + auto pre_ids_var = scope.FindVar(Input("pre_ids")); + PADDLE_ENFORCE_NOT_NULL(ids_var); + PADDLE_ENFORCE_NOT_NULL(scores_var); + PADDLE_ENFORCE_NOT_NULL(pre_ids_var); + + auto& ids = ids_var->Get(); + auto& scores = scores_var->Get(); + auto& pre_ids = pre_ids_var->Get(); + size_t level = Attr("level"); + size_t beam_size = Attr("beam_size"); + int end_id = Attr("end_id"); + BeamSearch alg(ids, scores, level, beam_size, end_id); + + auto selected_ids_var = scope.FindVar(Output("selected_ids")); + auto selected_scores_var = scope.FindVar(Output("selected_scores")); + PADDLE_ENFORCE_NOT_NULL(selected_ids_var); + PADDLE_ENFORCE_NOT_NULL(selected_scores_var); + auto& selected_ids_tensor = + *selected_ids_var->GetMutable(); + auto& selected_scores_tensor = + *selected_scores_var->GetMutable(); + alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); + } +*/ + } // namespace operators } // namespace paddle