提交 eada00c2 编写于 作者: S superjomn

init optimizer

上级 239d716b
......@@ -9,8 +9,14 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
#TODO(Superjomn) remove these dependencies from original framework
proto_desc)
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS tensor_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager)
cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite
ops_lite
host_kernels
)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
......@@ -18,5 +24,6 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels)
cc_test(test_type_system SRCS type_system_test.cc DEPS type_system)
cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes)
add_subdirectory(mir)
......@@ -15,7 +15,9 @@
#pragma once
#include <map>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
......@@ -48,17 +50,24 @@ class KernelBase {
return param_.get<Param>();
}
void set_op_type(const std::string& type) { op_type_ = type; }
const std::string& op_type() const { return op_type_; }
void Torch() {}
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
virtual std::string name() const = 0;
virtual ~KernelBase() = default;
protected:
core::any_context_t context_;
mutable operators::param_t param_;
// The corresponding op type.
std::string op_type_;
};
/*
......@@ -73,8 +82,9 @@ struct ParamType {
Place tensor_place{};
const Type* type_;
ParamType() = default;
ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {}
explicit ParamType() = default;
explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) {}
......@@ -135,7 +145,8 @@ class ParamTypeRegistry {
* PRECISION(kFloat)});
*/
struct NewInstance {
NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {}
explicit NewInstance(const std::string& kernel_type)
: kernel_type_(kernel_type) {}
NewInstance& BindInput(int offset, const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kInput>(
......@@ -205,6 +216,10 @@ class OpKernel : public KernelBase {
TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; }
std::string name() const override {
return op_type() + ":" + TargetToStr(Target) + "/" +
PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
}
void Touch() {}
......
......@@ -9,13 +9,14 @@ cc_library(mir_passes
graph_visualize_pass.cc
generate_program_pass.cc
demo_pass.cc
DEPS mir_pass)
DEPS mir_pass types_lite)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite
proto_desc ops_lite
ops_lite
host_kernels
mir_passes
mir_pass_manager
program_fake_utils
)
......@@ -11,3 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_program_pass,
paddle::lite::mir::GenerateProgramPass);
......@@ -24,8 +24,9 @@ namespace mir {
* GenerateProgramPass will build the execution program for executor from a mir
* graph.
*/
class GenerateProgramPass : public Pass {
class GenerateProgramPass : public ProgramPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph> &graph) override;
};
} // namespace mir
......
......@@ -11,3 +11,18 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(io_complement_pass, paddle::lite::mir::IoComplementPass);
......@@ -24,8 +24,9 @@ namespace mir {
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class IoComplementPass : public Pass {
class IoComplementPass : public ProgramPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph> &graph) override;
};
} // namespace mir
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
......
......@@ -32,16 +32,17 @@ class PassManager {
PassManager();
void Run() {
void Run(std::unique_ptr<SSAGraph>& graph) {
for (auto& pass : passes_) {
LOG(INFO) << "Running MIR pass " << pass->name();
pass->Apply(graph_);
pass->Apply(graph);
}
}
bool AddNewPass(const std::string& name, Pass* pass) {
passes_.emplace_back(pass);
pass_map_.emplace(name, passes_.back().get());
passes_.back()->set_name(name);
return true;
}
......@@ -65,12 +66,18 @@ class PassManager {
Pass* LookUp(const std::string& key) {
auto it = pass_map_.find(key);
CHECK(it != pass_map_.end());
return it->second;
if (it != pass_map_.end()) return it->second;
return nullptr;
}
template <typename PassTy>
PassTy* LookUp(const std::string& key) {
auto it = pass_map_.find(key);
if (it != pass_map_.end()) return dynamic_cast<PassTy*>(it->second);
return nullptr;
}
private:
std::unique_ptr<mir::SSAGraph> graph_;
std::list<std::unique_ptr<mir::Pass>> passes_;
std::map<std::string, mir::Pass*> pass_map_;
};
......
......@@ -16,7 +16,9 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program_fake_utils.h"
namespace paddle {
namespace lite {
......@@ -32,58 +34,6 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
fc->SetOutput("Out", {out});
}
Program FakeProgram() {
Program program;
program.scope = new lite::Scope;
auto add_fc = [&](int id, std::string x) {
// create variables
std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<Tensor>();
framework::OpDesc desc;
desc.SetInput("Input", {x});
desc.SetInput("W", {w1});
desc.SetInput("Bias", {b1});
desc.SetOutput("Out", {out1});
desc.SetType("fc");
desc.SetAttr("in_num_col_dims", 1);
desc.Flush();
// add to input
program.tmp_vars.push_back(w1);
program.tmp_vars.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope);
program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100});
b1v->Resize({100, 1});
out1v->Resize({100, 100});
return out1;
};
// x1, w1, b1 -fc-> out1
// out1, w2, b2 -fc-> out2
std::string x = "x";
program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<Tensor>();
xv->Resize({100, 100});
for (int i = 0; i < 3; i++) {
x = add_fc(i, x);
}
return program;
}
TEST(SSAGraph, test) {
auto program = FakeProgram();
SSAGraph graph;
......
......@@ -11,3 +11,49 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
bool KernelScoreCmp(const std::pair<size_t, std::unique_ptr<KernelBase>>& a,
const std::pair<size_t, std::unique_ptr<KernelBase>>& b) {
return a.first > b.first;
}
void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK(kernel_pick_factors_.AnyFactorConsidered())
<< "kernel_pick_factors should be specified first";
CHECK(graph) << "graph not valid";
// sort kernels by the factors.
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& instruct = node.AsInstruct();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
for (auto&& kernel : instruct.valid_kernels) {
scored.emplace_back(KernelGrade(*kernel), std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second));
LOG(INFO) << "pick " << instruct.valid_kernels.front()->name();
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(static_kernel_pick_pass,
paddle::lite::mir::StaticKernelPickPass);
......@@ -14,13 +14,65 @@
#pragma once
#include <limits>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
class StaticKernelPickPass : public mir::Pass {};
/*
* StaticKernelPickPass is a simple strategy for picking the kernel for each
* Operator using operator developer defined rule, there are many other tactics
* such as considering IO or kernel execution latency and we will implement them
* latter.
*
* There are two argument for this pass:
* - place, the target place.
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
*/
class StaticKernelPickPass : public mir::InstructionPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
const Place& place() const { return place_; }
const core::KernelPickFactor& kernel_pick_factors() const {
return kernel_pick_factors_;
}
core::KernelPickFactor* mutable_kernel_pick_factors() {
return &kernel_pick_factors_;
}
private:
// Score the kernel.
size_t KernelGrade(const lite::KernelBase& kernel) {
size_t score{};
const int kMax =
std::numeric_limits<core::KernelPickFactor::value_type>::max();
if (kernel_pick_factors_.IsTargetConsidered() &&
place().target == kernel.target()) {
score +=
kMax / static_cast<int>(core::KernelPickFactor::Factor::TargetFirst);
}
if (kernel_pick_factors_.IsPrecisionConsidered() &&
place().precision == kernel.precision()) {
score += kMax /
static_cast<int>(core::KernelPickFactor::Factor::PrecisionFirst);
}
// The data layout is not considered, for the input and output arguments
// might have different data layout.
// TODO(Superjomn) reconsider the idea of taking the data layout as a kernel
// specification.
return score;
}
private:
core::KernelPickFactor kernel_pick_factors_;
Place place_;
};
} // namespace mir
} // namespace lite
......
......@@ -120,8 +120,10 @@ class KernelRegistor : public lite::Registor<KernelType> {
LOG(INFO) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision);
KernelRegistry::Global().Register<target, precision>(
op_type, [&]() -> std::unique_ptr<KernelType> {
return std::unique_ptr<KernelType>(new KernelType);
op_type, [&, op_type]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType);
x->set_op_type(op_type);
return x;
});
}) {}
};
......
......@@ -16,7 +16,6 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
......@@ -28,17 +27,19 @@ namespace lite {
*/
class Optimizer {
public:
void Run(std::unique_ptr<mir::Program>&& program,
const std::vector<Place>& valid_places,
void Run(mir::Program&& program, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}) {
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(*program, valid_places);
graph_->Build(program, valid_places);
RunPasses();
}
// Generate a new program based on the mir graph.
std::unique_ptr<mir::Program> GenProgram() {}
std::unique_ptr<mir::Program> GenProgram() {
std::unique_ptr<mir::Program> res;
return res;
}
// Generate C++ code which combines the inference program, model and weights.
void GenCode(const std::string& code_dir);
......@@ -50,7 +51,7 @@ class Optimizer {
protected:
// Run the default passes registered in the PassManager.
void RunPasses() { mir::PassManager::Global().Run(); }
void RunPasses() { mir::PassManager::Global().Run(graph_); }
// Specify the passes and run them.
void RunPasses(std::vector<std::string>& passes);
......
......@@ -25,6 +25,38 @@ using any_context_t = variant<Context<TARGET(kX86)>, //
Context<TARGET(kCUDA)> //
>;
// Factors that impact the kernel picking strategy. Multiple factors can be
// considered together by using statement like 'factor1 | factor2'
class KernelPickFactor {
public:
using value_type = unsigned char;
enum class Factor : int {
// The following factors are sorted by priority.
TargetFirst = 1,
PrecisionFirst = 1 << 1,
DataLayoutFirst = 1 << 2,
DeviceFirst = 1 << 3,
};
// Has any factors considered.
bool AnyFactorConsidered() const { return data_; }
KernelPickFactor& ConsiderTarget();
KernelPickFactor& ConsiderPrecision();
KernelPickFactor& ConsiderDataLayout();
KernelPickFactor& ConsiderDevice();
bool IsTargetConsidered() const;
bool IsPrecisionConsidered() const;
bool IsDataLayoutConsidered() const;
bool IsDeviceConsidered() const {
return data_ & static_cast<int>(Factor::DeviceFirst);
}
private:
unsigned char data_{};
};
struct dim2 {
int x{};
int y{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册