diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 6fd696d5d8ccee31653201fb7bcd3707bca25fe7..c7e86290853ebb8af6dff2805315bfc7d3eec1af 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l #TODO(Superjomn) remove these dependencies from original framework proto_desc) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) +cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS tensor_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager) +cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph + scope_lite op_registry_lite proto_desc op_lite + ops_lite + host_kernels + ) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) @@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) cc_test(test_tensor_lite SRCS tensor_test.cc) cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels) cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) +cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) add_subdirectory(mir) diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index ee6890aea963132220d863316fcfb0492c9a7bf1..f83ae0a522248d2d013eb9c5797448abcd6733ac 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -15,7 +15,9 @@ #pragma once #include +#include #include +#include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/target_wrapper.h" @@ -48,17 +50,24 @@ class KernelBase { return param_.get(); } + void set_op_type(const std::string& type) { op_type_ = type; } + const std::string& op_type() const { return op_type_; } + void Torch() {} virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; + virtual std::string name() const = 0; + virtual ~KernelBase() = default; protected: core::any_context_t context_; mutable operators::param_t param_; + // The corresponding op type. + std::string op_type_; }; /* @@ -73,8 +82,9 @@ struct ParamType { Place tensor_place{}; const Type* type_; - ParamType() = default; - ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} + explicit ParamType() = default; + explicit ParamType(size_t element_type_hash) + : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} ParamType(const Type* type) : type_(type) {} @@ -135,7 +145,8 @@ class ParamTypeRegistry { * PRECISION(kFloat)}); */ struct NewInstance { - NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {} + explicit NewInstance(const std::string& kernel_type) + : kernel_type_(kernel_type) {} NewInstance& BindInput(int offset, const ParamType& ptype) { ParamTypeRegistry::Global().Register( @@ -205,6 +216,10 @@ class OpKernel : public KernelBase { TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } DataLayoutType layout() const override { return DataLayout; } + std::string name() const override { + return op_type() + ":" + TargetToStr(Target) + "/" + + PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout); + } void Touch() {} diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 5e789829abb51a9caf52b8f3dd5c3f76c78030b8..757e5e141fdad383c322a7eb940979dfe5ae8b4c 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -9,13 +9,14 @@ cc_library(mir_passes graph_visualize_pass.cc generate_program_pass.cc demo_pass.cc - DEPS mir_pass) + DEPS mir_pass types_lite) cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS mir_ssa_graph scope_lite op_lite - proto_desc ops_lite + ops_lite host_kernels mir_passes mir_pass_manager + program_fake_utils ) diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index ce71e4de2b8f6ee4332a47e7d26876f75e0d75fd..0be516489db6a0a66a66db6e87cdf9dc875f75ba 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -11,3 +11,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/mir/generate_program_pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { +void GenerateProgramPass::Apply(std::unique_ptr &graph) {} +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(generate_program_pass, + paddle::lite::mir::GenerateProgramPass); diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.h b/paddle/fluid/lite/core/mir/generate_program_pass.h index bee6a3c8f66af86c0f656995aac3f99056028ccf..bc78370b08d8785586ab5bb4babd883c9377988c 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.h +++ b/paddle/fluid/lite/core/mir/generate_program_pass.h @@ -24,8 +24,9 @@ namespace mir { * GenerateProgramPass will build the execution program for executor from a mir * graph. */ -class GenerateProgramPass : public Pass { +class GenerateProgramPass : public ProgramPass { public: + void Apply(std::unique_ptr &graph) override; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index ce71e4de2b8f6ee4332a47e7d26876f75e0d75fd..8511a3921f59539e8e83debdb6a0819110a86e98 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -11,3 +11,18 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/mir/io_complement_pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void IoComplementPass::Apply(std::unique_ptr &graph) {} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(io_complement_pass, paddle::lite::mir::IoComplementPass); diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.h b/paddle/fluid/lite/core/mir/io_complement_pass.h index ff071ce037e696acb9695c52291e0869f10a3d92..b44fde7b4501b5739a5c68902a7699ca43fb452f 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.h +++ b/paddle/fluid/lite/core/mir/io_complement_pass.h @@ -24,8 +24,9 @@ namespace mir { * IoComplementPass complement the necessary instruction to make data * transferring or transformation between different places. */ -class IoComplementPass : public Pass { +class IoComplementPass : public ProgramPass { public: + void Apply(std::unique_ptr &graph) override; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/pass_manager.cc b/paddle/fluid/lite/core/mir/pass_manager.cc index 6e1c0cd9f5f34bfd5e634bc5b11b5e6f45ebe2b4..508c2fd5522519793af26973f711c4c7d2b7a7d3 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.cc +++ b/paddle/fluid/lite/core/mir/pass_manager.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/pass_manager.h" -#include "paddle/fluid/lite/core/mir/pass_registry.h" namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/mir/pass_manager.h b/paddle/fluid/lite/core/mir/pass_manager.h index e2ff6549bd53fbc33e67011845355cd5104841d7..ba3fad4d9f242e0d921616ba8eced13d3df180dc 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.h +++ b/paddle/fluid/lite/core/mir/pass_manager.h @@ -32,16 +32,17 @@ class PassManager { PassManager(); - void Run() { + void Run(std::unique_ptr& graph) { for (auto& pass : passes_) { LOG(INFO) << "Running MIR pass " << pass->name(); - pass->Apply(graph_); + pass->Apply(graph); } } bool AddNewPass(const std::string& name, Pass* pass) { passes_.emplace_back(pass); pass_map_.emplace(name, passes_.back().get()); + passes_.back()->set_name(name); return true; } @@ -65,12 +66,18 @@ class PassManager { Pass* LookUp(const std::string& key) { auto it = pass_map_.find(key); - CHECK(it != pass_map_.end()); - return it->second; + if (it != pass_map_.end()) return it->second; + return nullptr; + } + + template + PassTy* LookUp(const std::string& key) { + auto it = pass_map_.find(key); + if (it != pass_map_.end()) return dynamic_cast(it->second); + return nullptr; } private: - std::unique_ptr graph_; std::list> passes_; std::map pass_map_; }; diff --git a/paddle/fluid/lite/core/mir/ssa_graph_test.cc b/paddle/fluid/lite/core/mir/ssa_graph_test.cc index 0db9164e7e2efbb59c2cf075a75aa32f29534f91..76f595a913a01bb436420fa722ea6949df69fc37 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph_test.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph_test.cc @@ -16,7 +16,9 @@ #include #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program_fake_utils.h" namespace paddle { namespace lite { @@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x, fc->SetOutput("Out", {out}); } -Program FakeProgram() { - Program program; - program.scope = new lite::Scope; - - auto add_fc = [&](int id, std::string x) { - // create variables - std::string w1 = "w" + std::to_string(id); - std::string b1 = "b" + std::to_string(id); - std::string out1 = "out" + std::to_string(id); - auto w1v = program.scope->Var(w1)->GetMutable(); - auto b1v = program.scope->Var(b1)->GetMutable(); - auto out1v = program.scope->Var(out1)->GetMutable(); - - framework::OpDesc desc; - desc.SetInput("Input", {x}); - desc.SetInput("W", {w1}); - desc.SetInput("Bias", {b1}); - desc.SetOutput("Out", {out1}); - desc.SetType("fc"); - desc.SetAttr("in_num_col_dims", 1); - desc.Flush(); - - // add to input - program.tmp_vars.push_back(w1); - program.tmp_vars.push_back(b1); - - auto fc_op = LiteOpRegistry::Global().Create("fc"); - fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); - fc_op->Attach(desc, program.scope); - program.ops.emplace_back(std::move(fc_op)); - - w1v->Resize({100, 100}); - b1v->Resize({100, 1}); - out1v->Resize({100, 100}); - - return out1; - }; - - // x1, w1, b1 -fc-> out1 - // out1, w2, b2 -fc-> out2 - - std::string x = "x"; - program.tmp_vars.push_back(x); - auto* xv = program.scope->Var(x)->GetMutable(); - xv->Resize({100, 100}); - - for (int i = 0; i < 3; i++) { - x = add_fc(i, x); - } - return program; -} - TEST(SSAGraph, test) { auto program = FakeProgram(); SSAGraph graph; diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index ce71e4de2b8f6ee4332a47e7d26876f75e0d75fd..8a324a6cca2618c06d69e86953618087409bd974 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -11,3 +11,49 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" +#include +#include +#include +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +bool KernelScoreCmp(const std::pair>& a, + const std::pair>& b) { + return a.first > b.first; +} + +void StaticKernelPickPass::Apply(std::unique_ptr& graph) { + CHECK(kernel_pick_factors_.AnyFactorConsidered()) + << "kernel_pick_factors should be specified first"; + CHECK(graph) << "graph not valid"; + // sort kernels by the factors. + for (auto& node : graph->mutable_nodes()) { + if (!node.IsInstruct()) continue; + auto& instruct = node.AsInstruct(); + std::vector>> scored; + for (auto&& kernel : instruct.valid_kernels) { + scored.emplace_back(KernelGrade(*kernel), std::move(kernel)); + } + + std::sort(scored.begin(), scored.end(), KernelScoreCmp); + + // Move kernel back + // Just keep a single best kernel. + // TODO(Superjomn) reconsider this. + instruct.valid_kernels.clear(); + instruct.valid_kernels.emplace_back(std::move(scored.front().second)); + LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(static_kernel_pick_pass, + paddle::lite::mir::StaticKernelPickPass); diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h index becdd50dd91584bef2a8d7c3bc8d4aac2b01a7a7..9cf3ee2d4396ea6fe9a629f944ca88af754d155c 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h @@ -14,13 +14,65 @@ #pragma once +#include #include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/types.h" namespace paddle { namespace lite { namespace mir { -class StaticKernelPickPass : public mir::Pass {}; +/* + * StaticKernelPickPass is a simple strategy for picking the kernel for each + * Operator using operator developer defined rule, there are many other tactics + * such as considering IO or kernel execution latency and we will implement them + * latter. + * + * There are two argument for this pass: + * - place, the target place. + * - kernel_pick_factors, the factors to consider in picking kernels. + * Set them first before execute the pass. + */ +class StaticKernelPickPass : public mir::InstructionPass { + public: + void Apply(std::unique_ptr& graph) override; + + const Place& place() const { return place_; } + const core::KernelPickFactor& kernel_pick_factors() const { + return kernel_pick_factors_; + } + core::KernelPickFactor* mutable_kernel_pick_factors() { + return &kernel_pick_factors_; + } + + private: + // Score the kernel. + size_t KernelGrade(const lite::KernelBase& kernel) { + size_t score{}; + const int kMax = + std::numeric_limits::max(); + if (kernel_pick_factors_.IsTargetConsidered() && + place().target == kernel.target()) { + score += + kMax / static_cast(core::KernelPickFactor::Factor::TargetFirst); + } + if (kernel_pick_factors_.IsPrecisionConsidered() && + place().precision == kernel.precision()) { + score += kMax / + static_cast(core::KernelPickFactor::Factor::PrecisionFirst); + } + + // The data layout is not considered, for the input and output arguments + // might have different data layout. + // TODO(Superjomn) reconsider the idea of taking the data layout as a kernel + // specification. + return score; + } + + private: + core::KernelPickFactor kernel_pick_factors_; + Place place_; +}; } // namespace mir } // namespace lite diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index be85cbe8419397a8c5572548f2684c97bc88b465..95d62f2fee4fea03efd95e80101fca3f20e9089f 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor { LOG(INFO) << "Register kernel " << op_type << " for " << TargetToStr(target) << " " << PrecisionToStr(precision); KernelRegistry::Global().Register( - op_type, [&]() -> std::unique_ptr { - return std::unique_ptr(new KernelType); + op_type, [&, op_type]() -> std::unique_ptr { + std::unique_ptr x(new KernelType); + x->set_op_type(op_type); + return x; }); }) {} }; diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 76bfb6c7b38b1e5a039d9b4b78bfe245c9531d12..dabe75f183ac232c7d0fe9a4cf61eadf2fbe6421 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -16,7 +16,6 @@ #include #include #include "paddle/fluid/lite/core/mir/pass_manager.h" -#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h" namespace paddle { @@ -28,17 +27,19 @@ namespace lite { */ class Optimizer { public: - void Run(std::unique_ptr&& program, - const std::vector& valid_places, + void Run(mir::Program&& program, const std::vector& valid_places, const std::vector& passes = {}) { CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); - graph_->Build(*program, valid_places); + graph_->Build(program, valid_places); RunPasses(); } // Generate a new program based on the mir graph. - std::unique_ptr GenProgram() {} + std::unique_ptr GenProgram() { + std::unique_ptr res; + return res; + } // Generate C++ code which combines the inference program, model and weights. void GenCode(const std::string& code_dir); @@ -50,7 +51,7 @@ class Optimizer { protected: // Run the default passes registered in the PassManager. - void RunPasses() { mir::PassManager::Global().Run(); } + void RunPasses() { mir::PassManager::Global().Run(graph_); } // Specify the passes and run them. void RunPasses(std::vector& passes); diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h index 52b5ea7d02abe5800bbeea0082874d5413a1882a..990ffe0007bfbb8d62190d63e1df78fa0bb76718 100644 --- a/paddle/fluid/lite/core/types.h +++ b/paddle/fluid/lite/core/types.h @@ -25,6 +25,38 @@ using any_context_t = variant, // Context // >; +// Factors that impact the kernel picking strategy. Multiple factors can be +// considered together by using statement like 'factor1 | factor2' +class KernelPickFactor { + public: + using value_type = unsigned char; + enum class Factor : int { + // The following factors are sorted by priority. + TargetFirst = 1, + PrecisionFirst = 1 << 1, + DataLayoutFirst = 1 << 2, + DeviceFirst = 1 << 3, + }; + + // Has any factors considered. + bool AnyFactorConsidered() const { return data_; } + + KernelPickFactor& ConsiderTarget(); + KernelPickFactor& ConsiderPrecision(); + KernelPickFactor& ConsiderDataLayout(); + KernelPickFactor& ConsiderDevice(); + + bool IsTargetConsidered() const; + bool IsPrecisionConsidered() const; + bool IsDataLayoutConsidered() const; + bool IsDeviceConsidered() const { + return data_ & static_cast(Factor::DeviceFirst); + } + + private: + unsigned char data_{}; +}; + struct dim2 { int x{}; int y{};