未验证 提交 9da9b192 编写于 作者: W Wu Yi 提交者: GitHub

[1.1] fix graph num hang (#14072)

* fix graph num hang test=develop

* re-enable tests test=develop

* re-enable graph num check test=develop

* fix multi device pass role check test=develop
上级 8c166b64
......@@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) {
std::deque<ir::Node *> q_nodes;
std::vector<std::unordered_set<ir::Node *>> graph_nodes;
std::unordered_set<ir::Node *> g_nodes;
// q_set used to record records in the queue.
std::unordered_set<ir::Node *> q_set;
size_t graph_count = 0;
auto traverse_nodes = [&visited_nodes,
&q_nodes](const std::vector<ir::Node *> &nodes) {
std::copy_if(
nodes.begin(), nodes.end(), std::back_inserter(q_nodes),
[&visited_nodes](Node *node) { return !visited_nodes.count(node); });
auto traverse_nodes = [&visited_nodes, &q_nodes,
&q_set](const std::vector<ir::Node *> &nodes) {
for (auto n : nodes) {
if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) {
q_nodes.push_back(n);
q_set.insert(n);
}
}
};
while (visited_nodes.size() != nodes.size()) {
if (!q_nodes.empty()) {
auto cur_node = q_nodes.front();
q_nodes.pop_front();
q_set.erase(cur_node);
visited_nodes.insert(cur_node);
g_nodes.insert(cur_node);
traverse_nodes(cur_node->inputs);
......@@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) {
for (auto &n : nodes) {
if (visited_nodes.count(n) == 0) {
q_nodes.push_back(n);
q_set.insert(n);
break;
}
}
......
......@@ -28,12 +28,12 @@ enum class OpRole {
kBackward = 0x0001,
kOptimize = 0x0002,
// RPC role is for send/recv releated op
kRPC = 0x0003,
kRPC = 0x0004,
// Dist role is for split_byref/split_selected_rows/concat
// used for distributed training.
kDist = 0x0004,
kDist = 0x0008,
// Tag all learning rate scheduler operators.
kLRSched = 0x0005,
kLRSched = 0x0016,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
......
......@@ -156,6 +156,12 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_);
#endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
......@@ -23,9 +23,8 @@ class TestDistCTR2x2(TestDistBase):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
if __name__ == "__main__":
......
......@@ -40,8 +40,7 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False
self._use_reduce = False
# FIXME(typhoonzero): fix async mode test later
def no_test_dist_train(self):
def test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200)
......
......@@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False
self._use_reader_alloc = False
#FIXME(typhoonzero): fix async mode later
def no_test_dist_train(self):
def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100)
......
......@@ -79,8 +79,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False
self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册