未验证 提交 e931c7ba 编写于 作者: W WangXi 提交者: GitHub

Fix multi nccl comm & wait server ready (#28663)

上级 e7caf3b8
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import copy
import paddle import paddle
from paddle.fluid.framework import core from paddle.fluid.framework import core
from paddle.fluid import compiler from paddle.fluid import compiler
...@@ -51,13 +52,21 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -51,13 +52,21 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
# should fix the variable # should fix the variable
def _setup_nccl_op(self, startup_program, main_program, build_strategy): def _setup_nccl_op(self, startup_program, main_program, build_strategy):
trainer_endpoints = self.role_maker._get_trainer_endpoints() trainer_endpoints = self.role_maker._get_trainer_endpoints()
trainers = trainer_endpoints other_trainers = copy.copy(trainer_endpoints)
trainer_id = self.role_maker._worker_index() trainer_id = self.role_maker._worker_index()
current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id] current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id]
other_trainers.remove(current_endpoint)
trainer_endpoints_env = ",".join(trainer_endpoints) trainer_endpoints_env = ",".join(trainer_endpoints)
trainers_num = self.role_maker._worker_num() trainers_num = self.role_maker._worker_num()
if trainer_id == 0:
wait_server_ready(other_trainers)
nccl_id_var = startup_program.global_block().create_var( nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, build_strategy.nccl_comm_num): for i in range(1, build_strategy.nccl_comm_num):
startup_program.global_block().create_var( startup_program.global_block().create_var(
name="NCCLID_{}".format(i), name="NCCLID_{}".format(i),
...@@ -90,7 +99,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -90,7 +99,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
}) })
def _try_to_compile(self, startup_program, main_program, loss): def _try_to_compile(self, startup_program, main_program, loss):
import copy
dist_strategy = self.user_defined_strategy dist_strategy = self.user_defined_strategy
local_build_strategy = paddle.fluid.BuildStrategy() local_build_strategy = paddle.fluid.BuildStrategy()
local_build_strategy.enable_sequential_execution = \ local_build_strategy.enable_sequential_execution = \
...@@ -148,13 +156,12 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -148,13 +156,12 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
sync_allreduce = dist_strategy.sync_nccl_allreduce sync_allreduce = dist_strategy.sync_nccl_allreduce
if sync_allreduce: if sync_allreduce:
exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1 exe_strategy.num_threads = max(
if local_build_strategy.use_hierarchical_allreduce: local_build_strategy.nccl_comm_num + 1,
exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1 exe_strategy.num_threads)
if exe_strategy.num_threads > 4: if local_build_strategy.nccl_comm_num > 1:
logging.warn( logging.warn(
"if you use hierachical_allreduce or " "nccl_comm_num > 1, you may need to set sync_nccl_allreduce=False to ensure that different nccl comms can overlap"
"with multi nccl comm, please set distributed_strategy.sync_nccl_allreduce=False"
) )
sync_batch_norm = local_build_strategy.sync_batch_norm sync_batch_norm = local_build_strategy.sync_batch_norm
...@@ -167,6 +174,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -167,6 +174,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
"set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False." "set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
) )
# NOTE. compatible with compiler, otherwise these values will be overwritten by compiler
main_program._nccl_comm_num = local_build_strategy.nccl_comm_num
main_program._use_hierarchical_allreduce = local_build_strategy.use_hierarchical_allreduce
main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks
# TODO(guru4elephant): should be an independent optimizer # TODO(guru4elephant): should be an independent optimizer
self._setup_nccl_op(startup_program, main_program, local_build_strategy) self._setup_nccl_op(startup_program, main_program, local_build_strategy)
......
...@@ -75,6 +75,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ...@@ -75,6 +75,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
proc_a = launch_func(node_func, node_a) proc_a = launch_func(node_func, node_a)
proc_a.start() proc_a.start()
proc_b = launch_func(node_func, node_b) proc_b = launch_func(node_func, node_b)
...@@ -197,6 +200,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ...@@ -197,6 +200,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
proc_a = launch_func(node_func, node_a) proc_a = launch_func(node_func, node_a)
proc_a.start() proc_a.start()
proc_b = launch_func(node_func, node_b) proc_b = launch_func(node_func, node_b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册