// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/lite/core/mir/generate_program_pass.h" #include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/types.h" namespace paddle { namespace lite { /* * lite::Optimizer optimize a program. It utilize the mir passes to analysis the * program and export an optimized program. */ class Optimizer { public: void Run(Program&& program, const std::vector& valid_places, core::KernelPickFactor kernel_pick_factor, const std::vector& passes = {}) { CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); SpecifyKernelPickTactic(kernel_pick_factor); RunPasses(); exec_scope_ = program.exec_scope; } // Generate a new program based on the mir graph. std::unique_ptr GenRuntimeProgram() { std::unique_ptr res; auto pass = mir::PassManager::Global().LookUp( "generate_program_pass"); auto program = pass->GenProgram(); CHECK(exec_scope_); program->set_exec_scope(exec_scope_); return program; } // Generate C++ code which combines the inference program, model and weights. void GenCode(const std::string& code_dir); const mir::SSAGraph& ssa_graph() const { CHECK(graph_); return *graph_; } protected: void SpecifyKernelPickTactic(core::KernelPickFactor factor); // Run the default passes registered in the PassManager. void RunPasses() { mir::PassManager::Global().Run(graph_); } // Specify the passes and run them. void RunPasses(std::vector& passes); private: std::unique_ptr graph_; lite::Scope* exec_scope_{}; }; } // namespace lite } // namespace paddle