未验证 提交 e931c7ba 编写于 作者: W WangXi 提交者: GitHub

Fix multi nccl comm & wait server ready (#28663)

上级 e7caf3b8
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import paddle
from paddle.fluid.framework import core
from paddle.fluid import compiler
......@@ -51,13 +52,21 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
# should fix the variable
def _setup_nccl_op(self, startup_program, main_program, build_strategy):
trainer_endpoints = self.role_maker._get_trainer_endpoints()
trainers = trainer_endpoints
other_trainers = copy.copy(trainer_endpoints)
trainer_id = self.role_maker._worker_index()
current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id]
other_trainers.remove(current_endpoint)
trainer_endpoints_env = ",".join(trainer_endpoints)
trainers_num = self.role_maker._worker_num()
if trainer_id == 0:
wait_server_ready(other_trainers)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, build_strategy.nccl_comm_num):
startup_program.global_block().create_var(
name="NCCLID_{}".format(i),
......@@ -90,7 +99,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
})
def _try_to_compile(self, startup_program, main_program, loss):
import copy
dist_strategy = self.user_defined_strategy
local_build_strategy = paddle.fluid.BuildStrategy()
local_build_strategy.enable_sequential_execution = \
......@@ -148,13 +156,12 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
sync_allreduce = dist_strategy.sync_nccl_allreduce
if sync_allreduce:
exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1
if local_build_strategy.use_hierarchical_allreduce:
exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1
if exe_strategy.num_threads > 4:
exe_strategy.num_threads = max(
local_build_strategy.nccl_comm_num + 1,
exe_strategy.num_threads)
if local_build_strategy.nccl_comm_num > 1:
logging.warn(
"if you use hierachical_allreduce or "
"with multi nccl comm, please set distributed_strategy.sync_nccl_allreduce=False"
"nccl_comm_num > 1, you may need to set sync_nccl_allreduce=False to ensure that different nccl comms can overlap"
)
sync_batch_norm = local_build_strategy.sync_batch_norm
......@@ -167,6 +174,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
"set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
)
# NOTE. compatible with compiler, otherwise these values will be overwritten by compiler
main_program._nccl_comm_num = local_build_strategy.nccl_comm_num
main_program._use_hierarchical_allreduce = local_build_strategy.use_hierarchical_allreduce
main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks
# TODO(guru4elephant): should be an independent optimizer
self._setup_nccl_op(startup_program, main_program, local_build_strategy)
......
......@@ -75,6 +75,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
proc_a = launch_func(node_func, node_a)
proc_a.start()
proc_b = launch_func(node_func, node_b)
......@@ -197,6 +200,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
proc_a = launch_func(node_func, node_a)
proc_a.start()
proc_b = launch_func(node_func, node_b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册