未验证 提交 d96ee24f 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #12697 from panyx0718/ir2

test and doc IR Graph
...@@ -28,6 +28,38 @@ namespace paddle { ...@@ -28,6 +28,38 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* The graph is a Directed Acyclic Single Static Assignment Graph.
*
* In more detail, the following properties must hold:
*
* The graph shouldn't contain cycle. Each node is a black-box to the graph
* so the node itself could be a loop operator.
*
* Each Variable-type node has only one input (thus single static assignment).
*
* The output/input of operator is variable and the output/input of variable
* is operator.
*
* The following data harzards in Program are addressed in the Graph:
*
* Write-After-Read
* a = op1(x)
* x = op2(b)
* A control-dependency connection is created bettwen op1 and op2 such that
* op1->op2, so as to ensure correct order.
*
* Write-After-Write
* x = op1(a)
* x = op2(b)
* A control-dependency connection is created between op1 and op2 such that
* op1->op2, so as to ensure correct order.
*
* Other properties currently hold, but is not enforced yet:
*
* Variable-type node (not control dep) with the same variable name share
* the same underlying VarDesc.
*/
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc &program); explicit Graph(const ProgramDesc &program);
......
...@@ -36,7 +36,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -36,7 +36,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "").AsDuplicable(); AddInput("X", "").AsDuplicable();
AddOutput("Out", ""); AddOutput("Out", "").AsDuplicable();
AddComment(""); AddComment("");
} }
}; };
...@@ -59,11 +59,27 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -59,11 +59,27 @@ class SumOpVarTypeInference : public VarTypeInference {
block->Var(out_var_name)->SetType(default_var_type); block->Var(out_var_name)->SetType(default_var_type);
} }
}; };
class DummyOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "").AsDuplicable();
AddComment("");
}
};
class DummyOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference); paddle::framework::SumOpVarTypeInference);
REGISTER_OPERATOR(dummy, paddle::framework::NOP, paddle::framework::SumOpMaker,
paddle::framework::SumOpVarTypeInference);
REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
paddle::framework::SumOpMaker); paddle::framework::SumOpMaker);
...@@ -110,5 +126,83 @@ TEST(GraphTest, Basic) { ...@@ -110,5 +126,83 @@ TEST(GraphTest, Basic) {
} }
ASSERT_EQ(nodes.size(), 5); ASSERT_EQ(nodes.size(), 5);
} }
TEST(GraphTest, WriteAfterRead) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2);
}
}
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, WriteAfterWrite) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2);
ASSERT_EQ(control_dep1, control_dep2);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册