ascend_communicate (#31708)

上级 faf40da5
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
import os
import paddle.fluid as fluid
from paddle.fluid import core, unique_name
......@@ -70,10 +71,10 @@ class CollectiveHelper(object):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
if core.is_compiled_with_cuda():
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
......@@ -98,6 +99,22 @@ class CollectiveHelper(object):
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
block.append_op(
type='c_comm_init_hcom',
inputs={},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
OP_ROLE_KEY: OpRole.Forward
})
def _wait(self, current_endpoint, endpoints):
assert (self.wait_port)
......
......@@ -13,6 +13,7 @@
from __future__ import print_function
from __future__ import division
import os
import paddle.fluid as fluid
from paddle.fluid import core, unique_name
......@@ -78,10 +79,10 @@ class PipelineHelper(object):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
if core.is_compiled_with_cuda():
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
......@@ -106,6 +107,22 @@ class PipelineHelper(object):
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
block.append_op(
type='c_comm_init_hcom',
inputs={},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
OP_ROLE_KEY: OpRole.Forward
})
def _broadcast_params(self, ring_id):
block = self.startup_program.global_block()
......
......@@ -265,7 +265,7 @@ class ShardingOptimizer(MetaOptimizerBase):
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init"
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom"
]:
pass
elif op.type == "conditional_block":
......
......@@ -2053,7 +2053,7 @@ class Operator(object):
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream',
'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_comm_init_hcom', 'c_sync_calc_stream',
'c_sync_comm_stream', 'queue_generator', 'dequeue', 'enqueue',
'heter_listen_and_serv'
}
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import sys
import math
from functools import reduce
import os
import collections
import six
......@@ -101,10 +102,10 @@ class Collective(object):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
if core.is_compiled_with_cuda():
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
......@@ -129,6 +130,22 @@ class Collective(object):
'ring_id': ring_id,
self.op_role_key: OpRole.Forward
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
block.append_op(
type='c_comm_init_hcom',
inputs={},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
self.op_role_key: OpRole.Forward
})
def _broadcast_params(self):
block = self.startup_program.global_block()
......
......@@ -133,9 +133,10 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint,
return
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
if core.is_compiled_with_cuda():
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=fluid.unique_name.generate('nccl_id'),
persistable=True,
......@@ -160,6 +161,21 @@ def init_communicator(program, rank, nranks, wait_port, current_endpoint,
'rank': rank,
'ring_id': 0,
})
elif core.is_compiled_with_npu():
endpoint_to_index_map = {
e: idx for idx, e in enumerate(endpoints)
}
block.append_op(
type='c_comm_init_hcom',
inputs={},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
'device_id': int(os.getenv("FLAGS_selected_npus")),
'rank_ids': [endpoint_to_index_map[e] for e in endpoints],
})
def prepare_distributed_context(place=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册