提交 cdb12e59 编写于 作者: S superjomn

add ssa_graph test

上级 8b950a4f
cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(kernel_lite SRCS kernel.cc)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system)
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc)
......
......@@ -3,6 +3,18 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_demo_pass SRCS demo_pass.cc DEPS mir_pass)
cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
io_complement_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
demo_pass.cc
DEPS mir_pass)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite
proto_desc ops_lite
host_kernels
mir_passes
)
......@@ -19,15 +19,19 @@ namespace paddle {
namespace lite {
namespace mir {
class DemoPass : public mir::Pass {
class DemoPass : public mir::DebugPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {}
};
/*
bool RegisterDemoPass() {
return PassManager::Global().AddNewPass("demo", new DemoPass);
}
*/
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* GenerateProgramPass will build the execution program for executor from a mir
* graph.
*/
class GenerateProgramPass : public Pass {
public:
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include <set>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void GraphVisualizePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
Visualize(graph.get());
}
std::string Visualize(mir::SSAGraph* graph) {
inference::analysis::Dot dot;
int id = 0;
std::set<std::string> exists_args;
for (auto& node : graph->mutable_nodes()) {
std::string key;
if (node.IsArgument()) {
key = node.AsArgument().name;
} else {
key = node.AsInstruct().op_type + std::to_string(id++);
}
if (node.IsInstruct()) {
dot.AddNode(key, {});
for (auto& x : node.inlinks) {
auto name = x->AsArgument().name;
if (!exists_args.count(name)) {
dot.AddNode(name, {});
}
dot.AddEdge(name, key, {});
exists_args.insert(name);
}
for (auto& x : node.outlinks) {
auto name = x->AsArgument().name;
if (!exists_args.count(name)) {
dot.AddNode(name, {});
}
dot.AddEdge(key, name, {});
exists_args.insert(name);
}
}
}
auto res = dot.Build();
LOG(INFO) << "dot:\n" << res;
return res;
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* GraphVisualizePass helps to visualize an mir graph by exporting a DOT
* language file.
*/
class GraphVisualizePass : public DebugPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
};
std::string Visualize(mir::SSAGraph* graph);
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class IoComplementPass : public Pass {
public:
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -51,6 +51,18 @@ class Node {
Place place;
};
Argument& AsArgument(const std::string& name) {
auto& x = AsArgument();
x.name = name;
return x;
}
Instruct& AsInstruct(const std::string& op_type) {
auto& x = AsInstruct();
x.op_type = op_type;
return x;
}
// Set roles.
Argument& AsArgument() {
if (role_ != Role::kUnk) {
......
......@@ -22,14 +22,53 @@ namespace mir {
class Pass {
public:
// Some appoint here, one pass should be only one of the following kinds.
enum class Kind {
// Will modify the program/graph topology.
kProgramWise = 0,
// Will modify the instruction, with the graph topology fixed.
kInstructionWise,
// Will not modify the IR, just collect information or visualization.
kDebug,
};
Pass(Kind kind) : kind_(kind) {}
virtual void Apply(std::unique_ptr<mir::SSAGraph>& graph) = 0;
void set_name(const std::string& name) { name_ = name; }
const std::string& name() const { return name_; }
void set_doc(const std::string& doc) { doc_ = doc; }
const std::string& doc() const { return doc_; }
Kind kind() const { return kind_; }
bool is_debug_pass() const { return kind_ == Kind::kDebug; }
bool is_program_pass() const { return kind_ == Kind::kProgramWise; }
bool is_instruction_pass() const { return kind_ == Kind::kInstructionWise; }
virtual ~Pass() = default;
private:
const Kind kind_;
std::string name_;
std::string doc_;
};
// Different kinds.
class ProgramPass : public Pass {
public:
ProgramPass() : Pass(Kind::kProgramWise) {}
};
class InstructionPass : public Pass {
public:
InstructionPass() : Pass(Kind::kInstructionWise) {}
};
class DebugPass : public Pass {
public:
DebugPass() : Pass(Kind::kDebug) {}
};
} // namespace mir
......
......@@ -21,10 +21,8 @@ namespace mir {
PassManager::PassManager() {}
// Manually register here.
extern bool RegisterDemoPass();
static bool xx __attribute__((unused)) = RegisterDemoPass();
} // namespace mir
} // namespace lite
} // namespace paddle
USE_MIR_PASS(demo);
......@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
......@@ -28,3 +29,5 @@ TEST(PassManager, test) {
} // namespace mir
} // namespace lite
} // namespace paddle
USE_MIR_PASS(demo);
......@@ -32,6 +32,18 @@ class PassRegistry {
bool Touch() const { return true; }
};
#define REGISTER_MIR_PASS(name__, class__) \
paddle::lite::mir::PassRegistry mir_pass_registry##name__(#name__, \
new class__); \
bool mir_pass_registry##name__##_fake() { \
return mir_pass_registry##name__.Touch(); \
}
#define USE_MIR_PASS(name__) \
extern bool mir_pass_registry##name__##_fake(); \
static bool mir_pass_usage##name__ __attribute__((unused)) = \
mir_pass_registry##name__##_fake();
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -30,8 +30,9 @@ namespace mir {
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> inputs;
std::list<std::unique_ptr<OpLite>> ops;
lite::Scope *scope;
std::unique_ptr<lite::Scope> scope;
};
// An Graph for MIR. It is built from a list of Op and a scope.
......@@ -42,21 +43,38 @@ class SSAGraph : GraphBase {
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) {
// create inputs
for (const auto &name : program.inputs) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
for (auto &op : program.ops) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto &new_kernel = node_storage_.back().AsInstruct();
auto &new_kernel = node_storage_.back().AsInstruct(op->op_type_);
new_kernel.valid_kernels = op->CreateKernels(valid_places);
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->input_names()) {
new_node.inlinks.push_back(arguments_.at(name));
}
for (const std::string &name : op->output_names()) {
if (!arguments_.count(name)) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument(name);
arg.name = name;
arguments_.emplace(name, &new_node);
}
new_node.outlinks.push_back(arguments_.at(name));
}
}
......@@ -64,6 +82,9 @@ class SSAGraph : GraphBase {
std::vector<mir::Node *> TopoloticalOrder() const;
const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; }
private:
std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void BuildFc(framework::ProgramDesc* desc, const std::string& x,
const std::string& w, const std::string& b,
const std::string& out) {
auto* fc = desc->MutableBlock(0)->AppendOp();
fc->SetInput("Input", {x});
fc->SetInput("W", {w});
fc->SetInput("Bias", {b});
fc->SetOutput("Out", {out});
}
Program FakeProgram() {
Program program;
program.scope.reset(new lite::Scope);
auto add_fc = [&](int id, std::string x) {
// create variables
std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<Tensor>();
framework::OpDesc desc;
desc.SetInput("Input", {x});
desc.SetInput("W", {w1});
desc.SetInput("Bias", {b1});
desc.SetOutput("Out", {out1});
desc.SetType("fc");
desc.SetAttr("in_num_col_dims", 1);
desc.Flush();
// add to input
program.inputs.push_back(w1);
program.inputs.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100});
b1v->Resize({100, 1});
out1v->Resize({100, 100});
return out1;
};
// x1, w1, b1 -fc-> out1
// out1, w2, b2 -fc-> out2
std::string x = "x";
program.inputs.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<Tensor>();
xv->Resize({100, 100});
for (int i = 0; i < 3; i++) {
x = add_fc(i, x);
}
return program;
}
TEST(SSAGraph, test) {
auto program = FakeProgram();
SSAGraph graph;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
graph.Build(program, places);
Visualize(&graph);
}
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class StaticKernelPickPass : public mir::Pass {};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -35,7 +35,7 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
}
template <>
const Type* Type::Get<UnsupportedTy>(TargetType target, int device) {
const Type* Type::Get<UnsupportedTy>(TargetType target) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
}
......
......@@ -8,6 +8,7 @@ cc_library(host_kernels DEPS
relu_compute_host
mul_compute_host
scale_compute_host
DEPS kernel_lite
)
cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册