未验证 提交 21b35167 编写于 作者: L lilong12 提交者: GitHub

fix bug for heter (#42590)

上级 df96d1ed
...@@ -122,4 +122,5 @@ REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel<float>, ...@@ -122,4 +122,5 @@ REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>, ops::RecvOpV2CUDAKernel<double>,
ops::RecvOpV2CUDAKernel<int>, ops::RecvOpV2CUDAKernel<int>,
ops::RecvOpV2CUDAKernel<int64_t>, ops::RecvOpV2CUDAKernel<int64_t>,
ops::RecvOpV2CUDAKernel<int8_t>,
ops::RecvOpV2CUDAKernel<plat::float16>); ops::RecvOpV2CUDAKernel<plat::float16>);
...@@ -109,4 +109,5 @@ REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel<float>, ...@@ -109,4 +109,5 @@ REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>, ops::SendOpV2CUDAKernel<double>,
ops::SendOpV2CUDAKernel<int>, ops::SendOpV2CUDAKernel<int>,
ops::SendOpV2CUDAKernel<int64_t>, ops::SendOpV2CUDAKernel<int64_t>,
ops::SendOpV2CUDAKernel<int8_t>,
ops::SendOpV2CUDAKernel<plat::float16>); ops::SendOpV2CUDAKernel<plat::float16>);
...@@ -41,7 +41,6 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,6 @@ class CSendOpASCENDKernel : public framework::OpKernel<T> {
// Use ProcessGroup // Use ProcessGroup
distributed::ProcessGroup* pg = map->get(ring_id); distributed::ProcessGroup* pg = map->get(ring_id);
std::vector<phi::DenseTensor> in_tensor; std::vector<phi::DenseTensor> in_tensor;
auto x = ctx.Input<framework::LoDTensor>("X");
in_tensor.push_back(*x); in_tensor.push_back(*x);
auto task = pg->Send(in_tensor, 1); auto task = pg->Send(in_tensor, 1);
return; return;
......
...@@ -50,6 +50,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { ...@@ -50,6 +50,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclInt64; return ncclInt64;
} else if (type == framework::proto::VarType::FP16) { } else if (type == framework::proto::VarType::FP16) {
return ncclFloat16; return ncclFloat16;
} else if (type == framework::proto::VarType::INT8) {
return ncclInt8;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported.")); "This datatype in nccl is not supported."));
......
...@@ -226,9 +226,15 @@ def _new_process_group_impl(backend, ...@@ -226,9 +226,15 @@ def _new_process_group_impl(backend,
world_size, world_size,
group_name, group_name,
pg_options, pg_options,
group_id=0): group_id=0,
src_rank=None,
dst_rank=None):
pg = None pg = None
genv = _get_global_env() genv = _get_global_env()
if backend != 'heter':
assert src_rank is None and dst_rank is None, (
"src_rank and dst_rank "
"can only be set for heter backend.")
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo": if backend == "gloo":
place = core.CPUPlace() place = core.CPUPlace()
...@@ -269,7 +275,9 @@ def _new_process_group_impl(backend, ...@@ -269,7 +275,9 @@ def _new_process_group_impl(backend,
gloo_rank=cluster_id, gloo_rank=cluster_id,
gloo_size=len(cluster_size), gloo_size=len(cluster_size),
with_switch=True, with_switch=True,
switch_endpoint=switch_ep) switch_endpoint=switch_ep,
src_rank=src_rank,
dst_rank=dst_rank)
return pg return pg
...@@ -322,6 +330,16 @@ def barrier(group=None): ...@@ -322,6 +330,16 @@ def barrier(group=None):
attrs={'ring_id': ring_id}) attrs={'ring_id': ring_id})
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None
def _set_custom_gid(gid):
_custom_gid = gid
def new_group(ranks=None, backend=None): def new_group(ranks=None, backend=None):
""" """
...@@ -348,9 +366,9 @@ def new_group(ranks=None, backend=None): ...@@ -348,9 +366,9 @@ def new_group(ranks=None, backend=None):
global _group_map global _group_map
if in_dygraph_mode(): if in_dygraph_mode():
global _default_group_name global _default_group_name
gid = _new_ring_id() gid = _custom_gid if _custom_gid else _new_ring_id()
group_name = _default_group_name + str(gid) group_name = _default_group_name + str(gid)
if ranks is None or len(ranks) > 1: if backend != 'heter' and (ranks is None or len(ranks) > 1):
global_group = _get_default_group() global_group = _get_default_group()
global_rank = global_group.rank global_rank = global_group.rank
global_ranks = global_group.ranks global_ranks = global_group.ranks
...@@ -362,8 +380,10 @@ def new_group(ranks=None, backend=None): ...@@ -362,8 +380,10 @@ def new_group(ranks=None, backend=None):
"equal to that of the default global group.") "equal to that of the default global group.")
size = len(ranks) size = len(ranks)
ranks = sorted(ranks) ranks = sorted(ranks)
if size > 1 and global_rank in ranks: if backend == 'heter' or (size > 1 and global_rank in ranks):
rank = ranks.index(global_rank) rank = 0 if backend == 'heter' else ranks.index(global_rank)
src_rank = ranks[0] if backend == 'heter' else None
dst_rank = ranks[1] if backend == 'heter' else None
pg = _new_process_group_impl( pg = _new_process_group_impl(
backend, backend,
_default_store, _default_store,
...@@ -371,7 +391,9 @@ def new_group(ranks=None, backend=None): ...@@ -371,7 +391,9 @@ def new_group(ranks=None, backend=None):
size, size,
group_name, group_name,
pg_options=None, pg_options=None,
group_id=gid) group_id=gid,
src_rank=src_rank,
dst_rank=dst_rank)
else: else:
rank = -1 rank = -1
pg = None pg = None
......
...@@ -138,9 +138,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -138,9 +138,16 @@ class ShardingOptimizer(MetaOptimizerBase):
if pp_degree > 1: if pp_degree > 1:
assert strategy.pipeline is True assert strategy.pipeline is True
assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \ if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( assert pp_degree == 2, ("For manually set pipeline, only "
global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree) "pp_degree = 2 is supported.")
assert global_world_size == mp_degree * sharding_degree * dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, dp_degree)
else:
assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree)
# FIXME (JZ-LIANG) deprecated hybrid_dp # FIXME (JZ-LIANG) deprecated hybrid_dp
if sharding_configs["hybrid_dp"]: if sharding_configs["hybrid_dp"]:
...@@ -268,7 +275,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -268,7 +275,11 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1: if self.pp_degree > 1:
startup_program = startup_program._pipeline_opt['startup_program'] startup_program = startup_program._pipeline_opt['startup_program']
print("pp_rank:", self.pp_rank) print("pp_rank:", self.pp_rank)
main_program = program_list[self.pp_rank] if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
main_program = program_list[int(
os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))]
else:
main_program = program_list[self.pp_rank]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f: with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program)) f.writelines(str(main_program))
main_block = main_program.global_block() main_block = main_program.global_block()
...@@ -633,14 +644,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -633,14 +644,15 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_group_endpoints[pair[1]], self.pp_group_endpoints[pair[1]],
] ]
pp_rank = 0 if self.pp_rank == pair[0] else 1 pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator( if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._startup_program, self._collective_helper._init_communicator(
self.current_endpoint, self._startup_program,
pp_group_endpoints, self.current_endpoint,
pp_rank, pp_group_endpoints,
ring_id, pp_rank,
False, ring_id,
sync=False) False,
sync=False)
def _init_npu_pipeline_comm(self, startup_block): def _init_npu_pipeline_comm(self, startup_block):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number # NOTE(wangxi): some bug with hccl, must set pp_degree be even number
...@@ -714,14 +726,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -714,14 +726,15 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_pipeline_comm(self, startup_block): def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
self._collective_helper._init_communicator( if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._startup_program, self._collective_helper._init_communicator(
self.current_endpoint, self._startup_program,
self.pp_group_endpoints, self.current_endpoint,
self.pp_rank, self.pp_group_endpoints,
self.pp_ring_id, self.pp_rank,
False, self.pp_ring_id,
sync=False) False,
sync=False)
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block) self._init_npu_pipeline_comm(startup_block)
...@@ -1387,17 +1400,27 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1387,17 +1400,27 @@ class ShardingOptimizer(MetaOptimizerBase):
# NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism # NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism
# e.g. mp-sharding-pp-dp # e.g. mp-sharding-pp-dp
# sharding-hybrid-dp as one senario of outter-pure-dp # sharding-hybrid-dp as one senario of outter-pure-dp
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( local_pp_degree = self.pp_degree
self.mp_degree, self.sharding_degree, self.pp_degree, if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
self.dp_degree, self.global_word_size) assert self.pp_degree == 2, ("For manually set pipeline, only "
"pp_degree = 2 is supported.")
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
self.global_word_size, self.mp_degree, self.sharding_degree, self.dp_degree)
local_pp_degree = 1
else:
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree, self.sharding_degree, self.pp_degree,
self.dp_degree, self.global_word_size)
if self.dp_degree > 1: if self.dp_degree > 1:
self.dp_ring_id = 2 self.dp_ring_id = 2
self.dp_rank = self.global_rank // (self.sharding_degree * self.dp_rank = self.global_rank // (
self.mp_degree * self.pp_degree) self.sharding_degree * self.mp_degree * local_pp_degree)
dp_first_rank_idx = self.global_rank % ( dp_first_rank_idx = self.global_rank % (
self.sharding_degree * self.mp_degree * self.pp_degree) self.sharding_degree * self.mp_degree * local_pp_degree)
dp_offset = (self.sharding_degree * self.mp_degree * self.pp_degree) dp_offset = (self.sharding_degree * self.mp_degree *
local_pp_degree)
self.dp_group_endpoints = [] self.dp_group_endpoints = []
for i in range(self.dp_degree): for i in range(self.dp_degree):
self.dp_group_endpoints.append(self.global_endpoints[ self.dp_group_endpoints.append(self.global_endpoints[
......
...@@ -6005,7 +6005,14 @@ class PipelineOptimizer(object): ...@@ -6005,7 +6005,14 @@ class PipelineOptimizer(object):
for p in program_list: for p in program_list:
self._create_vars(p.global_block(), main_block) self._create_vars(p.global_block(), main_block)
self.local_rank %= len(device_list) if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
self.local_rank = int(os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))
assert self.local_rank < len(device_list), (
"Manually specified "
"pipeline stage must be less than total number of pipeline "
"stages.")
else:
self.local_rank %= len(device_list)
# Step3.5: optimize forward send sync_comm to overlap send and recv # Step3.5: optimize forward send sync_comm to overlap send and recv
self._optimize_forward_send_sync(program_list[self.local_rank]) self._optimize_forward_send_sync(program_list[self.local_rank])
......
...@@ -63,7 +63,7 @@ class TestPipeline(TestDistBase): ...@@ -63,7 +63,7 @@ class TestPipeline(TestDistBase):
"pipeline_mnist_one_device.py", "pipeline_mnist_one_device.py",
check_error_log=True, check_error_log=True,
log_name=flag_name, log_name=flag_name,
need_envs=self.need_envs()) need_envs={"PADDLE_MANUAL_PIPELINE_STAGE": "0"})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册