From 6f06b32258c97273fbb998e67180523fa71621a4 Mon Sep 17 00:00:00 2001 From: ktlichkid Date: Fri, 20 Apr 2018 15:51:45 +0800 Subject: [PATCH] Added GetExpectedKernelType and Debug message --- paddle/fluid/operators/beam_search_op.cc | 13 +++++++++ paddle/fluid/operators/beam_search_op.h | 36 ++++++++++++++++++------ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index f9312295b6..0499d8cbef 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include + namespace paddle { namespace operators { @@ -252,6 +254,17 @@ class BeamSearchOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput(arg), "BeamSearch need output argument '%s'", arg); } + std::cout << "Done Infer Shape\n"; + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + std::cout << "Get Expected type 1\n"; + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + std::cout << "Get Expected type 2\n"; + kt.place_ = ctx.Input("pre_ids")->place(); + std::cout << "Get Expected type 3\n"; + return kt; } /* private: diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 6e2e2f4daa..1487905ce8 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -23,6 +23,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/operator.h" +#include + namespace paddle { namespace operators { @@ -196,31 +198,47 @@ template class BeamSearchOpKernel : public framework::OpKernel{ public: void Compute(const framework::ExecutionContext& context) const override { + std::cout << "Compute 1\n"; auto ids_var = context.Input("ids"); + std::cout << "Compute 2\n"; auto scores_var = context.Input("scores"); + std::cout << "Compute 3\n"; auto pre_ids_var = context.Input("pre_ids"); + std::cout << "Compute 4\n"; PADDLE_ENFORCE_NOT_NULL(ids_var); + std::cout << "Compute 5\n"; PADDLE_ENFORCE_NOT_NULL(scores_var); + std::cout << "Compute 6\n"; PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - - //auto& ids = ids_var->Get(); - //auto& scores = scores_var->Get(); - //auto& pre_ids = pre_ids_var->Get(); + std::cout << "Compute 7\n"; + // auto& ids = ids_var->Get(); + // auto& scores = scores_var->Get(); + // auto& pre_ids = pre_ids_var->Get(); size_t level = context.Attr("level"); + std::cout << "Compute 8\n"; size_t beam_size = context.Attr("beam_size"); + std::cout << "Compute 9\n"; int end_id = context.Attr("end_id"); + std::cout << "Compute 10\n"; BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id); - - auto selected_ids_var = context.Output("selected_ids"); - auto selected_scores_var = context.Output("selected_scores"); + std::cout << "Compute 11\n"; + auto selected_ids_var = + context.Output("selected_ids"); + std::cout << "Compute 12\n"; + auto selected_scores_var = + context.Output("selected_scores"); + std::cout << "Compute 13\n"; PADDLE_ENFORCE_NOT_NULL(selected_ids_var); + std::cout << "Compute 14\n"; PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - //auto& selected_ids_tensor = + std::cout << "Compute 15\n"; + // auto& selected_ids_tensor = // *selected_ids_var->GetMutable(); - //auto& selected_scores_tensor = + // auto& selected_scores_tensor = // *selected_scores_var->GetMutable(); alg(*pre_ids_var, selected_ids_var, selected_scores_var); + std::cout << "Compute 16\n"; } }; -- GitLab