diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index a27d197d1c2e7356cc3adc43c8686c3498384b02..cff097cca13f3b92c7efe4b69259fdf7c75b3760 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -21,8 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#include - namespace paddle { namespace operators { @@ -239,17 +237,14 @@ class BeamSearchOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput(arg), "BeamSearch need output argument '%s'", arg); } - std::cout << "Done Infer Shape\n"; } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - std::cout << "Get Expected type 1\n"; framework::OpKernelType kt = framework::OpKernelType( framework::ToDataType( ctx.Input("pre_ids")->type()), platform::CPUPlace()); - std::cout << "Get Expected type 2\n"; return kt; } }; diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 55bf48cb625b018a38350e15c54d87763dc782ae..97b039038d15c730a190fc6588b18a6a8d36bba4 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -23,8 +23,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/operator.h" -#include - namespace paddle { namespace operators { @@ -198,79 +196,25 @@ template class BeamSearchOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - std::cout << "Compute 1\n"; auto ids_var = context.Input("ids"); - std::cout << "Compute 2\n"; auto scores_var = context.Input("scores"); - std::cout << "Compute 3\n"; auto pre_ids_var = context.Input("pre_ids"); - std::cout << "Compute 4\n"; PADDLE_ENFORCE_NOT_NULL(ids_var); - std::cout << "Compute 5\n"; PADDLE_ENFORCE_NOT_NULL(scores_var); - std::cout << "Compute 6\n"; PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - std::cout << "Compute 7\n"; - // auto& ids = ids_var->Get(); - // auto& scores = scores_var->Get(); - // auto& pre_ids = pre_ids_var->Get(); size_t level = context.Attr("level"); - std::cout << "Compute 8\n"; size_t beam_size = context.Attr("beam_size"); - std::cout << "Compute 9\n"; int end_id = context.Attr("end_id"); - std::cout << "Compute 10\n"; BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id); - std::cout << "Compute 11\n"; auto selected_ids_var = context.Output("selected_ids"); - std::cout << "Compute 12\n"; auto selected_scores_var = context.Output("selected_scores"); - std::cout << "Compute 13\n"; PADDLE_ENFORCE_NOT_NULL(selected_ids_var); - std::cout << "Compute 14\n"; PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - std::cout << "Compute 15\n"; - // auto& selected_ids_tensor = - // *selected_ids_var->GetMutable(); - // auto& selected_scores_tensor = - // *selected_scores_var->GetMutable(); alg(*pre_ids_var, selected_ids_var, selected_scores_var); - std::cout << "Compute 16\n"; } }; - -/* - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - auto ids_var = scope.FindVar(Input("ids")); - auto scores_var = scope.FindVar(Input("scores")); - auto pre_ids_var = scope.FindVar(Input("pre_ids")); - PADDLE_ENFORCE_NOT_NULL(ids_var); - PADDLE_ENFORCE_NOT_NULL(scores_var); - PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - - auto& ids = ids_var->Get(); - auto& scores = scores_var->Get(); - auto& pre_ids = pre_ids_var->Get(); - size_t level = Attr("level"); - size_t beam_size = Attr("beam_size"); - int end_id = Attr("end_id"); - BeamSearch alg(ids, scores, level, beam_size, end_id); - - auto selected_ids_var = scope.FindVar(Output("selected_ids")); - auto selected_scores_var = scope.FindVar(Output("selected_scores")); - PADDLE_ENFORCE_NOT_NULL(selected_ids_var); - PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - auto& selected_ids_tensor = - *selected_ids_var->GetMutable(); - auto& selected_scores_tensor = - *selected_scores_var->GetMutable(); - alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); - } -*/ - } // namespace operators } // namespace paddle