未验证 提交 5175b3cb 编写于 作者: C chengduo 提交者: GitHub

Add GraphChecker (#13580)

* add GraphNum

test=develop

* add graph number check in parallelExecutor

test=develop

* fix transformer_model bug

test=develop

* fix graph num
上级 7cd27617
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include <algorithm> #include <algorithm>
#include <deque>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -113,6 +113,74 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -113,6 +113,74 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
return adj_list; return adj_list;
} }
size_t GraphNum(const Graph &graph) {
std::unordered_set<ir::Node *> nodes = graph.Nodes();
std::unordered_set<ir::Node *> visited_nodes;
visited_nodes.reserve(nodes.size());
std::deque<ir::Node *> q_nodes;
std::vector<std::unordered_set<ir::Node *>> graph_nodes;
std::unordered_set<ir::Node *> g_nodes;
size_t graph_count = 0;
auto traverse_nodes = [&visited_nodes,
&q_nodes](const std::vector<ir::Node *> &nodes) {
std::copy_if(
nodes.begin(), nodes.end(), std::back_inserter(q_nodes),
[&visited_nodes](Node *node) { return !visited_nodes.count(node); });
};
while (visited_nodes.size() != nodes.size()) {
if (!q_nodes.empty()) {
auto cur_node = q_nodes.front();
q_nodes.pop_front();
visited_nodes.insert(cur_node);
g_nodes.insert(cur_node);
traverse_nodes(cur_node->inputs);
traverse_nodes(cur_node->outputs);
} else {
++graph_count;
if (g_nodes.size()) {
graph_nodes.emplace_back(g_nodes);
}
g_nodes.clear();
for (auto &n : nodes) {
if (visited_nodes.count(n) == 0) {
q_nodes.push_back(n);
break;
}
}
}
}
if (g_nodes.size()) {
graph_nodes.emplace_back(g_nodes);
}
if (VLOG_IS_ON(10)) {
VLOG(10) << "graph_num: " << graph_nodes.size();
for (auto &g_n : graph_nodes) {
VLOG(10) << "graph_nodes: " << g_n.size();
if (g_n.size() < 10) {
std::stringstream out;
for (auto &node : g_n) {
out << "\nNode: " << node->Name() << " in [";
for (auto &n : node->inputs) {
out << n->Name() << ", ";
}
out << "], out[";
for (auto &n : node->outputs) {
out << n->Name() << ", ";
}
out << "]";
}
VLOG(10) << out.str();
}
}
}
return graph_count;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,6 +27,8 @@ namespace ir { ...@@ -27,6 +27,8 @@ namespace ir {
// Test if the graph contains circle. // Test if the graph contains circle.
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
size_t GraphNum(const Graph &graph);
// Topology Sort the operations in the graph from inputs to outputs. // Topology Sort the operations in the graph from inputs to outputs.
// `graph` cannot contain circle. // `graph` cannot contain circle.
std::vector<ir::Node *> TopologySortOperations(const Graph &graph); std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
......
...@@ -120,6 +120,97 @@ TEST(GraphHelperTest, Basic) { ...@@ -120,6 +120,97 @@ TEST(GraphHelperTest, Basic) {
ASSERT_EQ(node_map.at("op2"), 1UL); ASSERT_EQ(node_map.at("op2"), 1UL);
ASSERT_TRUE(node_map.at("op3") < node_map.at("op5")); ASSERT_TRUE(node_map.at("op3") < node_map.at("op5"));
} }
void BuildZeroGraph(Graph* g) {}
void BuildOneGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
}
void BuildTwoGraphs(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
// o2->v3->o5
// o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
// v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
// o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
// v4->outputs.push_back(o5);
}
TEST(GraphHelperTest, GraphNum) {
ProgramDesc prog;
Graph g(prog);
BuildZeroGraph(&g);
ASSERT_EQ(GraphNum(g), 0);
Graph g2(prog);
BuildOneGraph(&g2);
ASSERT_EQ(GraphNum(g2), 1);
Graph g3(prog);
BuildTwoGraphs(&g3);
ASSERT_EQ(GraphNum(g3), 2);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
...@@ -246,6 +246,7 @@ def prepare_encoder(src_word, ...@@ -246,6 +246,7 @@ def prepare_encoder(src_word,
padding_idx=pos_pad_idx, padding_idx=pos_pad_idx,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
# FIXME(guosheng): Decouple the program desc with batch_size. # FIXME(guosheng): Decouple the program desc with batch_size.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册