From 12db9f3c5f4b9e10d954f24446a098d52f394799 Mon Sep 17 00:00:00 2001 From: superjomn Date: Mon, 22 Apr 2019 16:43:31 +0800 Subject: [PATCH] make the predictor works from a faked model --- paddle/fluid/lite/api/CMakeLists.txt | 2 +- paddle/fluid/lite/api/cxx_api.h | 25 +++++++- paddle/fluid/lite/api/cxx_api_test.cc | 36 ++++------- paddle/fluid/lite/core/mir/CMakeLists.txt | 2 +- .../fluid/lite/core/mir/io_complement_pass.cc | 4 +- paddle/fluid/lite/core/mir/ssa_graph.h | 27 ++++++-- paddle/fluid/lite/core/op_executor_test.cc | 18 +++--- paddle/fluid/lite/core/op_lite.h | 3 + paddle/fluid/lite/core/optimizer.cc | 16 +++++ paddle/fluid/lite/core/optimizer.h | 12 +++- paddle/fluid/lite/core/program.h | 28 ++++++--- paddle/fluid/lite/core/type_system.cc | 39 ++++++++++++ paddle/fluid/lite/core/type_system.h | 9 +++ paddle/fluid/lite/kernels/host/CMakeLists.txt | 4 +- .../fluid/lite/kernels/host/feed_compute.cc | 9 ++- .../fluid/lite/kernels/host/fetch_compute.cc | 50 +++++++++++++++ .../fluid/lite/kernels/host/scale_compute.cc | 4 ++ paddle/fluid/lite/operators/CMakeLists.txt | 2 + paddle/fluid/lite/operators/fc_op.h | 2 + paddle/fluid/lite/operators/feed_op.cc | 4 ++ paddle/fluid/lite/operators/fetch_op.cc | 61 +++++++++++++++++++ paddle/fluid/lite/operators/io_copy_op.h | 2 + paddle/fluid/lite/operators/mul_op.h | 1 + paddle/fluid/lite/operators/op_params.h | 14 +++-- paddle/fluid/lite/operators/relu_op.h | 1 + paddle/fluid/lite/operators/scale_op.cc | 2 + 26 files changed, 315 insertions(+), 62 deletions(-) create mode 100644 paddle/fluid/lite/kernels/host/fetch_compute.cc create mode 100644 paddle/fluid/lite/operators/fetch_op.cc diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 06e49d363bb..c2c40c36bcc 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite) +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index a5c34fad37a..218f24c7a3c 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -17,6 +17,7 @@ #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/optimizer.h" #include "paddle/fluid/lite/core/program.h" +#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/model_parser/model_parser.h" namespace paddle { @@ -30,7 +31,6 @@ class Predictor { void Build(const std::string& model_path, const std::vector& valid_places) { - CHECK(!scope_.get()) << "duplicate build found"; framework::proto::ProgramDesc prog; LoadModel(model_path, scope_.get(), &prog); framework::ProgramDesc prog_desc(prog); @@ -38,10 +38,31 @@ class Predictor { Program program(prog_desc, scope_, valid_places); Optimizer optimizer; - optimizer.Run(std::move(program), valid_places); + core::KernelPickFactor factor; + factor.ConsiderTarget(); + optimizer.Run(std::move(program), valid_places, factor); program_ = optimizer.GenRuntimeProgram(); } + // Get offset-th col of feed. + Tensor* GetInput(size_t offset) { + auto* _feed_list = program_->exec_scope()->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto* feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + return &feed_list->at(offset); + } + + const Tensor* GetOutput(size_t offset) { + auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto fetch_list = _fetch_list->Get>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return &fetch_list.at(offset); + } + void Run() { program_->Run(); } private: diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 3f8a15170a0..515ef9c682e 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -14,36 +14,22 @@ #include "paddle/fluid/lite/api/cxx_api.h" #include +#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/op_executor.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { namespace lite { -TEST(CXXApi, raw) { - Scope scope; - framework::proto::ProgramDesc prog; - LoadModel("/home/chunwei/project2/models/model2", &scope, &prog); - framework::ProgramDesc prog_desc(prog); - - lite::Executor executor(&scope, {Place{TARGET(kHost), PRECISION(kFloat)}}); - - auto x = scope.Var("a")->GetMutable(); - x->Resize({100, 100}); - x->mutable_data(); - - executor.PrepareWorkspace(prog_desc); - executor.Build(prog_desc); - executor.Run(); -} - TEST(CXXApi, test) { lite::Predictor predictor; predictor.Build("/home/chunwei/project2/models/model2", {Place{TARGET(kHost), PRECISION(kFloat)}}); - auto* x = predictor.GetInputTensor("a"); - x->Resize({100, 200}); - x->mutable_data(); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize({100, 100}); + input_tensor->mutable_data(); + predictor.Run(); } } // namespace lite @@ -52,6 +38,10 @@ TEST(CXXApi, test) { USE_LITE_OP(mul); USE_LITE_OP(fc); USE_LITE_OP(scale); -USE_LITE_KERNEL(fc, kHost, kFloat); -USE_LITE_KERNEL(mul, kHost, kFloat); -USE_LITE_KERNEL(scale, kHost, kFloat); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_KERNEL(fc, kHost, kFloat, def); +USE_LITE_KERNEL(mul, kHost, kFloat, def); +USE_LITE_KERNEL(scale, kHost, kFloat, def); +USE_LITE_KERNEL(feed, kHost, kFloat, def); +USE_LITE_KERNEL(fetch, kHost, kFloat, def); diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index f003470ffcf..1b6e980927c 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(mir_node SRCS node.cc) cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) -cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph) +cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_passes SRCS static_kernel_pick_pass.cc diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index e962c8dae4e..17bbfb948f2 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr& graph) { inst.place, inst.op_type, tmp); CHECK(type) << "no param type found for " << inst.op_type << ":" << name << " " << inst.place; - if (type->tensor_place != inst.place) { - LOG(INFO) << "found IO unmatched tensor"; + if (type->tensor_place != in->AsArgument().place) { + LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name; } } } diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 2f7922cdbd2..81b1d8565ef 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -36,7 +36,7 @@ class SSAGraph : GraphBase { // @param program: the op program // @param valid_places: the valid places user set for the system. void Build(const Program &program, const std::vector &valid_places) { - // create inputs + // create temporary nodes. for (const auto &name : program.tmp_vars) { node_storage_.emplace_back(); auto &new_node = node_storage_.back(); @@ -45,20 +45,33 @@ class SSAGraph : GraphBase { arguments_[name] = &new_node; } + // create weight nodes. + for (const auto &name : program.weights) { + node_storage_.emplace_back(); + auto &new_node = node_storage_.back(); + auto &arg = new_node.AsArgument(); + arg.name = name; + arguments_[name] = &new_node; + } + for (auto &op : program.ops) { node_storage_.emplace_back(); // TODO(Superjomn) remove one valid_places here. op->SetValidPlaces(valid_places); auto &new_node = node_storage_.back(); - node_storage_.back().AsInstruct( - op->op_type_, op->CreateKernels(valid_places), op, op->op_info()); + auto kernels = op->CreateKernels(valid_places); + for (auto &kernel : kernels) { + op->AttachKernel(kernel.get()); + } + node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op, + op->op_info()); CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found"; // collect inputs and outputs for (const std::string &name : op->op_info()->input_names()) { - auto *arg = arguments_.at(name); + auto *arg = Argument(name); new_node.inlinks.push_back(arg); arg->outlinks.push_back(&new_node); } @@ -79,6 +92,12 @@ class SSAGraph : GraphBase { MarkArgumentWeights(program); } + mir::Node *Argument(const std::string &name) { + auto it = arguments_.find(name); + CHECK(it != arguments_.end()) << "no argument called " << name; + return it->second; + } + std::vector InstructTopologicalOrder(); // The inputs of the graph. diff --git a/paddle/fluid/lite/core/op_executor_test.cc b/paddle/fluid/lite/core/op_executor_test.cc index 650d2bb1e7f..51912b363a8 100644 --- a/paddle/fluid/lite/core/op_executor_test.cc +++ b/paddle/fluid/lite/core/op_executor_test.cc @@ -20,12 +20,9 @@ namespace paddle { namespace lite { TEST(executor, test) { - std::vector valid_places{ - OpLite::Place{TARGET(kHost), PRECISION(kFloat)}}; + std::vector valid_places{Place{TARGET(kHost), PRECISION(kFloat)}}; - Scope scope; - - Executor executor(&scope, valid_places); + auto scope = std::make_shared(); framework::ProgramDesc program; program.MutableBlock(0)->Var("x"); @@ -42,19 +39,18 @@ TEST(executor, test) { op_desc.SetAttr("in_num_col_dims", static_cast(1)); program.Flush(); - auto* w = scope.Var("w")->GetMutable(); + auto* w = scope->Var("w")->GetMutable(); w->Resize({20, 20}); - auto* x = scope.Var("x")->GetMutable(); + auto* x = scope->Var("x")->GetMutable(); x->Resize({1, 10, 20}); - auto* bias = scope.Var("bias")->GetMutable(); + auto* bias = scope->Var("bias")->GetMutable(); bias->Resize({1, 20}); bias->mutable_data(); w->mutable_data(); x->mutable_data(); - executor.PrepareWorkspace(program); - executor.Build(program); + lite::Executor executor(program, scope, valid_places); executor.Run(); } @@ -62,4 +58,4 @@ TEST(executor, test) { } // namespace paddle USE_LITE_OP(fc); -USE_LITE_KERNEL(fc, kHost, kFloat); +USE_LITE_KERNEL(fc, kHost, kFloat, def); diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 61013667123..f754780ad52 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -101,6 +101,9 @@ class OpLite : public Registry { virtual bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) = 0; + // Assign op param to kernel. + virtual void AttachKernel(KernelBase *kernel) = 0; + // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. virtual void StaticPickKernel(const std::vector &valid_targets) { diff --git a/paddle/fluid/lite/core/optimizer.cc b/paddle/fluid/lite/core/optimizer.cc index ce71e4de2b8..cd80f786793 100644 --- a/paddle/fluid/lite/core/optimizer.cc +++ b/paddle/fluid/lite/core/optimizer.cc @@ -11,3 +11,19 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/optimizer.h" +#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" + +namespace paddle { +namespace lite { + +void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) { + auto* pass = mir::PassManager::Global().LookUp( + "static_kernel_pick_pass"); + CHECK(pass); + + *pass->mutable_kernel_pick_factors() = factor; +} +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index d6e170b775b..7bd6a2476bd 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -19,6 +19,7 @@ #include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/program.h" +#include "paddle/fluid/lite/core/types.h" namespace paddle { namespace lite { @@ -30,11 +31,14 @@ namespace lite { class Optimizer { public: void Run(Program&& program, const std::vector& valid_places, + core::KernelPickFactor kernel_pick_factor, const std::vector& passes = {}) { CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); + SpecifyKernelPickTactic(kernel_pick_factor); RunPasses(); + exec_scope_ = program.exec_scope; } // Generate a new program based on the mir graph. @@ -42,7 +46,10 @@ class Optimizer { std::unique_ptr res; auto pass = mir::PassManager::Global().LookUp( "generate_program_pass"); - return pass->GenProgram(); + auto program = pass->GenProgram(); + CHECK(exec_scope_); + program->set_exec_scope(exec_scope_); + return program; } // Generate C++ code which combines the inference program, model and weights. @@ -54,6 +61,8 @@ class Optimizer { } protected: + void SpecifyKernelPickTactic(core::KernelPickFactor factor); + // Run the default passes registered in the PassManager. void RunPasses() { mir::PassManager::Global().Run(graph_); } @@ -62,6 +71,7 @@ class Optimizer { private: std::unique_ptr graph_; + lite::Scope* exec_scope_{}; }; } // namespace lite diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index e19ecc0f569..8046a3473fc 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -35,23 +35,22 @@ struct Program { std::list> ops; // the scope to run the kernels, NOTE not the root scope. std::shared_ptr scope; + std::vector valid_places; // Runtime scope. lite::Scope* exec_scope{}; + const framework::ProgramDesc desc; explicit Program(const std::shared_ptr& root) { scope = root; } Program(const framework::ProgramDesc& desc, const std::shared_ptr& root, - const std::vector& valid_places) { - scope = root; + const std::vector& valid_places) + : scope(root), valid_places(valid_places), desc(desc) { PrepareWorkspace(desc); Build(desc, valid_places); } std::unique_ptr Clone() const { - std::unique_ptr res(new Program(scope)); - res->tmp_vars = tmp_vars; - res->weights = weights; - res->ops = ops; + std::unique_ptr res(new Program(desc, scope, valid_places)); return res; } @@ -64,7 +63,7 @@ struct Program { // Create operators. for (auto* op_desc : program.Block(0).AllOps()) { auto op_type = op_desc->Type(); - if (op_type == "feed" || op_type == "fetch") continue; + // if (op_type == "feed" || op_type == "fetch") continue; LOG(INFO) << "create Op [" << op_type << "]"; ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); // pick initial kernel @@ -77,11 +76,22 @@ struct Program { void PrepareWorkspace(const framework::ProgramDesc& program) { CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; exec_scope = &scope->NewScope(); + // Create Feed and Fetch var. + scope->Var("feed")->GetMutable>(); + scope->Var("fetch")->GetMutable>(); + tmp_vars.push_back("feed"); + tmp_vars.push_back("fetch"); for (auto var_desc : program.Block(0).AllVars()) { if (!var_desc->Persistable()) { + LOG(INFO) << "get tmp var " << var_desc->Name(); + tmp_vars.push_back(var_desc->Name()); auto* var = exec_scope->Var(var_desc->Name()); LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var; + } else { + if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue; + LOG(INFO) << "get weight var " << var_desc->Name(); + weights.push_back(var_desc->Name()); } } } @@ -118,11 +128,15 @@ class RuntimeProgram { } } + void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } + lite::Scope* exec_scope() { return exec_scope_; } + size_t num_instructions() const { return instructions_.size(); } private: RuntimeProgram(const RuntimeProgram&) = delete; std::vector instructions_; + lite::Scope* exec_scope_{}; }; } // namespace lite diff --git a/paddle/fluid/lite/core/type_system.cc b/paddle/fluid/lite/core/type_system.cc index 95e4db7f105..b396a4e147c 100644 --- a/paddle/fluid/lite/core/type_system.cc +++ b/paddle/fluid/lite/core/type_system.cc @@ -48,6 +48,45 @@ const Type* Type::Get(TargetType target) { DataLayoutType::kNCHW>(); } +template +TensorListAnyTy* GetTensorListAnyTy() { + static TensorListAnyTy x(Target); + return &x; +} +template +TensorAnyTy* GetTensorAnyTy() { + static TensorAnyTy x(Target); + return &x; +} + +template <> +const Type* Type::Get(TargetType target) { + switch (target) { + case TargetType::kHost: + return GetTensorListAnyTy(); + case TargetType::kCUDA: + return GetTensorListAnyTy(); + case TargetType::kX86: + return GetTensorListAnyTy(); + default: + LOG(FATAL) << "unsupported type"; + } +} + +template <> +const Type* Type::Get(TargetType target) { + switch (target) { + case TargetType::kHost: + return GetTensorAnyTy(); + case TargetType::kCUDA: + return GetTensorAnyTy(); + case TargetType::kX86: + return GetTensorAnyTy(); + default: + LOG(FATAL) << "unsupported type"; + } +} + template <> const Type* Type::Get(TargetType target) { switch (target) { diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 5fa1df0ca93..a22c4e59c6f 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -60,6 +60,8 @@ class DataTypeBase { // Tensor_Any represents a Tensor with any place, data, layout. It is used // in some IO kernels those doesn't care the data. Tensor_Any, + // Used by feed or fetch op. + TensorList_Any, NumTypes, // Must remains as last defined ID. }; @@ -146,6 +148,13 @@ class TensorAnyTy : public Type { : Type(ID::Tensor_Any, "TensorAny", true, target, PRECISION(kAny), DATALAYOUT(kAny)) {} }; +// A list of tensor, and no assumption on the data layout or data type. +class TensorListAnyTy : public Type { + public: + TensorListAnyTy(TargetType target) + : Type(ID::TensorList_Any, "TensorList_Any", false, target, + PRECISION(kAny), DATALAYOUT(kAny)) {} +}; class TensorFp32NCHWTy : public Type { public: TensorFp32NCHWTy(TargetType target) diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 03501dce5a7..539bc04a7d9 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps}) cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps}) cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) +cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) cc_library(host_kernels DEPS + feed_compute_host + fetch_compute_host fc_compute_host relu_compute_host mul_compute_host scale_compute_host - feed_compute_host DEPS ${lite_kernel_deps} ) diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 342d5d55573..3f8015614bf 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/type_system.h" @@ -26,9 +25,9 @@ class FeedCompute : public OpKernel { using param_t = operators::FeedParam; void Run() override { - auto &theparam = Param(); - const Tensor &feed_item = theparam.feed_list->at(theparam.col); - theparam.out->CopyDataFrom(feed_item); + auto ¶m = Param(); + const Tensor &feed_item = param.feed_list->at(param.col); + param.out->CopyDataFrom(feed_item); } }; @@ -39,7 +38,7 @@ class FeedCompute : public OpKernel { REGISTER_LITE_KERNEL(feed, kHost, kFloat, paddle::lite::kernels::host::FeedCompute, def) - .BindInput("X", {paddle::lite::Type::Get( + .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) .BindOutput("Out", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/fetch_compute.cc b/paddle/fluid/lite/kernels/host/fetch_compute.cc new file mode 100644 index 00000000000..4bc71266ed2 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/fetch_compute.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class FetchCompute : public OpKernel { + public: + using param_t = operators::FeedParam; + + void Run() override { + auto& param = Param(); + auto* fetch_list = param.fetch_list; + if (fetch_list->size() <= static_cast(param.col)) { + fetch_list->resize(param.col + 1); + } + + auto& dst = fetch_list->at(param.col); + dst.CopyDataFrom(*param.input); + } +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(fetch, kHost, kFloat, + paddle::lite::kernels::host::FetchCompute, def) + .BindInput("X", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index 490792be6aa..ebf6f2ff4b3 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel { REGISTER_LITE_KERNEL(scale, kHost, kFloat, paddle::lite::kernels::host::ScaleCompute, def) + .BindInput("X", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 10122410659..004178f2f6a 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) +cc_library(fetch_op_lite SRCS fetch_op.cc DEPS op_lite) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite) cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) @@ -12,6 +13,7 @@ cc_library(ops_lite DEPS mul_op_lite scale_op_lite feed_op_lite + fetch_op_lite io_copy_op_lite ) diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index cd8e5064639..15d1693b719 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -67,6 +67,8 @@ class FcOpLite : public OpLite { return true; } + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "fc"; } private: diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 45d2f640a35..041fe07f634 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -31,6 +31,9 @@ class FeedOp : public OpLite { bool InferShape() const override { return true; } + protected: + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + protected: bool AttachImpl(const framework::OpDesc& opdesc, lite::Scope* scope) override { @@ -48,6 +51,7 @@ class FeedOp : public OpLite { // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc param_.col = boost::get(opdesc.GetAttr("col")); + kernel_->SetParam(param_); return true; } diff --git a/paddle/fluid/lite/operators/fetch_op.cc b/paddle/fluid/lite/operators/fetch_op.cc new file mode 100644 index 00000000000..f4e53c6699a --- /dev/null +++ b/paddle/fluid/lite/operators/fetch_op.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FetchOp : public OpLite { + public: + explicit FetchOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.input); + CHECK_OR_FALSE(param_.fetch_list); + return true; + } + + bool InferShape() const override { return true; } + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + protected: + bool AttachImpl(const framework::OpDesc& opdesc, + lite::Scope* scope) override { + auto _x = opdesc.Input("X").front(); + auto* x = scope->FindVar(_x); + CHECK(x); + param_.input = &x->Get(); + + auto _out = opdesc.Output("Out").front(); + auto* out = scope->FindVar(_out); + param_.fetch_list = out->GetMutable>(); + + param_.col = boost::get(opdesc.GetAttr("col")); + return true; + } + + std::string DebugString() const override { return "fetch"; } + + private: + mutable FetchParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fetch, paddle::lite::operators::FetchOp); diff --git a/paddle/fluid/lite/operators/io_copy_op.h b/paddle/fluid/lite/operators/io_copy_op.h index 0df41501f76..7d07f333576 100644 --- a/paddle/fluid/lite/operators/io_copy_op.h +++ b/paddle/fluid/lite/operators/io_copy_op.h @@ -28,6 +28,8 @@ class IoCopyOp : public OpLite { bool Run() override; std::string DebugString() const override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + protected: bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index effab7b8cbd..36770cb4d56 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -36,6 +36,7 @@ class MulOpLite : public OpLite { bool InferShape() const override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const framework::OpDesc &op_desc, lite::Scope *scope) override { diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 2207e2a093e..8f1d007d85e 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -25,8 +25,14 @@ namespace lite { namespace operators { struct FeedParam { - const std::vector* feed_list; - Tensor* out; + const std::vector* feed_list{}; + Tensor* out{}; + int col; +}; + +struct FetchParam { + const Tensor* input{}; + std::vector* fetch_list{}; int col; }; @@ -69,8 +75,8 @@ struct IoCopyParam { Tensor* y{}; }; -using param_t = - variant; +using param_t = variant; } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index 3a73b45a273..8fa311373f0 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -34,6 +34,7 @@ class ReluOp : public OpLite { bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "tanh"; } private: diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index eb99fec3f46..42a6a588914 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -43,6 +43,8 @@ class ScaleOp : public OpLite { return true; } + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const framework::OpDesc &op_desc, lite::Scope *scope) override { -- GitLab