From 294b58a9bae76366b6e4117a6ae8dc44e4311ad2 Mon Sep 17 00:00:00 2001 From: ktlichkid Date: Mon, 23 Apr 2018 17:22:05 +0800 Subject: [PATCH] Changed registered type --- paddle/fluid/operators/beam_search_op.cc | 93 ++++-------------------- paddle/fluid/operators/beam_search_op.h | 2 +- 2 files changed, 14 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index b0e284a2603..c1ff262169a 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) { return stream.str(); } -class BeamSearchOpMaker - : public framework::OpProtoAndCheckerMaker { +class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { public: BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { @@ -225,29 +224,15 @@ class BeamSearchOpMaker }; class BeamSearchOp : public framework::OperatorWithKernel { - /* - public: - BeamSearchOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - BeamSearchOp(const BeamSearchOp& o) - : framework::OperatorWithKernel( - static_cast(o)) { - PADDLE_THROW("Not Implemented"); - } - */ public: using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { for (const std::string &arg : std::vector({"pre_ids", "ids", "scores"})) { - PADDLE_ENFORCE(ctx->HasInput(arg), - "BeamSearch need input argument '%s'", arg); + PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'", + arg); } for (const std::string &arg : std::vector({"selected_ids", "selected_scores"})) { @@ -263,62 +248,13 @@ class BeamSearchOp : public framework::OperatorWithKernel { framework::OpKernelType kt = framework::OpKernelType( framework::ToDataType( ctx.Input("pre_ids")->type()), - platform::CPUPlace()); + platform::CPUPlace()); std::cout << "Get Expected type 2\n"; - // kt.place_ = ctx.Input("pre_ids")->place(); - // std::cout << "Get Expected type 3\n"; return kt; } -/* - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - auto ids_var = scope.FindVar(Input("ids")); - auto scores_var = scope.FindVar(Input("scores")); - auto pre_ids_var = scope.FindVar(Input("pre_ids")); - PADDLE_ENFORCE_NOT_NULL(ids_var); - PADDLE_ENFORCE_NOT_NULL(scores_var); - PADDLE_ENFORCE_NOT_NULL(pre_ids_var); - - auto& ids = ids_var->Get(); - auto& scores = scores_var->Get(); - auto& pre_ids = pre_ids_var->Get(); - size_t level = Attr("level"); - size_t beam_size = Attr("beam_size"); - int end_id = Attr("end_id"); - BeamSearch alg(ids, scores, level, beam_size, end_id); - - auto selected_ids_var = scope.FindVar(Output("selected_ids")); - auto selected_scores_var = scope.FindVar(Output("selected_scores")); - PADDLE_ENFORCE_NOT_NULL(selected_ids_var); - PADDLE_ENFORCE_NOT_NULL(selected_scores_var); - auto& selected_ids_tensor = - *selected_ids_var->GetMutable(); - auto& selected_scores_tensor = - *selected_scores_var->GetMutable(); - alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); - } -*/ }; -/* -class BeamSearchInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *context) const override { - for (const std::string &arg : - std::vector({"pre_ids", "ids", "scores"})) { - PADDLE_ENFORCE(context->HasInput(arg), - "BeamSearch need input argument '%s'", arg); - } - for (const std::string &arg : - std::vector({"selected_ids", "selected_scores"})) { - PADDLE_ENFORCE(context->HasOutput(arg), - "BeamSearch need output argument '%s'", arg); - } - } -}; -*/ class BeamSearchInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc &op_desc, @@ -334,18 +270,15 @@ class BeamSearchInferVarType : public framework::VarTypeInference { } // namespace operators } // namespace paddle -/* -REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp, - paddle::operators::BeamSearchProtoAndCheckerMaker, - paddle::operators::BeamSearchInferShape, - paddle::operators::BeamSearchInferVarType, - paddle::framework::EmptyGradOpMaker); -*/ + + namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp, - ops::BeamSearchOpMaker, - ops::BeamSearchInferVarType); + +REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker, + ops::BeamSearchInferVarType); REGISTER_OP_CPU_KERNEL( beam_search, ops::BeamSearchOpKernel, - ops::BeamSearchOpKernel); + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 1487905ce80..55bf48cb625 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -195,7 +195,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item); template -class BeamSearchOpKernel : public framework::OpKernel{ +class BeamSearchOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { std::cout << "Compute 1\n"; -- GitLab