From 0245a2dd7b86f88cfdb3522f6509041bb4ab4f1b Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 28 Apr 2019 11:13:47 +0800 Subject: [PATCH] add variable inference pass tester and code clean --- paddle/fluid/lite/core/kernel.h | 5 +- paddle/fluid/lite/core/mir/CMakeLists.txt | 11 +++ .../fluid/lite/core/mir/io_complement_pass.cc | 36 +-------- .../fluid/lite/core/mir/io_complement_pass.h | 4 - paddle/fluid/lite/core/mir/pass_registry.h | 2 +- paddle/fluid/lite/core/mir/ssa_graph.cc | 11 +-- paddle/fluid/lite/core/mir/ssa_graph_test.cc | 2 +- .../lite/core/mir/static_kernel_pick_pass.cc | 3 +- .../lite/core/mir/static_kernel_pick_pass.h | 8 +- .../core/mir/variable_place_inference_pass.h | 25 +++--- .../mir/variable_place_inference_pass_test.cc | 80 +++++++++++++++++++ paddle/fluid/lite/core/op_lite.cc | 2 +- paddle/fluid/lite/core/op_registry.cc | 2 +- paddle/fluid/lite/core/op_registry.h | 12 +-- paddle/fluid/lite/core/optimizer.cc | 28 ------- paddle/fluid/lite/core/optimizer.h | 32 ++++++-- paddle/fluid/lite/core/optimizer_test.cc | 2 +- paddle/fluid/lite/core/program.h | 2 +- paddle/fluid/lite/core/program_fake_utils.h | 63 +++++++++++++++ paddle/fluid/lite/core/type_system.h | 49 +----------- paddle/fluid/lite/kernels/cuda/use_kernels.h | 24 ++++++ paddle/fluid/lite/kernels/host/use_kernels.h | 22 +++++ paddle/fluid/lite/model_parser/runtime.h | 4 +- 23 files changed, 272 insertions(+), 157 deletions(-) create mode 100644 paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc create mode 100644 paddle/fluid/lite/kernels/cuda/use_kernels.h create mode 100644 paddle/fluid/lite/kernels/host/use_kernels.h diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index c6acf92bb..a984a7a99 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -96,10 +96,7 @@ class KernelBase { return type->type; } - void set_alias(const std::string& x) { - alias_ = x; - LOG(INFO) << "kernel " << op_type() << " setting alias " << alias(); - } + void set_alias(const std::string& x) { alias_ = x; } const std::string& alias() const { return alias_; } virtual Place place() const = 0; diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 1fd049147..7b7bc73ef 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -24,3 +24,14 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS mir_pass_manager program_fake_utils ) +cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS + ops_lite + host_kernels + kernels_cuda + mir_passes + mir_pass_manager + optimizer_lite + program_fake_utils + target_wrapper_host + target_wrapper_cuda + ) diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index eac6208a7..d1c7f73a7 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -36,10 +36,7 @@ void IoComplementPass::Apply(std::unique_ptr& graph) { ComplementInputs(graph.get(), node, in); } } - - // PickIoCopyKernel(graph.get()); - - LOG(INFO) << "\n" << Visualize(graph.get()); + VLOG(3) << "\n" << Visualize(graph.get()); } void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, @@ -96,6 +93,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, // create Op and kernels. auto io_copy_op = LiteOpRegistry::Global().Create("io_copy"); + CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; // CHECK(io_copy_op); // Create the new var manually. inst_node->AsInstruct().op->scope()->Var(io_copy_output_name); @@ -144,36 +142,6 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, graph->CheckValid(); } -void IoComplementPass::PickIoCopyKernel(SSAGraph* graph) { - for (auto& node : graph->mutable_nodes()) { - if (node.IsInstruct() && node.AsInstruct().op_type == "io_copy") { - auto& kernels = node.AsInstruct().valid_kernels; - CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; - for (auto& kernel : kernels) { - CHECK_EQ(node.inlinks.size(), 1UL); - CHECK_EQ(node.outlinks.size(), 1UL); - auto* inty = node.inlinks.front()->AsArgument().type; - auto* outy = node.outlinks.front()->AsArgument().type; - const Type* in_arg_ty = kernel->GetInputDeclType("Input"); - if (TypeCompatibleTo(*inty, *in_arg_ty)) { - const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - // Both the input and output type matches, remove other kernels - // directly. - if (out_arg_ty->target() == outy->target()) { - LOG(INFO) << "get a IOCopy kernel"; - auto x = std::move(kernel); - kernels.clear(); - kernels.emplace_back(std::move(x)); - break; - } - } - } - } - } - - // Check the compatiblity. -} - void IoComplementPass::SetValidPlaces(const std::vector& valid_places) { CHECK(!valid_places.empty()); valid_places_ = valid_places; diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.h b/paddle/fluid/lite/core/mir/io_complement_pass.h index feca16deb..b1ae18462 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.h +++ b/paddle/fluid/lite/core/mir/io_complement_pass.h @@ -26,7 +26,6 @@ static void UpdateInputTo(framework::proto::OpDesc* desc, for (auto& item : *desc->mutable_inputs()) { for (auto& input : *item.mutable_arguments()) { if (input == from) { - LOG(INFO) << "** update input argument from " << from << " to " << to; input = to; } } @@ -49,9 +48,6 @@ class IoComplementPass : public ProgramPass { void SetValidPlaces(const std::vector& valid_places); - // Pick the right kernel of IoCopy considering the input and output Type. - void PickIoCopyKernel(SSAGraph* graph); - const std::vector& valid_places() const { return valid_places_; }; private: diff --git a/paddle/fluid/lite/core/mir/pass_registry.h b/paddle/fluid/lite/core/mir/pass_registry.h index 4190f96c0..5c213169b 100644 --- a/paddle/fluid/lite/core/mir/pass_registry.h +++ b/paddle/fluid/lite/core/mir/pass_registry.h @@ -25,7 +25,7 @@ namespace mir { class PassRegistry { public: PassRegistry(const std::string& name, mir::Pass* pass) { - LOG(INFO) << "Registry add MIR pass " << name; + VLOG(2) << "Registry add MIR pass " << name; PassManager::Global().AddNewPass(name, pass); } diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 51f9362e5..b808d4421 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -91,7 +91,9 @@ std::vector SSAGraph::InstructTopologicalOrder() { void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { for (const auto &name : program.tmp_vars) { - LOG(INFO) << "create arg node " << name; + CHECK(!arguments_.count(name)) << "duplicate creating temp variable: " + << name; + VLOG(5) << "create arg node " << name; node_storage_.emplace_back(); auto &new_node = node_storage_.back(); new_node.AsArgument(name); @@ -102,7 +104,9 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { // create weight nodes. for (const auto &name : program.weights) { - LOG(INFO) << "create arg node " << name; + CHECK(!arguments_.count(name)) << "duplicate creating weight variable: " + << name; + VLOG(5) << "create arg node " << name; node_storage_.emplace_back(); auto &new_node = node_storage_.back(); new_node.AsArgument(name); @@ -134,10 +138,8 @@ void SSAGraph::Build(const Program &program, for (auto &op : program.ops) { auto *op_node = GraphCreateInstructNode(program, op, valid_places); - LOG(INFO) << "checking op " << op->op_type_; for (const std::string &name : op->op_info()->input_names()) { auto *arg = Argument(name); - LOG(INFO) << "input " << name; CHECK(arg->IsRoleSet()); DirectedLink(arg, op_node); } @@ -145,7 +147,6 @@ void SSAGraph::Build(const Program &program, if (!arguments_.count(name)) { NewArgumentNode(name); } - LOG(INFO) << "output " << name; auto *arg = arguments_.at(name); CHECK(arg->IsRoleSet()); DirectedLink(op_node, arg); diff --git a/paddle/fluid/lite/core/mir/ssa_graph_test.cc b/paddle/fluid/lite/core/mir/ssa_graph_test.cc index 76f595a91..2bf447593 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph_test.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph_test.cc @@ -35,7 +35,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x, } TEST(SSAGraph, test) { - auto program = FakeProgram(); + auto program = ProgramFaker(); SSAGraph graph; std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index a85628f02..2c954be96 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -38,7 +38,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr& graph) { std::vector>> scored; for (auto&& kernel : instruct.valid_kernels) { size_t score = KernelGrade(*kernel); - LOG(INFO) << "kernel " << kernel->summary() << " " << score; scored.emplace_back(score, std::move(kernel)); } @@ -49,7 +48,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr& graph) { // TODO(Superjomn) reconsider this. instruct.valid_kernels.clear(); instruct.valid_kernels.emplace_back(std::move(scored.front().second)); - LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); + VLOG(2) << "pick " << instruct.valid_kernels.front()->name(); } } diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h index 8e0aed7da..86b53ce28 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h @@ -74,10 +74,10 @@ class StaticKernelPickPass : public mir::InstructionPass { score += kMax / static_cast( core::KernelPickFactor::Factor::DataLayoutFirst); } - LOG(INFO) << "picker tactic " << kernel_pick_factors_; - LOG(INFO) << "kernel place " << kernel.place(); - LOG(INFO) << "picker place " << place(); - LOG(INFO) << "score " << score; + VLOG(4) << "picker tactic " << kernel_pick_factors_; + VLOG(4) << "kernel place " << kernel.place(); + VLOG(4) << "picker place " << place(); + VLOG(4) << "score " << score; // The data layout is not considered, for the input and output arguments // might have different data layout. diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 2a24ac6e6..daa5a5bb6 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -51,49 +51,54 @@ class VariablePlaceInferencePass : public DebugPass { for (auto& node : graph->mutable_nodes()) { if (node.IsArgument()) { CHECK(node.AsArgument().type) << "node " << node.AsArgument().name - << " type not determined"; + << " type not determined, " << &node; } } } void InferenceArgumentPlace(SSAGraph* graph) { - LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global(); + VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); for (auto& x : graph->InstructTopologicalOrder()) { auto& inst = x->AsInstruct(); // The IoCopyOp is a tool operator, it won't support the type inference. if (inst.op_type == "io_copy") continue; // LOG(INFO) << "- inferencing type " << // deal with inputs + VLOG(4) << "inferencing op " << inst.op_type; for (auto& arg_name : inst.op_info()->input_argnames()) { - LOG(INFO) << "-- input arg_name " << arg_name; + VLOG(3) << "-- input arg_name " << arg_name; // check if inputs's place is set, if not set, update them with the // kernel's declaration. auto type = inst.picked_kernel().GetInputDeclType(arg_name); auto arg_names = inst.op_info()->input_argument().at(arg_name); for (auto& arg_name : arg_names) { - LOG(INFO) << "--- var " << arg_name; + VLOG(3) << "--- var " << arg_name; auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); - if (arg_node.type) continue; - arg_node.type = type; + if (!arg_node.type) { + VLOG(4) << "set type " << *type << " " << node; + arg_node.type = type; + } } } for (auto& arg_name : inst.op_info()->output_argnames()) { - LOG(INFO) << "-- output arg_name " << arg_name; + VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); auto arg_names = inst.op_info()->output_argument().at(arg_name); // check if outputs's place is set, if not set, update them with the // kernel's declaration. for (auto& arg_name : arg_names) { - LOG(INFO) << "--- var " << arg_name; + VLOG(3) << "--- var " << arg_name; auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); - if (arg_node.type) continue; - node->AsArgument().type = type; + if (!arg_node.type) { + node->AsArgument().type = type; + VLOG(3) << "set type " << *type; + } } } } diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc new file mode 100644 index 000000000..394aa9ba8 --- /dev/null +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/optimizer.h" +#include "paddle/fluid/lite/core/program_fake_utils.h" +#include "paddle/fluid/lite/kernels/cuda/use_kernels.h" +#include "paddle/fluid/lite/kernels/host/use_kernels.h" + +namespace paddle { +namespace lite { +namespace mir { + +TEST(variable_place_inference_pass, test) { + std::shared_ptr scope(new lite::Scope); + ProgramFaker program_faker; + program_faker.AddFeed("a", 0); + program_faker.AddMul("a", "W", "a1"); + program_faker.AddMul("a1", "W1", "a2"); + program_faker.AddFetch("a2", 0); + program_faker.CreateVars(scope.get()); + + auto* desc = program_faker.program(); + + Optimizer optimizer; + std::vector places({ + Place{ + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW), + }, + Place{ + TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), + }, + Place{ + TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), + }, + Place{ + TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny), + }, + }); + + Program program(*desc, scope, places); + + core::KernelPickFactor factor; + factor.ConsiderTarget(); + + std::vector passes({ + "static_kernel_pick_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_complement_pass", // + }); + + Place prefered_place{ + TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), + }; + optimizer.KernelPickPreferPlace(prefered_place); + optimizer.Run(std::move(program), places, factor, passes); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index ba6917c62..f093efb4e 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -35,7 +35,7 @@ std::vector> OpLite::CreateKernels( } CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; - LOG(INFO) << "op " << op_type_ << " get " << kernels.size() << " kernels"; + VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } diff --git a/paddle/fluid/lite/core/op_registry.cc b/paddle/fluid/lite/core/op_registry.cc index b012ca3e2..4ac233ad2 100644 --- a/paddle/fluid/lite/core/op_registry.cc +++ b/paddle/fluid/lite/core/op_registry.cc @@ -21,7 +21,7 @@ std::list> KernelRegistry::Create( const std::string &op_type, TargetType target, PrecisionType precision, DataLayoutType layout) { Place place{target, precision, layout}; - LOG(INFO) << "creating " << op_type << " kernel for " << place; + VLOG(5) << "creating " << op_type << " kernel for " << place; #define CREATE_KERNEL1(target__, precision__) \ switch (layout) { \ case DATALAYOUT(kNCHW): \ diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 590ba3cac..fb656c3a4 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -81,9 +81,9 @@ class KernelRegistry final { void Register(const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { - LOG(INFO) << "register for " << TargetToStr(Target) << ":" - << PrecisionToStr(Precision) << "//" - << GetKernelOffset(); + VLOG(3) << "register for " << TargetToStr(Target) << ":" + << PrecisionToStr(Precision) << "//" + << GetKernelOffset(); using kernel_registor_t = KernelRegistryForTarget; auto &varient = registries_[GetKernelOffset()]; @@ -144,9 +144,9 @@ class KernelRegistor : public lite::Registor { public: KernelRegistor(const std::string &op_type, const std::string &alias) : Registor([=] { - LOG(INFO) << "Register kernel " << op_type << " for " - << TargetToStr(target) << " " << PrecisionToStr(precision) - << " " << DataLayoutToStr(layout) << " alias " << alias; + VLOG(3) << "Register kernel " << op_type << " for " + << TargetToStr(target) << " " << PrecisionToStr(precision) + << " " << DataLayoutToStr(layout) << " alias " << alias; KernelRegistry::Global().Register( op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); diff --git a/paddle/fluid/lite/core/optimizer.cc b/paddle/fluid/lite/core/optimizer.cc index b9761d105..96f3a0535 100644 --- a/paddle/fluid/lite/core/optimizer.cc +++ b/paddle/fluid/lite/core/optimizer.cc @@ -27,33 +27,5 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) { *pass->mutable_kernel_pick_factors() = factor; } -void Optimizer::RunPasses() { - std::vector passes({ - "static_kernel_pick_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "io_complement_pass", // - "argument_type_display_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // - "variable_place_inference_pass", // - "runtime_context_assign_pass", // - }); - for (auto& pass_type : passes) { - LOG(INFO) << ".. running pass " << pass_type; - auto* pass = mir::PassManager::Global().LookUp(pass_type); - CHECK(pass); - if (pass->name() == "io_complement_pass") { - auto* _pass = dynamic_cast(pass); - _pass->SetValidPlaces(valid_places_); - CHECK(!_pass->valid_places().empty()); - _pass->Apply(graph_); - } else { - pass->Apply(graph_); - } - } - // mir::PassManager::Global().Run(graph_); -} } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 977d1ab7e..c72bd7405 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -41,8 +41,24 @@ class Optimizer { graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); SpecifyKernelPickTactic(kernel_pick_factor); - // InitIoComplement(); - RunPasses(); + InitIoComplement(); + + if (passes.empty()) { + RunPasses(std::vector{{ + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_complement_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // + "runtime_context_assign_pass", // + }}); + } else { + RunPasses(passes); + } exec_scope_ = program.exec_scope; } @@ -86,11 +102,15 @@ class Optimizer { protected: void SpecifyKernelPickTactic(core::KernelPickFactor factor); - // Run the default passes registered in the PassManager. - void RunPasses(); - // Specify the passes and run them. - void RunPasses(std::vector& passes); + void RunPasses(const std::vector& passes) { + for (auto& x : passes) { + LOG(INFO) << "== Running pass " << x; + auto* pass = mir::PassManager::Global().LookUp(x); + CHECK(pass); + pass->Apply(graph_); + } + } private: std::unique_ptr graph_; diff --git a/paddle/fluid/lite/core/optimizer_test.cc b/paddle/fluid/lite/core/optimizer_test.cc index a301f996a..8e8827484 100644 --- a/paddle/fluid/lite/core/optimizer_test.cc +++ b/paddle/fluid/lite/core/optimizer_test.cc @@ -25,7 +25,7 @@ namespace lite { TEST(Optimizer, test) { Optimizer optimizer; - auto program = FakeProgram(); + auto program = ProgramFaker(); std::vector places({Place{TARGET(kHost), PRECISION(kFloat)}}); auto* pick_pass = diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 6f945c061..91b789819 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -64,7 +64,7 @@ struct Program { for (auto* op_desc : program.Block(0).AllOps()) { auto op_type = op_desc->Type(); // if (op_type == "feed" || op_type == "fetch") continue; - LOG(INFO) << "create Op [" << op_type << "]"; + VLOG(4) << "create Op [" << op_type << "]"; ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); // pick initial kernel ops.back()->PickKernel(valid_places); diff --git a/paddle/fluid/lite/core/program_fake_utils.h b/paddle/fluid/lite/core/program_fake_utils.h index 30f40cd9f..b27c90c3c 100644 --- a/paddle/fluid/lite/core/program_fake_utils.h +++ b/paddle/fluid/lite/core/program_fake_utils.h @@ -71,5 +71,68 @@ Program FakeProgram() { return program; } +class ProgramFaker { + public: + ProgramFaker() {} + + framework::ProgramDesc* program() { + desc_.Flush(); + return &desc_; + } + + void CreateVars(lite::Scope* scope) { + for (auto& var : tmp_vars_) { + auto* x = scope->Var(var); + x->GetMutable(); + } + + for (auto& x : tmp_vars_) { + desc_.MutableBlock(0)->Var(x); + } + } + + void AddMul(const std::string& X, const std::string& Y, + const std::string& out) { + tmp_vars_.insert(X); + tmp_vars_.insert(Y); + tmp_vars_.insert(out); + + auto* block = desc_.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {X}); + op->SetInput("Y", {Y}); + op->SetOutput("Out", {Y}); + op->SetAttr("x_num_col_dims", 1); + op->SetAttr("y_num_col_dims", 1); + } + + void AddFeed(const std::string& Out, int col) { + tmp_vars_.insert(Out); + + auto* block = desc_.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("feed"); + op->SetInput("X", {"feed"}); + op->SetOutput("Out", {Out}); + op->SetAttr("col", col); + } + + void AddFetch(const std::string& Input, int col) { + tmp_vars_.insert(Input); + auto* block = desc_.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("fetch"); + op->SetInput("X", {Input}); + op->SetOutput("Out", {"fetch"}); + op->SetAttr("col", col); + } + + private: + std::set tmp_vars_; + std::vector weight_vars_; + framework::ProgramDesc desc_; +}; + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 33181eeac..597712665 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -142,6 +142,8 @@ class Type : public DataTypeBase { } if (other.is_tensor_) { os << "("tensor"); - } - - public: - static TypeSystem& Global() { - static TypeSystem x; - return x; - } - - template - void Register(const std::string& type) { - size_t hash = typeid(T).hash_code(); - CHECK(!types_.count(hash)) << "duplicate register type " << type - << " found!"; - types_[hash] = type; - names_.insert(type); - } - - template - bool Contains() const { - return types_.count(typeid(T).hash_code()); - } - - bool Contains(size_t hash) const { return types_.count(hash); } - - bool Contains(const std::string& type) { return names_.count(type); } - - std::string DebugInfo() const { - std::stringstream ss; - for (const auto& it : types_) { - ss << it.second << "\n"; - } - return ss.str(); - } - - private: - std::unordered_map types_; - TypeSystem(const TypeSystem&) = delete; - std::unordered_set names_; -}; - /* * ParamType is used to represent a data type of a parameter for the kernel. It * can represent any Variable data type. diff --git a/paddle/fluid/lite/kernels/cuda/use_kernels.h b/paddle/fluid/lite/kernels/cuda/use_kernels.h new file mode 100644 index 000000000..39f6d41ad --- /dev/null +++ b/paddle/fluid/lite/kernels/cuda/use_kernels.h @@ -0,0 +1,24 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/op_registry.h" + +// TODO(Superjomn) make this file a library, that will make compile dependency +// easier. +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/paddle/fluid/lite/kernels/host/use_kernels.h b/paddle/fluid/lite/kernels/host/use_kernels.h new file mode 100644 index 000000000..e9e9c88c6 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/use_kernels.h @@ -0,0 +1,22 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/op_registry.h" + +USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/model_parser/runtime.h b/paddle/fluid/lite/model_parser/runtime.h index 3a96989e1..ea87eb248 100644 --- a/paddle/fluid/lite/model_parser/runtime.h +++ b/paddle/fluid/lite/model_parser/runtime.h @@ -95,7 +95,7 @@ class OpDesc { std::string op_type; std::map> inputs; std::map> outputs; - std::map> attrs; + std::map> attrs; }; class BlockDesc { @@ -112,6 +112,8 @@ class BlockDesc { class ProgramDesc { public: void Parse(const framework::proto::ProgramDesc& desc); + + BlockDesc block; }; } // namespace lite -- GitLab