From df70d5f1ced44f9d79461fdd7dd2e7311f62dd4f Mon Sep 17 00:00:00 2001 From: ktlichkid Date: Fri, 20 Apr 2018 11:33:52 +0800 Subject: [PATCH] Fixed some bugs --- paddle/fluid/operators/beam_search_op.cc | 20 +++++++------ paddle/fluid/operators/beam_search_op.h | 36 ++++++++++++------------ 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index f08c71ee016..f9312295b67 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -223,33 +223,37 @@ class BeamSearchOpMaker }; class BeamSearchOp : public framework::OperatorWithKernel { + /* public: BeamSearchOp(const std::string& type, const framework::VariableNameMap& inputs, const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + : OperatorWithKernel(type, inputs, outputs, attrs) {} BeamSearchOp(const BeamSearchOp& o) - : framework::OperatorBase( + : framework::OperatorWithKernel( static_cast(o)) { PADDLE_THROW("Not Implemented"); } + */ + public: + using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { for (const std::string &arg : std::vector({"pre_ids", "ids", "scores"})) { - PADDLE_ENFORCE(context->HasInput(arg), + PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'", arg); } for (const std::string &arg : std::vector({"selected_ids", "selected_scores"})) { - PADDLE_ENFORCE(context->HasOutput(arg), + PADDLE_ENFORCE(ctx->HasOutput(arg), "BeamSearch need output argument '%s'", arg); } } - +/* private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { @@ -278,9 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { *selected_scores_var->GetMutable(); alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); } - - public: - using framework::OperatorWithKernel::OperatorWithKernel; +*/ }; diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index dfafe124252..6e2e2f4daa7 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -196,33 +196,33 @@ template class BeamSearchOpKernel : public framework::OpKernel{ public: void Compute(const framework::ExecutionContext& context) const override { - auto* ids_var = context.Input("ids"); - auto* scores_var = context.Input("scores"); - auto* pre_ids_var = context.Input("pre_ids"); + auto ids_var = context.Input("ids"); + auto scores_var = context.Input("scores"); + auto pre_ids_var = context.Input("pre_ids"); PADDLE_ENFORCE_NOT_NULL(ids_var); PADDLE_ENFORCE_NOT_NULL(scores_var); PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - auto& ids = ids_var->Get(); - auto& scores = scores_var->Get(); - auto& pre_ids = pre_ids_var->Get(); + //auto& ids = ids_var->Get(); + //auto& scores = scores_var->Get(); + //auto& pre_ids = pre_ids_var->Get(); - size_t level = Attr("level"); - size_t beam_size = Attr("beam_size"); - int end_id = Attr("end_id"); - BeamSearch alg(ids, scores, level, beam_size, end_id); + size_t level = context.Attr("level"); + size_t beam_size = context.Attr("beam_size"); + int end_id = context.Attr("end_id"); + BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id); - auto* selected_ids_var = context.Output("selected_ids"); - auto* selected_scores_var = context.Output("selected_scores"); + auto selected_ids_var = context.Output("selected_ids"); + auto selected_scores_var = context.Output("selected_scores"); PADDLE_ENFORCE_NOT_NULL(selected_ids_var); PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - auto& selected_ids_tensor = - *selected_ids_var->GetMutable(); - auto& selected_scores_tensor = - *selected_scores_var->GetMutable(); - alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); + //auto& selected_ids_tensor = + // *selected_ids_var->GetMutable(); + //auto& selected_scores_tensor = + // *selected_scores_var->GetMutable(); + alg(*pre_ids_var, selected_ids_var, selected_scores_var); } -} +}; /* void RunImpl(const framework::Scope& scope, -- GitLab