From eada00c254440064976ff760c322d464068cf7c2 Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 17 Apr 2019 19:04:54 +0800 Subject: [PATCH] init optimizer --- paddle/fluid/lite/core/CMakeLists.txt | 7 +++ paddle/fluid/lite/core/kernel.h | 21 ++++++-- paddle/fluid/lite/core/mir/CMakeLists.txt | 5 +- .../lite/core/mir/generate_program_pass.cc | 14 +++++ .../lite/core/mir/generate_program_pass.h | 3 +- .../fluid/lite/core/mir/io_complement_pass.cc | 15 ++++++ .../fluid/lite/core/mir/io_complement_pass.h | 3 +- paddle/fluid/lite/core/mir/pass_manager.cc | 1 - paddle/fluid/lite/core/mir/pass_manager.h | 17 ++++-- paddle/fluid/lite/core/mir/ssa_graph_test.cc | 54 +------------------ .../lite/core/mir/static_kernel_pick_pass.cc | 46 ++++++++++++++++ .../lite/core/mir/static_kernel_pick_pass.h | 54 ++++++++++++++++++- paddle/fluid/lite/core/op_registry.h | 6 ++- paddle/fluid/lite/core/optimizer.h | 13 ++--- paddle/fluid/lite/core/types.h | 32 +++++++++++ 15 files changed, 217 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 6fd696d5d8c..c7e86290853 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l #TODO(Superjomn) remove these dependencies from original framework proto_desc) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) +cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS tensor_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager) +cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph + scope_lite op_registry_lite proto_desc op_lite + ops_lite + host_kernels + ) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) @@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) cc_test(test_tensor_lite SRCS tensor_test.cc) cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels) cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) +cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) add_subdirectory(mir) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index ee6890aea96..f83ae0a5222 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -15,7 +15,9 @@ #pragma once #include +#include #include +#include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/target_wrapper.h" @@ -48,17 +50,24 @@ class KernelBase { return param_.get(); } + void set_op_type(const std::string& type) { op_type_ = type; } + const std::string& op_type() const { return op_type_; } + void Torch() {} virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; + virtual std::string name() const = 0; + virtual ~KernelBase() = default; protected: core::any_context_t context_; mutable operators::param_t param_; + // The corresponding op type. + std::string op_type_; }; /* @@ -73,8 +82,9 @@ struct ParamType { Place tensor_place{}; const Type* type_; - ParamType() = default; - ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} + explicit ParamType() = default; + explicit ParamType(size_t element_type_hash) + : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} ParamType(const Type* type) : type_(type) {} @@ -135,7 +145,8 @@ class ParamTypeRegistry { * PRECISION(kFloat)}); */ struct NewInstance { - NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {} + explicit NewInstance(const std::string& kernel_type) + : kernel_type_(kernel_type) {} NewInstance& BindInput(int offset, const ParamType& ptype) { ParamTypeRegistry::Global().Register( @@ -205,6 +216,10 @@ class OpKernel : public KernelBase { TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } DataLayoutType layout() const override { return DataLayout; } + std::string name() const override { + return op_type() + ":" + TargetToStr(Target) + "/" + + PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout); + } void Touch() {} diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 5e789829abb..757e5e141fd 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -9,13 +9,14 @@ cc_library(mir_passes graph_visualize_pass.cc generate_program_pass.cc demo_pass.cc - DEPS mir_pass) + DEPS mir_pass types_lite) cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS mir_ssa_graph scope_lite op_lite - proto_desc ops_lite + ops_lite host_kernels mir_passes mir_pass_manager + program_fake_utils ) diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index ce71e4de2b8..0be516489db 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -11,3 +11,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/mir/generate_program_pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { +void GenerateProgramPass::Apply(std::unique_ptr &graph) {} +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(generate_program_pass, + paddle::lite::mir::GenerateProgramPass); diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.h b/paddle/fluid/lite/core/mir/generate_program_pass.h index bee6a3c8f66..bc78370b08d 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.h +++ b/paddle/fluid/lite/core/mir/generate_program_pass.h @@ -24,8 +24,9 @@ namespace mir { * GenerateProgramPass will build the execution program for executor from a mir * graph. */ -class GenerateProgramPass : public Pass { +class GenerateProgramPass : public ProgramPass { public: + void Apply(std::unique_ptr &graph) override; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index ce71e4de2b8..8511a3921f5 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -11,3 +11,18 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/mir/io_complement_pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void IoComplementPass::Apply(std::unique_ptr &graph) {} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(io_complement_pass, paddle::lite::mir::IoComplementPass); diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.h b/paddle/fluid/lite/core/mir/io_complement_pass.h index ff071ce037e..b44fde7b450 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.h +++ b/paddle/fluid/lite/core/mir/io_complement_pass.h @@ -24,8 +24,9 @@ namespace mir { * IoComplementPass complement the necessary instruction to make data * transferring or transformation between different places. */ -class IoComplementPass : public Pass { +class IoComplementPass : public ProgramPass { public: + void Apply(std::unique_ptr &graph) override; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/pass_manager.cc b/paddle/fluid/lite/core/mir/pass_manager.cc index 6e1c0cd9f5f..508c2fd5522 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.cc +++ b/paddle/fluid/lite/core/mir/pass_manager.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/pass_manager.h" -#include "paddle/fluid/lite/core/mir/pass_registry.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/mir/pass_manager.h b/paddle/fluid/lite/core/mir/pass_manager.h index e2ff6549bd5..ba3fad4d9f2 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.h +++ b/paddle/fluid/lite/core/mir/pass_manager.h @@ -32,16 +32,17 @@ class PassManager { PassManager(); - void Run() { + void Run(std::unique_ptr& graph) { for (auto& pass : passes_) { LOG(INFO) << "Running MIR pass " << pass->name(); - pass->Apply(graph_); + pass->Apply(graph); } } bool AddNewPass(const std::string& name, Pass* pass) { passes_.emplace_back(pass); pass_map_.emplace(name, passes_.back().get()); + passes_.back()->set_name(name); return true; } @@ -65,12 +66,18 @@ class PassManager { Pass* LookUp(const std::string& key) { auto it = pass_map_.find(key); - CHECK(it != pass_map_.end()); - return it->second; + if (it != pass_map_.end()) return it->second; + return nullptr; + } + + template + PassTy* LookUp(const std::string& key) { + auto it = pass_map_.find(key); + if (it != pass_map_.end()) return dynamic_cast(it->second); + return nullptr; } private: - std::unique_ptr graph_; std::list> passes_; std::map pass_map_; }; diff --git a/paddle/fluid/lite/core/mir/ssa_graph_test.cc b/paddle/fluid/lite/core/mir/ssa_graph_test.cc index 0db9164e7e2..76f595a913a 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph_test.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph_test.cc @@ -16,7 +16,9 @@ #include #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program_fake_utils.h" namespace paddle { namespace lite { @@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x, fc->SetOutput("Out", {out}); } -Program FakeProgram() { - Program program; - program.scope = new lite::Scope; - - auto add_fc = [&](int id, std::string x) { - // create variables - std::string w1 = "w" + std::to_string(id); - std::string b1 = "b" + std::to_string(id); - std::string out1 = "out" + std::to_string(id); - auto w1v = program.scope->Var(w1)->GetMutable(); - auto b1v = program.scope->Var(b1)->GetMutable(); - auto out1v = program.scope->Var(out1)->GetMutable(); - - framework::OpDesc desc; - desc.SetInput("Input", {x}); - desc.SetInput("W", {w1}); - desc.SetInput("Bias", {b1}); - desc.SetOutput("Out", {out1}); - desc.SetType("fc"); - desc.SetAttr("in_num_col_dims", 1); - desc.Flush(); - - // add to input - program.tmp_vars.push_back(w1); - program.tmp_vars.push_back(b1); - - auto fc_op = LiteOpRegistry::Global().Create("fc"); - fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); - fc_op->Attach(desc, program.scope); - program.ops.emplace_back(std::move(fc_op)); - - w1v->Resize({100, 100}); - b1v->Resize({100, 1}); - out1v->Resize({100, 100}); - - return out1; - }; - - // x1, w1, b1 -fc-> out1 - // out1, w2, b2 -fc-> out2 - - std::string x = "x"; - program.tmp_vars.push_back(x); - auto* xv = program.scope->Var(x)->GetMutable(); - xv->Resize({100, 100}); - - for (int i = 0; i < 3; i++) { - x = add_fc(i, x); - } - return program; -} - TEST(SSAGraph, test) { auto program = FakeProgram(); SSAGraph graph; diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index ce71e4de2b8..8a324a6cca2 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -11,3 +11,49 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" +#include +#include +#include +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +bool KernelScoreCmp(const std::pair>& a, + const std::pair>& b) { + return a.first > b.first; +} + +void StaticKernelPickPass::Apply(std::unique_ptr& graph) { + CHECK(kernel_pick_factors_.AnyFactorConsidered()) + << "kernel_pick_factors should be specified first"; + CHECK(graph) << "graph not valid"; + // sort kernels by the factors. + for (auto& node : graph->mutable_nodes()) { + if (!node.IsInstruct()) continue; + auto& instruct = node.AsInstruct(); + std::vector>> scored; + for (auto&& kernel : instruct.valid_kernels) { + scored.emplace_back(KernelGrade(*kernel), std::move(kernel)); + } + + std::sort(scored.begin(), scored.end(), KernelScoreCmp); + + // Move kernel back + // Just keep a single best kernel. + // TODO(Superjomn) reconsider this. + instruct.valid_kernels.clear(); + instruct.valid_kernels.emplace_back(std::move(scored.front().second)); + LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(static_kernel_pick_pass, + paddle::lite::mir::StaticKernelPickPass); diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h index becdd50dd91..9cf3ee2d439 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h @@ -14,13 +14,65 @@ #pragma once +#include #include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/types.h" namespace paddle { namespace lite { namespace mir { -class StaticKernelPickPass : public mir::Pass {}; +/* + * StaticKernelPickPass is a simple strategy for picking the kernel for each + * Operator using operator developer defined rule, there are many other tactics + * such as considering IO or kernel execution latency and we will implement them + * latter. + * + * There are two argument for this pass: + * - place, the target place. + * - kernel_pick_factors, the factors to consider in picking kernels. + * Set them first before execute the pass. + */ +class StaticKernelPickPass : public mir::InstructionPass { + public: + void Apply(std::unique_ptr& graph) override; + + const Place& place() const { return place_; } + const core::KernelPickFactor& kernel_pick_factors() const { + return kernel_pick_factors_; + } + core::KernelPickFactor* mutable_kernel_pick_factors() { + return &kernel_pick_factors_; + } + + private: + // Score the kernel. + size_t KernelGrade(const lite::KernelBase& kernel) { + size_t score{}; + const int kMax = + std::numeric_limits::max(); + if (kernel_pick_factors_.IsTargetConsidered() && + place().target == kernel.target()) { + score += + kMax / static_cast(core::KernelPickFactor::Factor::TargetFirst); + } + if (kernel_pick_factors_.IsPrecisionConsidered() && + place().precision == kernel.precision()) { + score += kMax / + static_cast(core::KernelPickFactor::Factor::PrecisionFirst); + } + + // The data layout is not considered, for the input and output arguments + // might have different data layout. + // TODO(Superjomn) reconsider the idea of taking the data layout as a kernel + // specification. + return score; + } + + private: + core::KernelPickFactor kernel_pick_factors_; + Place place_; +}; } // namespace mir } // namespace lite diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index be85cbe8419..95d62f2fee4 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor { LOG(INFO) << "Register kernel " << op_type << " for " << TargetToStr(target) << " " << PrecisionToStr(precision); KernelRegistry::Global().Register( - op_type, [&]() -> std::unique_ptr { - return std::unique_ptr(new KernelType); + op_type, [&, op_type]() -> std::unique_ptr { + std::unique_ptr x(new KernelType); + x->set_op_type(op_type); + return x; }); }) {} }; diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 76bfb6c7b38..dabe75f183a 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -16,7 +16,6 @@ #include #include #include "paddle/fluid/lite/core/mir/pass_manager.h" -#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h" namespace paddle { @@ -28,17 +27,19 @@ namespace lite { */ class Optimizer { public: - void Run(std::unique_ptr&& program, - const std::vector& valid_places, + void Run(mir::Program&& program, const std::vector& valid_places, const std::vector& passes = {}) { CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); - graph_->Build(*program, valid_places); + graph_->Build(program, valid_places); RunPasses(); } // Generate a new program based on the mir graph. - std::unique_ptr GenProgram() {} + std::unique_ptr GenProgram() { + std::unique_ptr res; + return res; + } // Generate C++ code which combines the inference program, model and weights. void GenCode(const std::string& code_dir); @@ -50,7 +51,7 @@ class Optimizer { protected: // Run the default passes registered in the PassManager. - void RunPasses() { mir::PassManager::Global().Run(); } + void RunPasses() { mir::PassManager::Global().Run(graph_); } // Specify the passes and run them. void RunPasses(std::vector& passes); diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h index 52b5ea7d02a..990ffe0007b 100644 --- a/paddle/fluid/lite/core/types.h +++ b/paddle/fluid/lite/core/types.h @@ -25,6 +25,38 @@ using any_context_t = variant, // Context // >; +// Factors that impact the kernel picking strategy. Multiple factors can be +// considered together by using statement like 'factor1 | factor2' +class KernelPickFactor { + public: + using value_type = unsigned char; + enum class Factor : int { + // The following factors are sorted by priority. + TargetFirst = 1, + PrecisionFirst = 1 << 1, + DataLayoutFirst = 1 << 2, + DeviceFirst = 1 << 3, + }; + + // Has any factors considered. + bool AnyFactorConsidered() const { return data_; } + + KernelPickFactor& ConsiderTarget(); + KernelPickFactor& ConsiderPrecision(); + KernelPickFactor& ConsiderDataLayout(); + KernelPickFactor& ConsiderDevice(); + + bool IsTargetConsidered() const; + bool IsPrecisionConsidered() const; + bool IsDataLayoutConsidered() const; + bool IsDeviceConsidered() const { + return data_ & static_cast(Factor::DeviceFirst); + } + + private: + unsigned char data_{}; +}; + struct dim2 { int x{}; int y{}; -- GitLab