提交 cadc6a97 编写于 作者: W WangXi 提交者: gongweibao

fix dgc test and bug when not set trainers_endpoints_, test=develop (#20617)

上级 46797f53
...@@ -465,8 +465,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -465,8 +465,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
new details::SparseAllReduceOpHandle( new details::SparseAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, multi_nccl_ctxs_, is_encoded, scopes, places, multi_nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) * strategy_.num_trainers_ * places_.size()));
places_.size()));
} else { } else {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new details::AllReduceOpHandle( new details::AllReduceOpHandle(
......
...@@ -271,7 +271,6 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -271,7 +271,6 @@ class CollectiveOptimizer(DistributedOptimizer):
node_num = self._node_num() node_num = self._node_num()
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
self._strategy.fuse_all_reduce_ops = True
exec_strategy = self._strategy.exec_strategy exec_strategy = self._strategy.exec_strategy
if node_num <= 1: if node_num <= 1:
......
...@@ -291,6 +291,10 @@ class TestDistRunnerBase(object): ...@@ -291,6 +291,10 @@ class TestDistRunnerBase(object):
build_stra.num_trainers = 1 build_stra.num_trainers = 1
build_stra.trainer_id = 0 build_stra.trainer_id = 0
if args.use_dgc:
# fuse_all_reduce_ops require that gradients should not be sparse types
build_stra.fuse_all_reduce_ops = False
print_to_err(type(self).__name__, "begin to compile with data parallel") print_to_err(type(self).__name__, "begin to compile with data parallel")
binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
...@@ -852,7 +856,9 @@ class TestDistBase(unittest.TestCase): ...@@ -852,7 +856,9 @@ class TestDistBase(unittest.TestCase):
if check_error_log: if check_error_log:
required_envs["GLOG_vmodule"] = \ required_envs["GLOG_vmodule"] = \
"fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10" "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
"alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
"sparse_all_reduce_op_handle=10"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
local_losses \ local_losses \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册