From 881ea62bbf071284f89fcd2e64cd79157ccb8a53 Mon Sep 17 00:00:00 2001 From: ktlichkid Date: Thu, 19 Apr 2018 15:14:31 +0800 Subject: [PATCH] Added BeamSearchOpMaker class --- paddle/fluid/operators/beam_search_op.cc | 59 ++++++++++++++++++++++-- paddle/fluid/operators/beam_search_op.h | 52 ++------------------- 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index fdab4e92f47..5309639e9df 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -195,10 +195,10 @@ std::string ItemToString(const BeamSearch::Item &item) { return stream.str(); } -class BeamSearchProtoAndCheckerMaker +class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { public: - BeamSearchProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) + BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { // inputs and outputs stored in proto AddInput("pre_ids", "ids in previous step"); @@ -222,6 +222,59 @@ class BeamSearchProtoAndCheckerMaker } }; +class BeamSearchOp : public framework::OperatorWithKernel { + public: + BeamSearchOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + BeamSearchOp(const BeamSearchOp& o) + : framework::OperatorBase( + static_cast(o)) { + PADDLE_THROW("Not Implemented"); + } + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + + } + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + auto ids_var = scope.FindVar(Input("ids")); + auto scores_var = scope.FindVar(Input("scores")); + auto pre_ids_var = scope.FindVar(Input("pre_ids")); + PADDLE_ENFORCE_NOT_NULL(ids_var); + PADDLE_ENFORCE_NOT_NULL(scores_var); + PADDLE_ENFORCE_NOT_NULL(pre_ids_var); + + auto& ids = ids_var->Get(); + auto& scores = scores_var->Get(); + auto& pre_ids = pre_ids_var->Get(); + size_t level = Attr("level"); + size_t beam_size = Attr("beam_size"); + int end_id = Attr("end_id"); + BeamSearch alg(ids, scores, level, beam_size, end_id); + + auto selected_ids_var = scope.FindVar(Output("selected_ids")); + auto selected_scores_var = scope.FindVar(Output("selected_scores")); + PADDLE_ENFORCE_NOT_NULL(selected_ids_var); + PADDLE_ENFORCE_NOT_NULL(selected_scores_var); + auto& selected_ids_tensor = + *selected_ids_var->GetMutable(); + auto& selected_scores_tensor = + *selected_scores_var->GetMutable(); + alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + + +/* class BeamSearchInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { @@ -250,7 +303,7 @@ class BeamSearchInferVarType : public framework::VarTypeInference { } } }; - +*/ } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 11ca9b15c59..e6221313f72 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -192,56 +192,10 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item); -class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker{ - public: - MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker){ - } -} - -class BeamSearchOp : public framework::OperatorBase { - public: - BeamSearchOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - BeamSearchOp(const BeamSearchOp& o) - : framework::OperatorBase( - static_cast(o)) { - PADDLE_THROW("Not Implemented"); - } +template +class BeamSearchKernel : public framework::OpKernel{ - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - auto ids_var = scope.FindVar(Input("ids")); - auto scores_var = scope.FindVar(Input("scores")); - auto pre_ids_var = scope.FindVar(Input("pre_ids")); - PADDLE_ENFORCE_NOT_NULL(ids_var); - PADDLE_ENFORCE_NOT_NULL(scores_var); - PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - - auto& ids = ids_var->Get(); - auto& scores = scores_var->Get(); - auto& pre_ids = pre_ids_var->Get(); - size_t level = Attr("level"); - size_t beam_size = Attr("beam_size"); - int end_id = Attr("end_id"); - BeamSearch alg(ids, scores, level, beam_size, end_id); - - auto selected_ids_var = scope.FindVar(Output("selected_ids")); - auto selected_scores_var = scope.FindVar(Output("selected_scores")); - PADDLE_ENFORCE_NOT_NULL(selected_ids_var); - PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - auto& selected_ids_tensor = - *selected_ids_var->GetMutable(); - auto& selected_scores_tensor = - *selected_scores_var->GetMutable(); - alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); - } -}; +} } // namespace operators } // namespace paddle -- GitLab