未验证 提交 9da9b192 编写于 作者: W Wu Yi 提交者: GitHub

[1.1] fix graph num hang (#14072)

* fix graph num hang test=develop

* re-enable tests test=develop

* re-enable graph num check test=develop

* fix multi device pass role check test=develop
上级 8c166b64
...@@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) { ...@@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) {
std::deque<ir::Node *> q_nodes; std::deque<ir::Node *> q_nodes;
std::vector<std::unordered_set<ir::Node *>> graph_nodes; std::vector<std::unordered_set<ir::Node *>> graph_nodes;
std::unordered_set<ir::Node *> g_nodes; std::unordered_set<ir::Node *> g_nodes;
// q_set used to record records in the queue.
std::unordered_set<ir::Node *> q_set;
size_t graph_count = 0; size_t graph_count = 0;
auto traverse_nodes = [&visited_nodes, auto traverse_nodes = [&visited_nodes, &q_nodes,
&q_nodes](const std::vector<ir::Node *> &nodes) { &q_set](const std::vector<ir::Node *> &nodes) {
std::copy_if( for (auto n : nodes) {
nodes.begin(), nodes.end(), std::back_inserter(q_nodes), if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
[&visited_nodes](Node *node) { return !visited_nodes.count(node); }); q_nodes.push_back(n);
q_set.insert(n);
}
}
}; };
while (visited_nodes.size() != nodes.size()) { while (visited_nodes.size() != nodes.size()) {
if (!q_nodes.empty()) { if (!q_nodes.empty()) {
auto cur_node = q_nodes.front(); auto cur_node = q_nodes.front();
q_nodes.pop_front(); q_nodes.pop_front();
q_set.erase(cur_node);
visited_nodes.insert(cur_node); visited_nodes.insert(cur_node);
g_nodes.insert(cur_node); g_nodes.insert(cur_node);
traverse_nodes(cur_node->inputs); traverse_nodes(cur_node->inputs);
...@@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) { ...@@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) {
for (auto &n : nodes) { for (auto &n : nodes) {
if (visited_nodes.count(n) == 0) { if (visited_nodes.count(n) == 0) {
q_nodes.push_back(n); q_nodes.push_back(n);
q_set.insert(n);
break; break;
} }
} }
......
...@@ -28,12 +28,12 @@ enum class OpRole { ...@@ -28,12 +28,12 @@ enum class OpRole {
kBackward = 0x0001, kBackward = 0x0001,
kOptimize = 0x0002, kOptimize = 0x0002,
// RPC role is for send/recv releated op // RPC role is for send/recv releated op
kRPC = 0x0003, kRPC = 0x0004,
// Dist role is for split_byref/split_selected_rows/concat // Dist role is for split_byref/split_selected_rows/concat
// used for distributed training. // used for distributed training.
kDist = 0x0004, kDist = 0x0008,
// Tag all learning rate scheduler operators. // Tag all learning rate scheduler operators.
kLRSched = 0x0005, kLRSched = 0x0016,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
...@@ -23,8 +23,7 @@ class TestDistCTR2x2(TestDistBase): ...@@ -23,8 +23,7 @@ class TestDistCTR2x2(TestDistBase):
self._sync_mode = True self._sync_mode = True
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_dist_ctr(self):
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
......
...@@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase): ...@@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False self._use_reduce = False
# FIXME(typhoonzero): fix async mode test later def test_dist_train(self):
def no_test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
......
...@@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reader_alloc = False self._use_reader_alloc = False
#FIXME(typhoonzero): fix async mode later def test_dist_train(self):
def no_test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
...@@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later def test_simnet_bow(self):
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册