未验证 提交 d12588d2 编写于 作者: Y Yiqun Liu 提交者: GitHub

Broadcast the master weight along with param for distributed training. (#52638)

* Broadcast the master weight along with param for distributed training.

* Fix codestyle.
上级 ba9a22db
......@@ -28,6 +28,7 @@ from .sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from .sharding.prune import ProgramDeps
from .sharding import utils
# FIXME: import *
from .sharding.utils import *
import logging
......@@ -84,7 +85,7 @@ class ShardingOptimizer(MetaOptimizerBase):
dist_strategy.sharding_configs = {"segment_broadcast_MB": 32}
def _get_sharding_segment_strategy(self):
""" get
"""get
self._sharding_segment_strategy
1. if by_size: self._broadcast_MB
2. if by_anchors: self._sharding_segment_anchors
......@@ -97,21 +98,26 @@ class ShardingOptimizer(MetaOptimizerBase):
if segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = sharding_configs["segment_broadcast_MB"]
assert self._broadcast_MB > 0, "segment size should larger than zero !"
assert (
self._broadcast_MB > 0
), "segment size should larger than zero !"
elif segment_strategy == "segment_anchors":
self._sharding_segment_anchors = sharding_configs["segment_anchors"]
assert len(self._sharding_segment_anchors
) > 0, "you should set the sharding segment anchors !"
assert (
len(self._sharding_segment_anchors) > 0
), "you should set the sharding segment anchors !"
self._backward_remain_anchors = self._sharding_segment_anchors[:]
self._forward_remain_anchors = []
else:
raise NotImplementedError(
"the sharding segment strategy [{}] is not implemented".format(
str(segment_strategy)))
str(segment_strategy)
)
)
self._sharding_segment_strategy = segment_strategy
def _get_hybrid_degree(self):
""" get
"""get
self.hybrid_dp
self.sharding_degree
self.mp_degree
......@@ -135,21 +141,32 @@ class ShardingOptimizer(MetaOptimizerBase):
assert strategy.pipeline is True
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
assert pp_degree == 2, ("For manually set pipeline, only "
"pp_degree = 2 is supported.")
assert global_world_size == mp_degree * sharding_degree * dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, dp_degree)
assert pp_degree == 2, (
"For manually set pipeline, only " "pp_degree = 2 is supported."
)
assert (
global_world_size == mp_degree * sharding_degree * dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, dp_degree
)
else:
assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree)
assert (
global_world_size
== mp_degree * sharding_degree * pp_degree * dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
global_world_size,
mp_degree,
sharding_degree,
pp_degree,
dp_degree,
)
# FIXME (JZ-LIANG) deprecated hybrid_dp
if sharding_configs["hybrid_dp"]:
logger.warning(
"[hybrid_dp] API setting is deprecated. Now when "
"dp_degree >= 2, its will be in hybrid dp mode automatically")
"dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert dp_degree >= 1
self.hybrid_dp = True if dp_degree > 1 else False
......@@ -159,7 +176,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_degree = dp_degree
def _get_hybrid_dp_mode(self):
""" get
"""get
self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp'
self.gradient_merge_mode = 'pp_gm' or 'sharding_gm'
self._gradient_merge_acc_step
......@@ -183,9 +200,10 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1:
dp_mode = "pp_hybrid_dp"
else:
assert self.sharding_degree > 1, \
"by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \
assert self.sharding_degree > 1, (
"by now we only support five kind of hybrid dp: sharding_hybrid_dp, "
"mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
)
dp_mode = "sharding_hybrid_dp"
# gradient merge
......@@ -198,23 +216,33 @@ class ShardingOptimizer(MetaOptimizerBase):
gm_mode = "pp_gm"
gm_acc_step = strategy.pipeline_configs['accumulate_steps']
gradient_scale_configs = strategy.gradient_scale_configs
assert gradient_scale_configs['scale_strategy'] == 'avg', \
'For pipeline mode, the ' 'gradient scale mode should ' \
'be "avg", but got {}'.format(gradient_scale_configs['scale_strategy'])
assert gradient_scale_configs['scale_strategy'] == 'avg', (
'For pipeline mode, the '
'gradient scale mode should '
'be "avg", but got {}'.format(
gradient_scale_configs['scale_strategy']
)
)
# Note (Yuang Liu): this avg_loss flag determines where to do the average op for grad merge.
# If True, will do sum firstly for gradient merge, then do scale by gm_acc_step.
# If False, will scale loss by gm_acc_step first, then do sum for gradient merge.
self.scale_gradient = gradient_scale_configs['scale_gradient']
if gm_acc_step > 1:
logger.info("Gradient merge in [{}], acc step = [{}]".format(
gm_mode, gm_acc_step))
logger.info(
"Gradient merge in [{}], acc step = [{}]".format(
gm_mode, gm_acc_step
)
)
optimizer_sharding = False
# TODO(wangxi): need support dp_as_opt_sharding with sharding
# need support without pp in future
if self.sharding_degree == 1 and self.dp_degree > 1 \
and sharding_configs['_dp_as_optimizer_sharding'] \
and self.pp_degree > 1:
if (
self.sharding_degree == 1
and self.dp_degree > 1
and sharding_configs['_dp_as_optimizer_sharding']
and self.pp_degree > 1
):
optimizer_sharding = True
self.hybrid_dp_mode = dp_mode
......@@ -224,19 +252,23 @@ class ShardingOptimizer(MetaOptimizerBase):
# this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = sharding_configs[
"pp_allreduce_in_optimize"]
"pp_allreduce_in_optimize"
]
def _inner_opt_minimize(self, loss, startup_program, parameter_list,
no_grad_set):
def _inner_opt_minimize(
self, loss, startup_program, parameter_list, no_grad_set
):
pipeline_configs = self.user_defined_strategy.pipeline_configs
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
"self.inner_opt of ShardingOptimizer should not be None."
)
if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step)
self.inner_opt, self._gradient_merge_acc_step
)
self._pp_optimizer = pp_optimizer
global_rank = self.role_maker._worker_index()
......@@ -253,17 +285,25 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_ring_id': 3,
'mp_degree': self.mp_degree,
'mp_rank': global_rank % self.mp_degree,
'scale_gradient': self.scale_gradient
'scale_gradient': self.scale_gradient,
}
main_program = loss.block.program
main_program._pipeline_opt = pipeline_opt
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
(
optimize_ops,
params_grads,
program_list,
self.pipeline_pair,
self.pp_ring_map,
) = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set
)
assert self.pp_degree == len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
loss, startup_program, parameter_list, no_grad_set
)
if startup_program is None:
startup_program = default_startup_program()
......@@ -272,8 +312,9 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program = startup_program._pipeline_opt['startup_program']
print("pp_rank:", self.pp_rank)
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
main_program = program_list[int(
os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))]
main_program = program_list[
int(os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))
]
else:
main_program = program_list[self.pp_rank]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
......@@ -299,14 +340,16 @@ class ShardingOptimizer(MetaOptimizerBase):
return optimize_ops, params_grads
def _apply_sharding_pass(self, params_grads):
if self.sharding_degree == 1: return
if self.sharding_degree == 1:
return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# step1: build shard
self._build_shard(params_grads, self.sharding_rank,
self.sharding_degree)
self._build_shard(
params_grads, self.sharding_rank, self.sharding_degree
)
# step2: split_program
self._split_program(main_block)
......@@ -318,13 +361,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# step4: remove unneeded ops and vars from block
self._prune_main_program(
main_block, self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
main_block,
self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id],
)
self._prune_startup_program(startup_block, self._shard)
def _apply_opt_sharding_pass(self, params_grads):
""" outer dp as optimizer sharding """
if self._optimizer_sharding is False: return
"""outer dp as optimizer sharding"""
if self._optimizer_sharding is False:
return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
......@@ -338,12 +384,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# step4: remove unneeded ops and vars from block
self._prune_main_program(
main_block, self._shard,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id])
main_block,
self._shard,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id],
)
self._prune_startup_program(startup_block, self._shard)
def _insert_allreduce_for_pp(self, params_grads):
if self.pp_degree == 1: return
if self.pp_degree == 1:
return
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
......@@ -363,10 +412,12 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
if op.type != 'cast':
continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in self._params:
continue
# if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
......@@ -376,7 +427,8 @@ class ShardingOptimizer(MetaOptimizerBase):
shard = self._shard if self._optimizer_sharding else None
accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block, strategy=strategy, shard=shard)
main_block, strategy=strategy, shard=shard
)
len_of_ops = len(main_block.ops)
if self.scale_gradient:
......@@ -384,8 +436,9 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index = get_first_optimize_op_idx(main_block)
if self.pp_allreduce_in_optimize:
logger.info("Pipeline Persistable grad is {}".format(
accumulated_grad_names))
logger.info(
"Pipeline Persistable grad is {}".format(accumulated_grad_names)
)
# FIXME(wangxi): accumulated_grad get from pipeline is not
# include sharding's param@BroadCast grad when
# pp_allreduce_in_optimize
......@@ -397,10 +450,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True,
rank=self.sharding_rank)
rank=self.sharding_rank,
)
logger.info("PP-Sharding grad is {}".format(accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops)
if self._optimizer_sharding:
......@@ -413,10 +467,12 @@ class ShardingOptimizer(MetaOptimizerBase):
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
strategy=strategy,
)
logger.info(
"Optimizer grad in this rank {}".format(accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
"Optimizer grad in this rank {}".format(accumulated_grad_names)
)
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops)
# NOTE(wangxi): we fused after optimize_cast
......@@ -424,14 +480,17 @@ class ShardingOptimizer(MetaOptimizerBase):
optimizer_param = utils.insert_broadcast_param_ops(
main_block,
len_of_ops,
self.dp_ring_id, [x[0].name for x in params_grads],
self.dp_ring_id,
[x[0].name for x in params_grads],
self._shard,
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=None if optimize_cast else strategy)
strategy=None if optimize_cast else strategy,
)
logger.info(
"Optimizer param in this rank {}".format(optimizer_param))
"Optimizer param in this rank {}".format(optimizer_param)
)
if not strategy.fuse_grad_merge and not optimize_cast:
assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
......@@ -442,15 +501,20 @@ class ShardingOptimizer(MetaOptimizerBase):
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True,
user_defined_strategy=strategy)
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
user_defined_strategy=strategy,
)
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops)
# FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there?
def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names):
if self.user_defined_strategy.amp and \
self.user_defined_strategy.amp_configs['use_dynamic_loss_scaling']:
if (
self.user_defined_strategy.amp
and self.user_defined_strategy.amp_configs[
'use_dynamic_loss_scaling'
]
):
# For AMP, if using dynamic loss scaling the avg
# operation can be simple done by modify the LossScaling op.
for idx, op in enumerate(main_block.ops):
......@@ -461,7 +525,8 @@ class ShardingOptimizer(MetaOptimizerBase):
loss_scale_tmp_var = main_block.create_var(
name=loss_scale_tmp_var_name,
shape=loss_scaling_var.shape,
dtype=loss_scaling_var.dtype)
dtype=loss_scaling_var.dtype,
)
main_block._insert_op_without_sync(
idx,
type='scale',
......@@ -471,8 +536,9 @@ class ShardingOptimizer(MetaOptimizerBase):
'scale': self._gradient_merge_acc_step,
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
OP_ROLE_KEY: OpRole.Optimize,
},
)
op._rename_input(loss_scale_name, loss_scale_tmp_var_name)
break
else:
......@@ -483,7 +549,9 @@ class ShardingOptimizer(MetaOptimizerBase):
if is_optimizer_op(op) and op.type != 'c_sync_comm_stream':
tmp_first_opt_idx = idx
break
assert tmp_first_opt_idx is not None, 'Occurs some errors, no optimize ops'
assert (
tmp_first_opt_idx is not None
), 'Occurs some errors, no optimize ops'
for grad in accumulated_grad_names:
main_block._insert_op_without_sync(
tmp_first_opt_idx,
......@@ -494,14 +562,17 @@ class ShardingOptimizer(MetaOptimizerBase):
'scale': 1.0 / self._gradient_merge_acc_step,
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
OP_ROLE_KEY: OpRole.Optimize,
},
)
def _adapt_amp_clip_without_sharding(self):
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if self.sharding_degree > 1: return
if self._optimizer_sharding: return
if self.sharding_degree > 1:
return
if self._optimizer_sharding:
return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
......@@ -515,9 +586,9 @@ class ShardingOptimizer(MetaOptimizerBase):
FP16Utils.sync_amp_check_nan_inf(main_block, rings)
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(main_block,
[self.mp_ring_id, self.pp_ring_id],
self.mp_rank)
gradientclip_helper.sync_global_norm(
main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank
)
def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block()
......@@ -538,8 +609,9 @@ class ShardingOptimizer(MetaOptimizerBase):
mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None
dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None
offload_helper = OffloadHelper(mp_ring_id=mp_ring_id,
dp_ring_id=dp_ring_id)
offload_helper = OffloadHelper(
mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id
)
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
......@@ -555,32 +627,32 @@ class ShardingOptimizer(MetaOptimizerBase):
# will take more memory, but will be faster. Trade space for time.
if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block,
[x[0].name for x in params_grads])
main_block, startup_block, [x[0].name for x in params_grads]
)
# NOTE(wangxi): fused after optimize_cast
utils.fuse_opt_broadcast_param_ops(main_block,
dp_ring_id,
self._shard,
strategy=strategy)
utils.fuse_opt_broadcast_param_ops(
main_block, dp_ring_id, self._shard, strategy=strategy
)
else:
offload_helper.cast_fp32param_in_optimize(
main_block, startup_block)
main_block, startup_block
)
def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
with open(
"start_sharding_%d" % self.role_maker._worker_index(), 'w'
) as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
with open(
"main_sharding_%d" % self.role_maker._worker_index(), 'w'
) as f:
f.writelines(str(main_block.program))
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
def minimize_impl(
self, loss, startup_program=None, parameter_list=None, no_grad_set=None
):
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
......@@ -595,7 +667,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# inner optimize minimize
optimize_ops, params_grads = self._inner_opt_minimize(
loss, startup_program, parameter_list, no_grad_set)
loss, startup_program, parameter_list, no_grad_set
)
self._init_comm()
......@@ -644,13 +717,15 @@ class ShardingOptimizer(MetaOptimizerBase):
]
pp_rank = 0 if self.pp_rank == pair[0] else 1
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._collective_helper._init_communicator(self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
sync=False)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
sync=False,
)
def _init_npu_pipeline_comm(self, startup_block):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number
......@@ -668,15 +743,22 @@ class ShardingOptimizer(MetaOptimizerBase):
my_pair.append(pair)
# for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (self.pp_rank, (self.pp_rank + 1) % self.pp_degree
) # 2->3
send_to_next_pair = (
self.pp_rank,
(self.pp_rank + 1) % self.pp_degree,
) # 2->3
recv_from_next_pair = (
(self.pp_rank + 1) % self.pp_degree, self.pp_rank) # 3->2
(self.pp_rank + 1) % self.pp_degree,
self.pp_rank,
) # 3->2
recv_from_prev_pair = (
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree, self.pp_rank
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
self.pp_rank,
) # 1->2
send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) %
self.pp_degree) # 2->1
send_to_prev_pair = (
self.pp_rank,
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
) # 2->1
even = (self.pp_rank % 2) == 0
......@@ -685,54 +767,66 @@ class ShardingOptimizer(MetaOptimizerBase):
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(
pair, ring_id))
logger.info(
"pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, ring_id)
)
# 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(
pair, ring_id))
logger.info(
"pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, ring_id)
)
# if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1],
max_ring_id +
1) # 3->0 not in pp_ring_map
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1], max_ring_id + 1
) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id))
logger.info(
"pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
# 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1],
max_ring_id +
2) # 0->3 not in pp_ring_map
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1], max_ring_id + 2
) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id))
logger.info(
"pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \
"please check unexpected pair {}".format(my_pair)
assert len(my_pair) == 0, (
"Current pipeline does not support cross stage communication, "
"please check unexpected pair {}".format(my_pair)
)
def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._collective_helper._init_communicator(self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
sync=False)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
sync=False,
)
if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block)
......@@ -752,13 +846,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# mp ring
if self.mp_degree > 1:
self._collective_helper._init_communicator(self._startup_program,
self.current_endpoint,
self.mp_group_endpoints,
self.mp_rank,
self.mp_ring_id,
False,
sync=False)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.mp_group_endpoints,
self.mp_rank,
self.mp_ring_id,
False,
sync=False,
)
# sharding ring
if self.sharding_degree > 1:
......@@ -769,7 +865,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank,
self.sharding_ring_id,
False,
sync=False)
sync=False,
)
# pp ring
if self.pp_degree > 1:
......@@ -777,13 +874,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# pure dp ring
if self.dp_degree > 1:
self._collective_helper._init_communicator(self._startup_program,
self.current_endpoint,
self.dp_group_endpoints,
self.dp_rank,
self.dp_ring_id,
False,
sync=False)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.dp_group_endpoints,
self.dp_rank,
self.dp_ring_id,
False,
sync=False,
)
startup_block._sync_with_cpp()
......@@ -794,9 +893,12 @@ class ShardingOptimizer(MetaOptimizerBase):
# step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params(
self._main_program.global_block())
self._main_program.global_block()
)
def _wait(self, ):
def _wait(
self,
):
endpoints = self.global_endpoints[:]
current_endpoint = endpoints[self.global_rank]
if self.global_rank == 0:
......@@ -821,7 +923,7 @@ class ShardingOptimizer(MetaOptimizerBase):
segment._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize))
assert int(op.attr('op_role')) != int(OpRole.Optimize)
if self._sharding_segment_strategy == "segment_broadcast_MB":
if segment._param_mem >= self._broadcast_MB:
segment = self.collect_segment(segment, op_idx, block)
......@@ -835,21 +937,27 @@ class ShardingOptimizer(MetaOptimizerBase):
if ".cast_fp16@GRAD" not in input_name:
continue
else:
input_name = input_name[:input_name.
find(".cast_fp16@GRAD")]
input_name = input_name[
: input_name.find(".cast_fp16@GRAD")
]
if input_name in self._backward_remain_anchors:
segment = self.collect_segment(
segment, op_idx, block)
assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format(
input_name)
segment, op_idx, block
)
assert (
input_name not in self._forward_remain_anchors
), "segment anchor [{}] met twice !".format(
input_name
)
self._backward_remain_anchors.remove(input_name)
self._forward_remain_anchors.append(input_name)
elif int(op.attr('op_role')) == int(OpRole.Forward):
for output_name in op.desc.output_arg_names():
if output_name in self._forward_remain_anchors:
segment = self.collect_segment(
segment, op_idx, block)
segment, op_idx, block
)
self._forward_remain_anchors.remove(output_name)
# find broadcast vars
......@@ -865,47 +973,49 @@ class ShardingOptimizer(MetaOptimizerBase):
if self._shard.has_param(input_name):
broadcast_var_name = input_name
else:
broadcast_var_name = unique_name.generate(input_name +
"@BroadCast")
broadcast_var_name = unique_name.generate(
input_name + "@BroadCast"
)
segment._fill_constant_vars.append(broadcast_var_name)
# (JZ-LIANG) should use Param base name ?
broadcast_var_base_name = input_name
if "subprog" in broadcast_var_base_name:
# remove suffix
broadcast_var_base_name = broadcast_var_base_name[:
broadcast_var_base_name
.find(
".subprog"
)]
broadcast_var_base_name = broadcast_var_base_name[
: broadcast_var_base_name.find(".subprog")
]
var2broadcast_time[
broadcast_var_base_name] = var2broadcast_time.get(
broadcast_var_base_name, 0) + 1
var2broadcast_time[broadcast_var_base_name] = (
var2broadcast_time.get(broadcast_var_base_name, 0) + 1
)
segment._param2broadcast[input_name] = broadcast_var_name
segment._broadcast_vars.append(
(broadcast_var_name, self._shard.device(input_name)))
(broadcast_var_name, self._shard.device(input_name))
)
segment._param_mem += get_var_size(
self._main_program.global_block().var(input_name))
self._main_program.global_block().var(input_name)
)
# find reduce vars
if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize
pass
else:
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i
+
1]
param, reduced_grad = (
op_role_var[i],
op_role_var[i + 1],
)
segment._allreduce_vars.append(reduced_grad)
assert (reduced_grad
not in self._reduced_grads_to_param)
assert (
reduced_grad not in self._reduced_grads_to_param
)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
......@@ -920,29 +1030,40 @@ class ShardingOptimizer(MetaOptimizerBase):
self._segments.insert(0, segment)
if self._sharding_segment_strategy == "segment_anchors":
assert len(
self._forward_remain_anchors) == 0, "remain anchors {}".format(
self._forward_remain_anchors)
assert len(
self._backward_remain_anchors) == 0, "remain anchors {}".format(
self._backward_remain_anchors)
assert (
len(self._forward_remain_anchors) == 0
), "remain anchors {}".format(self._forward_remain_anchors)
assert (
len(self._backward_remain_anchors) == 0
), "remain anchors {}".format(self._backward_remain_anchors)
if self._verbose:
for varname in sorted(var2broadcast_time,
key=var2broadcast_time.get,
reverse=True):
logger.info("Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname))
for varname in sorted(
var2broadcast_time, key=var2broadcast_time.get, reverse=True
):
logger.info(
"Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname
)
)
for idx_ in range(len(self._segments)):
logger.info("segment [{}] :".format(idx_))
logger.info("start op: [{}] [{}]".format(
block.ops[self._segments[idx_]._start_idx].desc.type(),
block.ops[self._segments[idx_].
_start_idx].desc.input_arg_names()))
logger.info("end op: [{}] [{}]".format(
block.ops[self._segments[idx_]._end_idx].desc.type(),
block.ops[
self._segments[idx_]._end_idx].desc.input_arg_names()))
logger.info(
"start op: [{}] [{}]".format(
block.ops[self._segments[idx_]._start_idx].desc.type(),
block.ops[
self._segments[idx_]._start_idx
].desc.input_arg_names(),
)
)
logger.info(
"end op: [{}] [{}]".format(
block.ops[self._segments[idx_]._end_idx].desc.type(),
block.ops[
self._segments[idx_]._end_idx
].desc.input_arg_names(),
)
)
return
def _prune_main_program(self, block, shard, rings):
......@@ -954,7 +1075,7 @@ class ShardingOptimizer(MetaOptimizerBase):
2. prune cast_fp32_to_fp16; update amp_infine_checking
3. prune gradient_clip related; update global_norm_sum
4. prune optimizer op + param + gradient
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, shard)
......@@ -975,17 +1096,18 @@ class ShardingOptimizer(MetaOptimizerBase):
input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names()
# FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE
if op.type == "c_allreduce_sum" and \
op.attr('use_model_parallel') is False:
assert (len(output_names) == 1)
if (
op.type == "c_allreduce_sum"
and op.attr('use_model_parallel') is False
):
assert len(output_names) == 1
output_name = output_names[0]
reduced_grads.append(output_name)
# prune optimizer state and param
pruned_opti_vars = []
for var_name in list(block.vars.keys()):
if shard.is_opti_var(var_name) and \
not shard.has_opt_var(var_name):
if shard.is_opti_var(var_name) and not shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
......@@ -996,17 +1118,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum",
"c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
"c_allreduce_sum",
"c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]:
pass
elif op.type == "conditional_block":
assert (op.desc.has_attr("sub_block"))
assert op.desc.has_attr("sub_block")
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = program_deps.get_sub_block_deps(subblock_idx)
# only prune amp subblock
......@@ -1022,7 +1144,8 @@ class ShardingOptimizer(MetaOptimizerBase):
reversed_output_vars.append(output_name)
# prune
for sub_op_idx, _ in reversed(
list(enumerate(subblock_deps._block.ops))):
list(enumerate(subblock_deps._block.ops))
):
if subblock_deps.should_remove_op(sub_op_idx):
subblock_deps.remove_op(sub_op_idx)
reversed_input_vars = []
......@@ -1038,7 +1161,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx):
# NOTE(wangxi): need reserve all param in optimizer_sharding
reserved_vars = self._params if self._optimizer_sharding else None
reserved_vars = (
self._params if self._optimizer_sharding else None
)
program_deps.remove_op(idx, reserved_vars)
# NOTE (JZ-LIANG) revise and unify logic here
......@@ -1049,7 +1174,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
if block.has_var(var_name):
reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp()
return
......@@ -1059,7 +1185,7 @@ class ShardingOptimizer(MetaOptimizerBase):
add broadcast allreduce op
if enable gradient_merge, insert related ops
if combined with pipeline(grad accumulate),
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
"""
if len(self._segments) < 1:
......@@ -1072,175 +1198,280 @@ class ShardingOptimizer(MetaOptimizerBase):
# NOTE (JZ-LIANG) revise and unify logic here
# fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1, -1):
for idx in range(
self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1,
-1,
):
op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1
if "MERGED" in op.output_arg_names[0]:
new_end_idx = idx + 1
elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1
if "@TMP" in op.output_arg_names[0]:
new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(
self._segments[-1]._allreduce_vars)
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
self._segments[-1]._allreduce_vars
)
if (
self.gradient_merge_mode != "sharding_gm"
or self._gradient_merge_acc_step <= 1
):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_sync_comm_ops(
block,
self._segments[-1]._end_idx,
self.dp_ring_id,
shard_allredue_vars,
)
insert_allreduce_ops(
block,
self._segments[-1]._end_idx,
self.dp_ring_id,
shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy)
user_defined_strategy=self.user_defined_strategy,
)
# gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
self.create_persistable_gradients_and_insert_merge_ops(
block, self._startup_program.global_block(),
self._segments[-1]._end_idx, shard_allredue_vars,
self._shard)
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
block,
self._startup_program.global_block(),
self._segments[-1]._end_idx,
shard_allredue_vars,
self._shard,
)
insert_sync_comm_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
)
# allreduce --> reduce
insert_reduce_ops(block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
insert_reduce_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False,
)
for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[
idx - 1]._allreduce_vars if idx > 0 else []
broadcast_vars = self._segments[
idx +
1]._broadcast_vars if idx < len(self._segments) - 1 else []
fill_constant_vars = self._segments[
idx +
2]._fill_constant_vars if idx < len(self._segments) - 2 else []
cast_ops = self._segments[
idx + 2]._cast_ops if idx < len(self._segments) - 2 else {}
allreduce_vars = (
self._segments[idx - 1]._allreduce_vars if idx > 0 else []
)
broadcast_vars = (
self._segments[idx + 1]._broadcast_vars
if idx < len(self._segments) - 1
else []
)
fill_constant_vars = (
self._segments[idx + 2]._fill_constant_vars
if idx < len(self._segments) - 2
else []
)
cast_ops = (
self._segments[idx + 2]._cast_ops
if idx < len(self._segments) - 2
else {}
)
for op_idx in reversed(range(segment._start_idx, segment._end_idx)):
op = block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if input_name in segment._param2broadcast and \
input_name != segment._param2broadcast[input_name]:
op._rename_input(input_name,
segment._param2broadcast[input_name])
if (
input_name in segment._param2broadcast
and input_name != segment._param2broadcast[input_name]
):
op._rename_input(
input_name, segment._param2broadcast[input_name]
)
for param_name, broadcast_name in segment._param2broadcast.items():
if param_name != broadcast_name:
block.create_var(
name=broadcast_name,
shape=self._main_program.global_block().var(
param_name).shape,
dtype=self._main_program.global_block().var(
param_name).dtype,
persistable=False)
shape=self._main_program.global_block()
.var(param_name)
.shape,
dtype=self._main_program.global_block()
.var(param_name)
.dtype,
persistable=False,
)
# step1: remove cast ops
block._sync_with_cpp()
segment._end_idx += FP16Utils.remove_cast_op(
block, self._params, segment, 0)
block, self._params, segment, 0
)
# step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx,
self.dp_ring_id, shard_allredue_vars)
if (
self.gradient_merge_mode != "sharding_gm"
or self._gradient_merge_acc_step <= 1
):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_sync_comm_ops(
block,
segment._end_idx,
self.dp_ring_id,
shard_allredue_vars,
)
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id,
broad_cast_vars)
insert_sync_comm_ops(
block,
segment._end_idx,
self.sharding_ring_id,
broad_cast_vars,
)
else:
comm_dep_vars = allreduce_vars + [
x[0] for x in broadcast_vars
]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id,
comm_dep_vars)
insert_sync_comm_ops(
block,
segment._end_idx,
self.sharding_ring_id,
comm_dep_vars,
)
# gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, broad_cast_vars)
insert_sync_comm_ops(
block,
segment._end_idx,
self.sharding_ring_id,
broad_cast_vars,
)
calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items()
] + self._segments[idx]._allreduce_vars
calc_dep_vars = (
fill_constant_vars
+ [k for k, v in cast_ops.items()]
+ self._segments[idx]._allreduce_vars
)
if len(calc_dep_vars) > 0:
insert_sync_calc_op(block, segment._end_idx,
[calc_dep_vars[-1]])
insert_sync_calc_op(
block, segment._end_idx, [calc_dep_vars[-1]]
)
# step3: insert `fill_constant` ops
insert_fill_constant_ops(block, segment._end_idx,
fill_constant_vars)
insert_fill_constant_ops(
block, segment._end_idx, fill_constant_vars
)
# step4: add `cast` ops
insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops
# gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
if (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
self.create_persistable_gradients_and_insert_merge_ops(
block, self._startup_program.global_block(),
segment._start_idx, shard_allredue_vars, self._shard)
block,
self._startup_program.global_block(),
segment._start_idx,
shard_allredue_vars,
self._shard,
)
insert_broadcast_ops(block, segment._start_idx,
self.sharding_ring_id, broadcast_vars)
insert_broadcast_ops(
block, segment._start_idx, self.sharding_ring_id, broadcast_vars
)
# step6: add all_reduce ops
# dp
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
if (
self.gradient_merge_mode != "sharding_gm"
or self._gradient_merge_acc_step <= 1
):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_allreduce_ops(
block,
segment._start_idx,
self.dp_ring_id,
shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
user_defined_strategy=self.user_defined_strategy,
)
insert_sync_comm_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
)
# gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
insert_sync_comm_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
)
# sharding
# allreduce --> reduce
# TODO temp change
if len(allreduce_vars) > 0:
insert_reduce_ops(block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False,
)
block._sync_with_cpp()
if self._segments[0]._broadcast_vars:
broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars]
insert_sync_comm_ops(block, self._segments[0]._start_idx,
self.sharding_ring_id, broadcast_vars)
insert_broadcast_ops(block, self._segments[0]._start_idx,
self.sharding_ring_id,
self._segments[0]._broadcast_vars)
insert_sync_comm_ops(
block,
self._segments[0]._start_idx,
self.sharding_ring_id,
broadcast_vars,
)
insert_broadcast_ops(
block,
self._segments[0]._start_idx,
self.sharding_ring_id,
self._segments[0]._broadcast_vars,
)
fill_constant_vars = []
for x in self._segments[:2]:
......@@ -1254,12 +1485,14 @@ class ShardingOptimizer(MetaOptimizerBase):
calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()]
if fill_constant_vars or cast_ops:
insert_sync_calc_op(block, self._segments[0]._start_idx,
[calc_deps_vars[-1]])
insert_sync_calc_op(
block, self._segments[0]._start_idx, [calc_deps_vars[-1]]
)
if fill_constant_vars:
insert_fill_constant_ops(block, self._segments[0]._start_idx,
fill_constant_vars)
insert_fill_constant_ops(
block, self._segments[0]._start_idx, fill_constant_vars
)
if cast_ops:
insert_cast_ops(block, self._segments[0]._start_idx, cast_ops)
......@@ -1273,7 +1506,7 @@ class ShardingOptimizer(MetaOptimizerBase):
continue
if self._optimizer_sharding and shard.is_param(output_name):
continue
#TODO why do we remove op, when only one var is removed
# TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False)
break
......@@ -1295,23 +1528,36 @@ class ShardingOptimizer(MetaOptimizerBase):
pp: 4
pp-pair: >= 20
if one parallelism is not enable: -1
and only support parallelism hierarchy: mp --> sharding --> pp --> dp
and only support parallelism hierarchy: mp --> sharding --> pp --> dp
"""
# step 1: initialize nccl
self.global_word_size = self.role_maker._worker_num()
self.global_rank = self.role_maker._worker_index()
self.global_endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.global_endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker,
nrings=self._nrings_sharding)
assert self.global_word_size % self.mp_degree == 0, \
"global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree)
assert self.global_word_size % self.sharding_degree == 0, \
"global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree)
assert self.global_word_size % self.pp_degree == 0, \
"global_word_size: {} should be divisible to the pp_degree: {}".format(self.global_word_size, self.pp_degree)
assert self.global_word_size % self.dp_degree == 0, \
"global_word_size: {} should be divisible to the dp_degree: {}".format(self.global_word_size, self.dp_degree)
self._collective_helper = CollectiveHelper(
self.role_maker, nrings=self._nrings_sharding
)
assert (
self.global_word_size % self.mp_degree == 0
), "global_word_size: {} should be divisible to the mp_degree: {}".format(
self.global_word_size, self.mp_degree
)
assert (
self.global_word_size % self.sharding_degree == 0
), "global_word_size: {} should be divisible to the sharding_degree: {}".format(
self.global_word_size, self.sharding_degree
)
assert (
self.global_word_size % self.pp_degree == 0
), "global_word_size: {} should be divisible to the pp_degree: {}".format(
self.global_word_size, self.pp_degree
)
assert (
self.global_word_size % self.dp_degree == 0
), "global_word_size: {} should be divisible to the dp_degree: {}".format(
self.global_word_size, self.dp_degree
)
# mp group
if self.mp_degree > 1:
......@@ -1319,14 +1565,16 @@ class ShardingOptimizer(MetaOptimizerBase):
self.mp_rank = self.global_rank % self.mp_degree
self.mp_group_id = self.global_rank // self.mp_degree
self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints)
ep
for idx, ep in enumerate(self.global_endpoints)
if idx // self.mp_degree == self.mp_group_id
]
assert self.current_endpoint in self.mp_group_endpoints
assert len(
self.mp_group_endpoints
) == self.mp_degree, "num of mp worker in group is [{}], but mp group size is [{}]".format(
len(self.mp_group_endpoints), self.mp_degree)
assert (
len(self.mp_group_endpoints) == self.mp_degree
), "num of mp worker in group is [{}], but mp group size is [{}]".format(
len(self.mp_group_endpoints), self.mp_degree
)
else:
self.mp_degree = 1
self.mp_ring_id = -1
......@@ -1337,23 +1585,28 @@ class ShardingOptimizer(MetaOptimizerBase):
# sharding
if self.sharding_degree > 1:
self.sharding_ring_id = 1
self.sharding_rank = (self.global_rank //
self.mp_degree) % self.sharding_degree
self.sharding_group_id = self.global_rank // (self.mp_degree *
self.sharding_degree)
self.sharding_rank = (
self.global_rank // self.mp_degree
) % self.sharding_degree
self.sharding_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree
)
# mp + sharding + ...
if self.mp_degree > 1:
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints)
if (idx // (self.mp_degree * self.sharding_degree)) == self.
sharding_group_id and idx % self.mp_degree == self.mp_rank
ep
for idx, ep in enumerate(self.global_endpoints)
if (idx // (self.mp_degree * self.sharding_degree))
== self.sharding_group_id
and idx % self.mp_degree == self.mp_rank
]
# sharding + ...
else:
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints)
if (idx // (self.mp_degree * self.sharding_degree)
) == self.sharding_group_id
ep
for idx, ep in enumerate(self.global_endpoints)
if (idx // (self.mp_degree * self.sharding_degree))
== self.sharding_group_id
]
assert self.current_endpoint in self.sharding_group_endpoints
else:
......@@ -1368,20 +1621,28 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_pair_ring_id = 20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self.pp_ring_id = 4
self.pp_rank = self.global_rank // (self.sharding_degree *
self.mp_degree) % self.pp_degree
self.pp_rank = (
self.global_rank
// (self.sharding_degree * self.mp_degree)
% self.pp_degree
)
# (NOTE): Already adjust for (outter-pure) dp
self.pp_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree * self.pp_degree)
self.mp_degree * self.sharding_degree * self.pp_degree
)
pp_first_stage_idx = self.global_rank % (
self.sharding_degree * self.mp_degree) + self.pp_group_id * (
self.mp_degree * self.sharding_degree * self.pp_degree)
self.sharding_degree * self.mp_degree
) + self.pp_group_id * (
self.mp_degree * self.sharding_degree * self.pp_degree
)
pp_stage_offset = self.sharding_degree * self.mp_degree
self.pp_group_endpoints = []
for i in range(self.pp_degree):
self.pp_group_endpoints.append(
self.global_endpoints[pp_first_stage_idx +
pp_stage_offset * i])
self.global_endpoints[
pp_first_stage_idx + pp_stage_offset * i
]
)
assert self.current_endpoint in self.pp_group_endpoints
else:
self.pp_ring_id = -1
......@@ -1397,29 +1658,48 @@ class ShardingOptimizer(MetaOptimizerBase):
# sharding-hybrid-dp as one senario of outter-pure-dp
local_pp_degree = self.pp_degree
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
assert self.pp_degree == 2, ("For manually set pipeline, only "
"pp_degree = 2 is supported.")
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
self.global_word_size, self.mp_degree, self.sharding_degree, self.dp_degree)
assert self.pp_degree == 2, (
"For manually set pipeline, only " "pp_degree = 2 is supported."
)
assert (
self.global_word_size
== self.mp_degree * self.sharding_degree * self.dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
self.global_word_size,
self.mp_degree,
self.sharding_degree,
self.dp_degree,
)
local_pp_degree = 1
else:
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree, self.sharding_degree, self.pp_degree,
self.dp_degree, self.global_word_size)
assert (
self.global_word_size
== self.mp_degree
* self.sharding_degree
* self.pp_degree
* self.dp_degree
), "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree,
self.global_word_size,
)
if self.dp_degree > 1:
self.dp_ring_id = 2
self.dp_rank = self.global_rank // (
self.sharding_degree * self.mp_degree * local_pp_degree)
self.sharding_degree * self.mp_degree * local_pp_degree
)
dp_first_rank_idx = self.global_rank % (
self.sharding_degree * self.mp_degree * local_pp_degree)
dp_offset = (self.sharding_degree * self.mp_degree *
local_pp_degree)
self.sharding_degree * self.mp_degree * local_pp_degree
)
dp_offset = self.sharding_degree * self.mp_degree * local_pp_degree
self.dp_group_endpoints = []
for i in range(self.dp_degree):
self.dp_group_endpoints.append(
self.global_endpoints[dp_first_rank_idx + dp_offset * i])
self.global_endpoints[dp_first_rank_idx + dp_offset * i]
)
assert self.current_endpoint in self.dp_group_endpoints
logger.info("Hybrid DP mode turn on !")
else:
......@@ -1448,8 +1728,9 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("sharding group size: {}".format(self.sharding_degree))
logger.info("sharding rank: {}".format(self.sharding_rank))
logger.info("sharding group id: {}".format(self.sharding_group_id))
logger.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logger.info(
"sharding group endpoints: {}".format(self.sharding_group_endpoints)
)
logger.info("sharding ring id: {}".format(self.sharding_ring_id))
logger.info("#####" * 6)
......@@ -1462,15 +1743,15 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("pure dp group size: {}".format(self.dp_degree))
logger.info("pure dp rank: {}".format(self.dp_rank))
logger.info("pure dp group endpoints: {}".format(
self.dp_group_endpoints))
logger.info(
"pure dp group endpoints: {}".format(self.dp_group_endpoints)
)
logger.info("pure dp ring id: {}".format(self.dp_ring_id))
logger.info("#####" * 6)
return
def _recreate_not_persist_param_as_var(self):
def recreate_not_persist_param_as_var(program):
block = program.global_block()
params = block.all_parameters()
......@@ -1494,14 +1775,16 @@ class ShardingOptimizer(MetaOptimizerBase):
is_distributed = param.is_distributed
block._remove_var(name, sync=False)
var = block.create_var(name=name,
shape=shape,
dtype=dtype,
type=type,
lod_level=lod_level,
stop_gradient=stop_gradient,
trainable=trainable,
persistable=False)
var = block.create_var(
name=name,
shape=shape,
dtype=dtype,
type=type,
lod_level=lod_level,
stop_gradient=stop_gradient,
trainable=trainable,
persistable=False,
)
if have_dist_attr:
var.is_distributed = is_distributed
......@@ -1516,6 +1799,13 @@ class ShardingOptimizer(MetaOptimizerBase):
identical when hybrid-dp is used, and the initialization of
not distributed param between mp group to be identical.
"""
def _find_master_param(all_vars_name, param_name):
for var_name in all_vars_name:
if param_name in var_name and "fp32_master" in var_name:
return var_name
return None
if self.dp_degree <= 1 and self.mp_degree <= 1:
return
......@@ -1536,8 +1826,10 @@ class ShardingOptimizer(MetaOptimizerBase):
if op.type == 'c_broadcast':
broadcast_params.add(op.desc.output_arg_names()[0])
all_vars_name = startup_block.vars
for param in params_name:
if param in broadcast_params: continue
if param in broadcast_params:
continue
rings = []
# need sync not distributed param in mp group
......@@ -1547,30 +1839,51 @@ class ShardingOptimizer(MetaOptimizerBase):
rings.append(self.dp_ring_id)
for ring in rings:
startup_block.append_op(type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
},
)
# Broadcast the master weight at the same time for AMP-O2 training.
master_param = _find_master_param(all_vars_name, param)
if master_param is not None:
startup_block.append_op(
type='c_broadcast',
inputs={'X': master_param},
outputs={'Out': master_param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
},
)
startup_block._sync_with_cpp()
# sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops(
self, main_block, startup_block, insert_idx, grad_names, shard):
self, main_block, startup_block, insert_idx, grad_names, shard
):
for grad_name in grad_names:
assert get_grad_device(
grad_name, shard
) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format(
grad_name)
assert (
get_grad_device(grad_name, shard) == shard.worker_idx
), "try to merge gradient not belong to current shard: [{}]".format(
grad_name
)
persistable_grad_name = grad_name + '@GradiantMerge'
assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(
grad_name)
assert (
grad_name not in self._grad2merged_grad
), "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(
grad_name
)
self._grad2merged_grad[grad_name] = persistable_grad_name
grad_var = main_block.var(grad_name)
# create var
......@@ -1578,36 +1891,38 @@ class ShardingOptimizer(MetaOptimizerBase):
name=persistable_grad_name,
shape=grad_var.shape,
dtype=grad_var.dtype,
persistable=True)
persistable=True,
)
startup_gradient_merge_var = startup_block.create_var(
name=persistable_grad_name,
shape=grad_var.shape,
dtype=grad_var.dtype,
persistable=True)
persistable=True,
)
# merge gradient
main_block._insert_op_without_sync(
insert_idx,
type="elementwise_add",
inputs={
'X': grad_name,
'Y': gradient_merge_var
},
inputs={'X': grad_name, 'Y': gradient_merge_var},
outputs={'Out': gradient_merge_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward
})
OP_ROLE_KEY: OpRole.Backward,
},
)
# startup initialization
startup_block.append_op(type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": grad_var.shape,
"dtype": grad_var.dtype,
"value": float(0),
})
startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": grad_var.shape,
"dtype": grad_var.dtype,
"value": float(0),
},
)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
......@@ -1620,14 +1935,17 @@ class ShardingOptimizer(MetaOptimizerBase):
value=int(self._gradient_merge_acc_step),
dtype='int32',
persistable=True,
force_cpu=True)
force_cpu=True,
)
zero_var = layers.create_global_var(name="gradient_merge_zero",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
zero_var = layers.create_global_var(
name="gradient_merge_zero",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True,
)
# Add step var & cond var
current_step_var = layers.create_global_var(
......@@ -1636,42 +1954,40 @@ class ShardingOptimizer(MetaOptimizerBase):
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
force_cpu=True,
)
cond_var = main_block.create_var(name="gradient_merge_cond",
shape=[1],
dtype='bool')
cond_var = main_block.create_var(
name="gradient_merge_cond", shape=[1], dtype='bool'
)
with device_guard("cpu"):
# step_var = (step_var + 1) % k_step
main_block.append_op(type='increment',
inputs={'X': [current_step_var]},
outputs={'Out': [current_step_var]},
attrs={
'step': float(1),
OP_ROLE_KEY: OpRole.Optimize
})
main_block.append_op(type='elementwise_mod',
inputs={
'X': current_step_var,
'Y': acc_step_var
},
outputs={'Out': current_step_var},
attrs={
'axis': -1,
OP_ROLE_KEY: OpRole.Optimize,
'use_mkldnn': False
})
main_block.append_op(
type='increment',
inputs={'X': [current_step_var]},
outputs={'Out': [current_step_var]},
attrs={'step': float(1), OP_ROLE_KEY: OpRole.Optimize},
)
main_block.append_op(
type='elementwise_mod',
inputs={'X': current_step_var, 'Y': acc_step_var},
outputs={'Out': current_step_var},
attrs={
'axis': -1,
OP_ROLE_KEY: OpRole.Optimize,
'use_mkldnn': False,
},
)
# cond_var = (step_var == 0)
main_block.append_op(type='equal',
inputs={
'X': current_step_var,
'Y': zero_var
},
outputs={'Out': cond_var},
attrs={OP_ROLE_KEY: OpRole.Optimize})
main_block.append_op(
type='equal',
inputs={'X': current_step_var, 'Y': zero_var},
outputs={'Out': cond_var},
attrs={OP_ROLE_KEY: OpRole.Optimize},
)
# paddle.static.Print(current_step_var, message="in FWBW last conditional")
return cond_var
......@@ -1681,7 +1997,7 @@ class ShardingOptimizer(MetaOptimizerBase):
grad@gradientmerge / acc_step
re-create all optimize ops of origin main block and rename them
cast(backward)
amp
amp
clip
opt
# fill constant grad@gradientmerge
......@@ -1698,35 +2014,37 @@ class ShardingOptimizer(MetaOptimizerBase):
# allreduce grad@gradientmerge
if self.hybrid_dp:
assert self.dp_ring_id >= 0, "dp_ring_id should larger than 0 when in sharding&DP mode"
assert (
self.dp_ring_id >= 0
), "dp_ring_id should larger than 0 when in sharding&DP mode"
for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='c_allreduce_sum',
inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var},
attrs={
'ring_id': self.dp_ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
cur_block.append_op(
type='c_allreduce_sum',
inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var},
attrs={
'ring_id': self.dp_ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
},
)
# grad@gradientmerge / acc_step
for grad, merged_grad in self._grad2merged_grad.items():
# grad /= k_steps
merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='scale',
inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var},
attrs={
'scale':
1.0 / float(self._gradient_merge_acc_step),
'bias':
0.0,
'bias_after_scale':
False,
OP_ROLE_KEY:
OpRole.Optimize
})
cur_block.append_op(
type='scale',
inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var},
attrs={
'scale': 1.0 / float(self._gradient_merge_acc_step),
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize,
},
)
# re-create optimize ops
already_moved_var_names = []
......@@ -1737,15 +2055,19 @@ class ShardingOptimizer(MetaOptimizerBase):
for input_name in new_op_desc.input_arg_names():
if input_name in self._grad2merged_grad:
new_op_desc._rename_input(
input_name, self._grad2merged_grad[input_name])
input_name, self._grad2merged_grad[input_name]
)
for output_name in new_op_desc.output_arg_names():
if output_name in self._grad2merged_grad:
new_op_desc._rename_output(
output_name, self._grad2merged_grad[output_name])
output_name, self._grad2merged_grad[output_name]
)
# move non temp optimize vars from block0 to cond block
if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys(
if (
output_name not in already_moved_var_names
and output_name not in self._grad2merged_grad.keys()
):
var_ = self._main_program.global_block().var(output_name)
if not var_.persistable:
......@@ -1754,11 +2076,14 @@ class ShardingOptimizer(MetaOptimizerBase):
shape_ = var_.shape
type_ = var_.dtype
self._main_program.global_block()._remove_var(
var_.name, sync=False)
self.cond_block.create_var(name=name_,
shape=shape_,
dtype=type_,
persistable=False)
var_.name, sync=False
)
self.cond_block.create_var(
name=name_,
shape=shape_,
dtype=type_,
persistable=False,
)
already_moved_var_names.append(name_)
self._main_program.global_block()._sync_with_cpp()
......@@ -1767,14 +2092,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# fill zero to grad@gradientmerge
for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='fill_constant',
outputs={'Out': merged_grad_var},
attrs={
"shape": merged_grad_var.shape,
"dtype": merged_grad_var.dtype,
"value": float(0),
OP_ROLE_KEY: OpRole.Optimize
})
cur_block.append_op(
type='fill_constant',
outputs={'Out': merged_grad_var},
attrs={
"shape": merged_grad_var.shape,
"dtype": merged_grad_var.dtype,
"value": float(0),
OP_ROLE_KEY: OpRole.Optimize,
},
)
# lr_var = main_block.var("gradient_merge_current_step")
# paddle.static.Print(lr_var, message="in OPTIMIZE last conditional")
......@@ -1786,7 +2113,10 @@ class ShardingOptimizer(MetaOptimizerBase):
create cond block
"""
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if (
self.gradient_merge_mode != "sharding_gm"
or self._gradient_merge_acc_step <= 1
):
return
main_block = self._main_program.global_block()
......@@ -1805,7 +2135,8 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(op_idx, sync=False)
tmp_copy_block._sync_with_cpp()
self.original_optimize_ops_desc = list(
reversed(self.original_optimize_ops_desc))
reversed(self.original_optimize_ops_desc)
)
# back to block 0
self._main_program._rollback()
......@@ -1822,18 +2153,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# cond op
step_scope = self._main_program.global_block().create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
type=core.VarDesc.VarType.STEP_SCOPES
)
conditional_block_op = self._main_program.global_block().append_op(
type='conditional_block',
inputs={
'Cond': cond,
'Input': [],
},
outputs={
'Out': [],
'Scope': [step_scope]
},
outputs={'Out': [], 'Scope': [step_scope]},
attrs={
'sub_block': cond_block,
'is_scalar_condition': True,
})
},
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册