未验证 提交 7ae73e33 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #12432 from Superjomn/fea/analysis-ssa

inference analysis supports SSA
...@@ -36,6 +36,8 @@ namespace analysis { ...@@ -36,6 +36,8 @@ namespace analysis {
/* /*
* DataFlowGraph - A container of Value and Function Nodes. * DataFlowGraph - A container of Value and Function Nodes.
*
* This is the base graph for any other type of graphs, such as SSA or CFG.
*/ */
struct DataFlowGraph { struct DataFlowGraph {
NodeMap nodes; NodeMap nodes;
......
...@@ -40,6 +40,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -40,6 +40,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(desc_); PADDLE_ENFORCE(desc_);
// insert vars // insert vars
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std::unordered_map<std::string, size_t> var2id; std::unordered_map<std::string, size_t> var2id;
auto &main_block = desc_->blocks(framework::kRootBlockIndex); auto &main_block = desc_->blocks(framework::kRootBlockIndex);
for (int i = 0; i < main_block.vars_size(); i++) { for (int i = 0; i < main_block.vars_size(); i++) {
...@@ -51,6 +53,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -51,6 +53,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
var2id[var.name()] = v->id(); var2id[var.name()] = v->id();
} }
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std::unordered_set<Node *> unique_written_vars;
for (int i = 0; i < main_block.ops_size(); i++) { for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i); const auto &op = main_block.ops(i);
auto *o = graph->nodes.Create(Node::Type::kFunction); auto *o = graph->nodes.Create(Node::Type::kFunction);
...@@ -62,33 +73,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -62,33 +73,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
o->SetPbMsg(op.SerializeAsString()); o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs // set inputs and outputs
std::unordered_set<Node *> inlinks;
for (int j = 0; j < op.inputs_size(); j++) { for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j); auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) { for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k))); auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o); in->outlinks.push_back(o);
o->inlinks.push_back(in); o->inlinks.push_back(in);
inlinks.insert(in);
} }
} }
for (int j = 0; j < op.outputs_size(); j++) { for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j); auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) { for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]); auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
if (inlinks.count(out)) { if (unique_written_vars.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a). // Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = graph->nodes.Create(Node::Type::kValue); auto *out_alias = graph->nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name()); out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc()); out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg()); out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] = out_alias->id(); // update a -> a0 var2id[out_alias->name()] =
out_alias->id(); // update variable's alias Node
LOG(INFO) << "loop found in graph, create SSA alias node [" LOG(INFO) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]"; << out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias; out = out_alias;
} }
out->inlinks.push_back(o); out->inlinks.push_back(o);
o->outlinks.push_back(out); o->outlinks.push_back(out);
unique_written_vars.insert(out);
} }
} }
} }
......
...@@ -30,7 +30,7 @@ namespace inference { ...@@ -30,7 +30,7 @@ namespace inference {
namespace analysis { namespace analysis {
/* /*
* Transform a FluidDesc to a data flow graph. * Transform a FluidDesc to a SSA.
*/ */
class FluidToDataFlowGraphPass final : public DataFlowGraphPass { class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册