提交 0245a2dd 编写于 作者: S Superjomn

add variable inference pass tester

and code clean
上级 28d27145
...@@ -96,10 +96,7 @@ class KernelBase { ...@@ -96,10 +96,7 @@ class KernelBase {
return type->type; return type->type;
} }
void set_alias(const std::string& x) { void set_alias(const std::string& x) { alias_ = x; }
alias_ = x;
LOG(INFO) << "kernel " << op_type() << " setting alias " << alias();
}
const std::string& alias() const { return alias_; } const std::string& alias() const { return alias_; }
virtual Place place() const = 0; virtual Place place() const = 0;
......
...@@ -24,3 +24,14 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS ...@@ -24,3 +24,14 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_pass_manager mir_pass_manager
program_fake_utils program_fake_utils
) )
cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS
ops_lite
host_kernels
kernels_cuda
mir_passes
mir_pass_manager
optimizer_lite
program_fake_utils
target_wrapper_host
target_wrapper_cuda
)
...@@ -36,10 +36,7 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -36,10 +36,7 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in);
} }
} }
VLOG(3) << "\n" << Visualize(graph.get());
// PickIoCopyKernel(graph.get());
LOG(INFO) << "\n" << Visualize(graph.get());
} }
void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
...@@ -96,6 +93,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, ...@@ -96,6 +93,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
// create Op and kernels. // create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create("io_copy"); auto io_copy_op = LiteOpRegistry::Global().Create("io_copy");
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op); // CHECK(io_copy_op);
// Create the new var manually. // Create the new var manually.
inst_node->AsInstruct().op->scope()->Var(io_copy_output_name); inst_node->AsInstruct().op->scope()->Var(io_copy_output_name);
...@@ -144,36 +142,6 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, ...@@ -144,36 +142,6 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
graph->CheckValid(); graph->CheckValid();
} }
void IoComplementPass::PickIoCopyKernel(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsInstruct() && node.AsInstruct().op_type == "io_copy") {
auto& kernels = node.AsInstruct().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
for (auto& kernel : kernels) {
CHECK_EQ(node.inlinks.size(), 1UL);
CHECK_EQ(node.outlinks.size(), 1UL);
auto* inty = node.inlinks.front()->AsArgument().type;
auto* outy = node.outlinks.front()->AsArgument().type;
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatibleTo(*inty, *in_arg_ty)) {
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
// Both the input and output type matches, remove other kernels
// directly.
if (out_arg_ty->target() == outy->target()) {
LOG(INFO) << "get a IOCopy kernel";
auto x = std::move(kernel);
kernels.clear();
kernels.emplace_back(std::move(x));
break;
}
}
}
}
}
// Check the compatiblity.
}
void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) { void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()); CHECK(!valid_places.empty());
valid_places_ = valid_places; valid_places_ = valid_places;
......
...@@ -26,7 +26,6 @@ static void UpdateInputTo(framework::proto::OpDesc* desc, ...@@ -26,7 +26,6 @@ static void UpdateInputTo(framework::proto::OpDesc* desc,
for (auto& item : *desc->mutable_inputs()) { for (auto& item : *desc->mutable_inputs()) {
for (auto& input : *item.mutable_arguments()) { for (auto& input : *item.mutable_arguments()) {
if (input == from) { if (input == from) {
LOG(INFO) << "** update input argument from " << from << " to " << to;
input = to; input = to;
} }
} }
...@@ -49,9 +48,6 @@ class IoComplementPass : public ProgramPass { ...@@ -49,9 +48,6 @@ class IoComplementPass : public ProgramPass {
void SetValidPlaces(const std::vector<Place>& valid_places); void SetValidPlaces(const std::vector<Place>& valid_places);
// Pick the right kernel of IoCopy considering the input and output Type.
void PickIoCopyKernel(SSAGraph* graph);
const std::vector<Place>& valid_places() const { return valid_places_; }; const std::vector<Place>& valid_places() const { return valid_places_; };
private: private:
......
...@@ -25,7 +25,7 @@ namespace mir { ...@@ -25,7 +25,7 @@ namespace mir {
class PassRegistry { class PassRegistry {
public: public:
PassRegistry(const std::string& name, mir::Pass* pass) { PassRegistry(const std::string& name, mir::Pass* pass) {
LOG(INFO) << "Registry add MIR pass " << name; VLOG(2) << "Registry add MIR pass " << name;
PassManager::Global().AddNewPass(name, pass); PassManager::Global().AddNewPass(name, pass);
} }
......
...@@ -91,7 +91,9 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() { ...@@ -91,7 +91,9 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars) { for (const auto &name : program.tmp_vars) {
LOG(INFO) << "create arg node " << name; CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
new_node.AsArgument(name); new_node.AsArgument(name);
...@@ -102,7 +104,9 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { ...@@ -102,7 +104,9 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes. // create weight nodes.
for (const auto &name : program.weights) { for (const auto &name : program.weights) {
LOG(INFO) << "create arg node " << name; CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
new_node.AsArgument(name); new_node.AsArgument(name);
...@@ -134,10 +138,8 @@ void SSAGraph::Build(const Program &program, ...@@ -134,10 +138,8 @@ void SSAGraph::Build(const Program &program,
for (auto &op : program.ops) { for (auto &op : program.ops) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places); auto *op_node = GraphCreateInstructNode(program, op, valid_places);
LOG(INFO) << "checking op " << op->op_type_;
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name); auto *arg = Argument(name);
LOG(INFO) << "input " << name;
CHECK(arg->IsRoleSet()); CHECK(arg->IsRoleSet());
DirectedLink(arg, op_node); DirectedLink(arg, op_node);
} }
...@@ -145,7 +147,6 @@ void SSAGraph::Build(const Program &program, ...@@ -145,7 +147,6 @@ void SSAGraph::Build(const Program &program,
if (!arguments_.count(name)) { if (!arguments_.count(name)) {
NewArgumentNode(name); NewArgumentNode(name);
} }
LOG(INFO) << "output " << name;
auto *arg = arguments_.at(name); auto *arg = arguments_.at(name);
CHECK(arg->IsRoleSet()); CHECK(arg->IsRoleSet());
DirectedLink(op_node, arg); DirectedLink(op_node, arg);
......
...@@ -35,7 +35,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x, ...@@ -35,7 +35,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
} }
TEST(SSAGraph, test) { TEST(SSAGraph, test) {
auto program = FakeProgram(); auto program = ProgramFaker();
SSAGraph graph; SSAGraph graph;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
......
...@@ -38,7 +38,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -38,7 +38,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored; std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
for (auto&& kernel : instruct.valid_kernels) { for (auto&& kernel : instruct.valid_kernels) {
size_t score = KernelGrade(*kernel); size_t score = KernelGrade(*kernel);
LOG(INFO) << "kernel " << kernel->summary() << " " << score;
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
} }
...@@ -49,7 +48,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -49,7 +48,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this. // TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear(); instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second)); instruct.valid_kernels.emplace_back(std::move(scored.front().second));
LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); VLOG(2) << "pick " << instruct.valid_kernels.front()->name();
} }
} }
......
...@@ -74,10 +74,10 @@ class StaticKernelPickPass : public mir::InstructionPass { ...@@ -74,10 +74,10 @@ class StaticKernelPickPass : public mir::InstructionPass {
score += kMax / static_cast<int>( score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::DataLayoutFirst); core::KernelPickFactor::Factor::DataLayoutFirst);
} }
LOG(INFO) << "picker tactic " << kernel_pick_factors_; VLOG(4) << "picker tactic " << kernel_pick_factors_;
LOG(INFO) << "kernel place " << kernel.place(); VLOG(4) << "kernel place " << kernel.place();
LOG(INFO) << "picker place " << place(); VLOG(4) << "picker place " << place();
LOG(INFO) << "score " << score; VLOG(4) << "score " << score;
// The data layout is not considered, for the input and output arguments // The data layout is not considered, for the input and output arguments
// might have different data layout. // might have different data layout.
......
...@@ -51,49 +51,54 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -51,49 +51,54 @@ class VariablePlaceInferencePass : public DebugPass {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (node.IsArgument()) { if (node.IsArgument()) {
CHECK(node.AsArgument().type) << "node " << node.AsArgument().name CHECK(node.AsArgument().type) << "node " << node.AsArgument().name
<< " type not determined"; << " type not determined, " << &node;
} }
} }
} }
void InferenceArgumentPlace(SSAGraph* graph) { void InferenceArgumentPlace(SSAGraph* graph) {
LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global(); VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) { for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct(); auto& inst = x->AsInstruct();
// The IoCopyOp is a tool operator, it won't support the type inference. // The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue; if (inst.op_type == "io_copy") continue;
// LOG(INFO) << "- inferencing type " << // LOG(INFO) << "- inferencing type " <<
// deal with inputs // deal with inputs
VLOG(4) << "inferencing op " << inst.op_type;
for (auto& arg_name : inst.op_info()->input_argnames()) { for (auto& arg_name : inst.op_info()->input_argnames()) {
LOG(INFO) << "-- input arg_name " << arg_name; VLOG(3) << "-- input arg_name " << arg_name;
// check if inputs's place is set, if not set, update them with the // check if inputs's place is set, if not set, update them with the
// kernel's declaration. // kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name); auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->input_argument().at(arg_name); auto arg_names = inst.op_info()->input_argument().at(arg_name);
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
LOG(INFO) << "--- var " << arg_name; VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.type) continue; if (!arg_node.type) {
VLOG(4) << "set type " << *type << " " << node;
arg_node.type = type; arg_node.type = type;
} }
} }
}
for (auto& arg_name : inst.op_info()->output_argnames()) { for (auto& arg_name : inst.op_info()->output_argnames()) {
LOG(INFO) << "-- output arg_name " << arg_name; VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
auto arg_names = inst.op_info()->output_argument().at(arg_name); auto arg_names = inst.op_info()->output_argument().at(arg_name);
// check if outputs's place is set, if not set, update them with the // check if outputs's place is set, if not set, update them with the
// kernel's declaration. // kernel's declaration.
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
LOG(INFO) << "--- var " << arg_name; VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.type) continue; if (!arg_node.type) {
node->AsArgument().type = type; node->AsArgument().type = type;
VLOG(3) << "set type " << *type;
}
} }
} }
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program_fake_utils.h"
#include "paddle/fluid/lite/kernels/cuda/use_kernels.h"
#include "paddle/fluid/lite/kernels/host/use_kernels.h"
namespace paddle {
namespace lite {
namespace mir {
TEST(variable_place_inference_pass, test) {
std::shared_ptr<Scope> scope(new lite::Scope);
ProgramFaker program_faker;
program_faker.AddFeed("a", 0);
program_faker.AddMul("a", "W", "a1");
program_faker.AddMul("a1", "W1", "a2");
program_faker.AddFetch("a2", 0);
program_faker.CreateVars(scope.get());
auto* desc = program_faker.program();
Optimizer optimizer;
std::vector<Place> places({
Place{
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW),
},
Place{
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny),
},
Place{
TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW),
},
Place{
TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny),
},
});
Program program(*desc, scope, places);
core::KernelPickFactor factor;
factor.ConsiderTarget();
std::vector<std::string> passes({
"static_kernel_pick_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_complement_pass", //
});
Place prefered_place{
TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW),
};
optimizer.KernelPickPreferPlace(prefered_place);
optimizer.Run(std::move(program), places, factor, passes);
}
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
...@@ -35,7 +35,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -35,7 +35,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
} }
CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
LOG(INFO) << "op " << op_type_ << " get " << kernels.size() << " kernels"; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels; return kernels;
} }
......
...@@ -21,7 +21,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -21,7 +21,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type, TargetType target, PrecisionType precision, const std::string &op_type, TargetType target, PrecisionType precision,
DataLayoutType layout) { DataLayoutType layout) {
Place place{target, precision, layout}; Place place{target, precision, layout};
LOG(INFO) << "creating " << op_type << " kernel for " << place; VLOG(5) << "creating " << op_type << " kernel for " << place;
#define CREATE_KERNEL1(target__, precision__) \ #define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \ switch (layout) { \
case DATALAYOUT(kNCHW): \ case DATALAYOUT(kNCHW): \
......
...@@ -81,7 +81,7 @@ class KernelRegistry final { ...@@ -81,7 +81,7 @@ class KernelRegistry final {
void Register(const std::string &name, void Register(const std::string &name,
typename KernelRegistryForTarget<Target, Precision, typename KernelRegistryForTarget<Target, Precision,
Layout>::creator_t &&creator) { Layout>::creator_t &&creator) {
LOG(INFO) << "register for " << TargetToStr(Target) << ":" VLOG(3) << "register for " << TargetToStr(Target) << ":"
<< PrecisionToStr(Precision) << "//" << PrecisionToStr(Precision) << "//"
<< GetKernelOffset<Target, Precision, Layout>(); << GetKernelOffset<Target, Precision, Layout>();
using kernel_registor_t = using kernel_registor_t =
...@@ -144,7 +144,7 @@ class KernelRegistor : public lite::Registor<KernelType> { ...@@ -144,7 +144,7 @@ class KernelRegistor : public lite::Registor<KernelType> {
public: public:
KernelRegistor(const std::string &op_type, const std::string &alias) KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([=] { : Registor<KernelType>([=] {
LOG(INFO) << "Register kernel " << op_type << " for " VLOG(3) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision) << TargetToStr(target) << " " << PrecisionToStr(precision)
<< " " << DataLayoutToStr(layout) << " alias " << alias; << " " << DataLayoutToStr(layout) << " alias " << alias;
KernelRegistry::Global().Register<target, precision, layout>( KernelRegistry::Global().Register<target, precision, layout>(
......
...@@ -27,33 +27,5 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) { ...@@ -27,33 +27,5 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
*pass->mutable_kernel_pick_factors() = factor; *pass->mutable_kernel_pick_factors() = factor;
} }
void Optimizer::RunPasses() {
std::vector<std::string> passes({
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_complement_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
"runtime_context_assign_pass", //
});
for (auto& pass_type : passes) {
LOG(INFO) << ".. running pass " << pass_type;
auto* pass = mir::PassManager::Global().LookUp(pass_type);
CHECK(pass);
if (pass->name() == "io_complement_pass") {
auto* _pass = dynamic_cast<mir::IoComplementPass*>(pass);
_pass->SetValidPlaces(valid_places_);
CHECK(!_pass->valid_places().empty());
_pass->Apply(graph_);
} else {
pass->Apply(graph_);
}
}
// mir::PassManager::Global().Run(graph_);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -41,8 +41,24 @@ class Optimizer { ...@@ -41,8 +41,24 @@ class Optimizer {
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places); graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
// InitIoComplement(); InitIoComplement();
RunPasses();
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_complement_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
"runtime_context_assign_pass", //
}});
} else {
RunPasses(passes);
}
exec_scope_ = program.exec_scope; exec_scope_ = program.exec_scope;
} }
...@@ -86,11 +102,15 @@ class Optimizer { ...@@ -86,11 +102,15 @@ class Optimizer {
protected: protected:
void SpecifyKernelPickTactic(core::KernelPickFactor factor); void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Run the default passes registered in the PassManager.
void RunPasses();
// Specify the passes and run them. // Specify the passes and run them.
void RunPasses(std::vector<std::string>& passes); void RunPasses(const std::vector<std::string>& passes) {
for (auto& x : passes) {
LOG(INFO) << "== Running pass " << x;
auto* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass);
pass->Apply(graph_);
}
}
private: private:
std::unique_ptr<mir::SSAGraph> graph_; std::unique_ptr<mir::SSAGraph> graph_;
......
...@@ -25,7 +25,7 @@ namespace lite { ...@@ -25,7 +25,7 @@ namespace lite {
TEST(Optimizer, test) { TEST(Optimizer, test) {
Optimizer optimizer; Optimizer optimizer;
auto program = FakeProgram(); auto program = ProgramFaker();
std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}}); std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}});
auto* pick_pass = auto* pick_pass =
......
...@@ -64,7 +64,7 @@ struct Program { ...@@ -64,7 +64,7 @@ struct Program {
for (auto* op_desc : program.Block(0).AllOps()) { for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type(); auto op_type = op_desc->Type();
// if (op_type == "feed" || op_type == "fetch") continue; // if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]"; VLOG(4) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type)); ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel // pick initial kernel
ops.back()->PickKernel(valid_places); ops.back()->PickKernel(valid_places);
......
...@@ -71,5 +71,68 @@ Program FakeProgram() { ...@@ -71,5 +71,68 @@ Program FakeProgram() {
return program; return program;
} }
class ProgramFaker {
public:
ProgramFaker() {}
framework::ProgramDesc* program() {
desc_.Flush();
return &desc_;
}
void CreateVars(lite::Scope* scope) {
for (auto& var : tmp_vars_) {
auto* x = scope->Var(var);
x->GetMutable<lite::Tensor>();
}
for (auto& x : tmp_vars_) {
desc_.MutableBlock(0)->Var(x);
}
}
void AddMul(const std::string& X, const std::string& Y,
const std::string& out) {
tmp_vars_.insert(X);
tmp_vars_.insert(Y);
tmp_vars_.insert(out);
auto* block = desc_.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("mul");
op->SetInput("X", {X});
op->SetInput("Y", {Y});
op->SetOutput("Out", {Y});
op->SetAttr("x_num_col_dims", 1);
op->SetAttr("y_num_col_dims", 1);
}
void AddFeed(const std::string& Out, int col) {
tmp_vars_.insert(Out);
auto* block = desc_.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("feed");
op->SetInput("X", {"feed"});
op->SetOutput("Out", {Out});
op->SetAttr("col", col);
}
void AddFetch(const std::string& Input, int col) {
tmp_vars_.insert(Input);
auto* block = desc_.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("fetch");
op->SetInput("X", {Input});
op->SetOutput("Out", {"fetch"});
op->SetAttr("col", col);
}
private:
std::set<std::string> tmp_vars_;
std::vector<std::string> weight_vars_;
framework::ProgramDesc desc_;
};
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -142,6 +142,8 @@ class Type : public DataTypeBase { ...@@ -142,6 +142,8 @@ class Type : public DataTypeBase {
} }
if (other.is_tensor_) { if (other.is_tensor_) {
os << "<Tensor:"; os << "<Tensor:";
} else {
os << "<";
} }
os << TargetToStr(other.target()) << "/" os << TargetToStr(other.target()) << "/"
<< PrecisionToStr(other.precision()) << "/" << PrecisionToStr(other.precision()) << "/"
...@@ -256,53 +258,6 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, ...@@ -256,53 +258,6 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place); bool is_tensor, Place place);
// ------------------------- end predefined types --------------------------- // ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
class TypeSystem {
private:
// Put all valid types for Variables here!
TypeSystem() {
// Tensor is a valid data type for Variable.
Register<Tensor>("tensor");
}
public:
static TypeSystem& Global() {
static TypeSystem x;
return x;
}
template <typename T>
void Register(const std::string& type) {
size_t hash = typeid(T).hash_code();
CHECK(!types_.count(hash)) << "duplicate register type " << type
<< " found!";
types_[hash] = type;
names_.insert(type);
}
template <typename T>
bool Contains() const {
return types_.count(typeid(T).hash_code());
}
bool Contains(size_t hash) const { return types_.count(hash); }
bool Contains(const std::string& type) { return names_.count(type); }
std::string DebugInfo() const {
std::stringstream ss;
for (const auto& it : types_) {
ss << it.second << "\n";
}
return ss.str();
}
private:
std::unordered_map<size_t /*hash*/, std::string /*name*/> types_;
TypeSystem(const TypeSystem&) = delete;
std::unordered_set<std::string> names_;
};
/* /*
* ParamType is used to represent a data type of a parameter for the kernel. It * ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type. * can represent any Variable data type.
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/op_registry.h"
// TODO(Superjomn) make this file a library, that will make compile dependency
// easier.
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/op_registry.h"
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
...@@ -95,7 +95,7 @@ class OpDesc { ...@@ -95,7 +95,7 @@ class OpDesc {
std::string op_type; std::string op_type;
std::map<std::string, std::vector<std::string>> inputs; std::map<std::string, std::vector<std::string>> inputs;
std::map<std::string, std::vector<std::string>> outputs; std::map<std::string, std::vector<std::string>> outputs;
std::map<std::string, variant<int, std::string>> attrs; std::map<std::string, variant<int, float, std::string>> attrs;
}; };
class BlockDesc { class BlockDesc {
...@@ -112,6 +112,8 @@ class BlockDesc { ...@@ -112,6 +112,8 @@ class BlockDesc {
class ProgramDesc { class ProgramDesc {
public: public:
void Parse(const framework::proto::ProgramDesc& desc); void Parse(const framework::proto::ProgramDesc& desc);
BlockDesc block;
}; };
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册