未验证 提交 d12588d2 编写于 作者: Y Yiqun Liu 提交者: GitHub

Broadcast the master weight along with param for distributed training. (#52638)

* Broadcast the master weight along with param for distributed training.

* Fix codestyle.
上级 ba9a22db
...@@ -28,6 +28,7 @@ from .sharding.gradient_clip_helper import GradientClipHelper ...@@ -28,6 +28,7 @@ from .sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper from .sharding.offload_helper import OffloadHelper
from .sharding.prune import ProgramDeps from .sharding.prune import ProgramDeps
from .sharding import utils from .sharding import utils
# FIXME: import * # FIXME: import *
from .sharding.utils import * from .sharding.utils import *
import logging import logging
...@@ -84,7 +85,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -84,7 +85,7 @@ class ShardingOptimizer(MetaOptimizerBase):
dist_strategy.sharding_configs = {"segment_broadcast_MB": 32} dist_strategy.sharding_configs = {"segment_broadcast_MB": 32}
def _get_sharding_segment_strategy(self): def _get_sharding_segment_strategy(self):
""" get """get
self._sharding_segment_strategy self._sharding_segment_strategy
1. if by_size: self._broadcast_MB 1. if by_size: self._broadcast_MB
2. if by_anchors: self._sharding_segment_anchors 2. if by_anchors: self._sharding_segment_anchors
...@@ -97,21 +98,26 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -97,21 +98,26 @@ class ShardingOptimizer(MetaOptimizerBase):
if segment_strategy == "segment_broadcast_MB": if segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = sharding_configs["segment_broadcast_MB"] self._broadcast_MB = sharding_configs["segment_broadcast_MB"]
assert self._broadcast_MB > 0, "segment size should larger than zero !" assert (
self._broadcast_MB > 0
), "segment size should larger than zero !"
elif segment_strategy == "segment_anchors": elif segment_strategy == "segment_anchors":
self._sharding_segment_anchors = sharding_configs["segment_anchors"] self._sharding_segment_anchors = sharding_configs["segment_anchors"]
assert len(self._sharding_segment_anchors assert (
) > 0, "you should set the sharding segment anchors !" len(self._sharding_segment_anchors) > 0
), "you should set the sharding segment anchors !"
self._backward_remain_anchors = self._sharding_segment_anchors[:] self._backward_remain_anchors = self._sharding_segment_anchors[:]
self._forward_remain_anchors = [] self._forward_remain_anchors = []
else: else:
raise NotImplementedError( raise NotImplementedError(
"the sharding segment strategy [{}] is not implemented".format( "the sharding segment strategy [{}] is not implemented".format(
str(segment_strategy))) str(segment_strategy)
)
)
self._sharding_segment_strategy = segment_strategy self._sharding_segment_strategy = segment_strategy
def _get_hybrid_degree(self): def _get_hybrid_degree(self):
""" get """get
self.hybrid_dp self.hybrid_dp
self.sharding_degree self.sharding_degree
self.mp_degree self.mp_degree
...@@ -135,21 +141,32 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -135,21 +141,32 @@ class ShardingOptimizer(MetaOptimizerBase):
assert strategy.pipeline is True assert strategy.pipeline is True
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
assert pp_degree == 2, ("For manually set pipeline, only " assert pp_degree == 2, (
"pp_degree = 2 is supported.") "For manually set pipeline, only " "pp_degree = 2 is supported."
assert global_world_size == mp_degree * sharding_degree * dp_degree, \ )
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( assert (
global_world_size, mp_degree, sharding_degree, dp_degree) global_world_size == mp_degree * sharding_degree * dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, dp_degree
)
else: else:
assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \ assert (
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( global_world_size
global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree) == mp_degree * sharding_degree * pp_degree * dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
global_world_size,
mp_degree,
sharding_degree,
pp_degree,
dp_degree,
)
# FIXME (JZ-LIANG) deprecated hybrid_dp # FIXME (JZ-LIANG) deprecated hybrid_dp
if sharding_configs["hybrid_dp"]: if sharding_configs["hybrid_dp"]:
logger.warning( logger.warning(
"[hybrid_dp] API setting is deprecated. Now when " "[hybrid_dp] API setting is deprecated. Now when "
"dp_degree >= 2, its will be in hybrid dp mode automatically") "dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert dp_degree >= 1 assert dp_degree >= 1
self.hybrid_dp = True if dp_degree > 1 else False self.hybrid_dp = True if dp_degree > 1 else False
...@@ -159,7 +176,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -159,7 +176,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_degree = dp_degree self.dp_degree = dp_degree
def _get_hybrid_dp_mode(self): def _get_hybrid_dp_mode(self):
""" get """get
self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp' self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp'
self.gradient_merge_mode = 'pp_gm' or 'sharding_gm' self.gradient_merge_mode = 'pp_gm' or 'sharding_gm'
self._gradient_merge_acc_step self._gradient_merge_acc_step
...@@ -183,9 +200,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -183,9 +200,10 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1: if self.pp_degree > 1:
dp_mode = "pp_hybrid_dp" dp_mode = "pp_hybrid_dp"
else: else:
assert self.sharding_degree > 1, \ assert self.sharding_degree > 1, (
"by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \ "by now we only support five kind of hybrid dp: sharding_hybrid_dp, "
"mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." "mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
)
dp_mode = "sharding_hybrid_dp" dp_mode = "sharding_hybrid_dp"
# gradient merge # gradient merge
...@@ -198,23 +216,33 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -198,23 +216,33 @@ class ShardingOptimizer(MetaOptimizerBase):
gm_mode = "pp_gm" gm_mode = "pp_gm"
gm_acc_step = strategy.pipeline_configs['accumulate_steps'] gm_acc_step = strategy.pipeline_configs['accumulate_steps']
gradient_scale_configs = strategy.gradient_scale_configs gradient_scale_configs = strategy.gradient_scale_configs
assert gradient_scale_configs['scale_strategy'] == 'avg', \ assert gradient_scale_configs['scale_strategy'] == 'avg', (
'For pipeline mode, the ' 'gradient scale mode should ' \ 'For pipeline mode, the '
'be "avg", but got {}'.format(gradient_scale_configs['scale_strategy']) 'gradient scale mode should '
'be "avg", but got {}'.format(
gradient_scale_configs['scale_strategy']
)
)
# Note (Yuang Liu): this avg_loss flag determines where to do the average op for grad merge. # Note (Yuang Liu): this avg_loss flag determines where to do the average op for grad merge.
# If True, will do sum firstly for gradient merge, then do scale by gm_acc_step. # If True, will do sum firstly for gradient merge, then do scale by gm_acc_step.
# If False, will scale loss by gm_acc_step first, then do sum for gradient merge. # If False, will scale loss by gm_acc_step first, then do sum for gradient merge.
self.scale_gradient = gradient_scale_configs['scale_gradient'] self.scale_gradient = gradient_scale_configs['scale_gradient']
if gm_acc_step > 1: if gm_acc_step > 1:
logger.info("Gradient merge in [{}], acc step = [{}]".format( logger.info(
gm_mode, gm_acc_step)) "Gradient merge in [{}], acc step = [{}]".format(
gm_mode, gm_acc_step
)
)
optimizer_sharding = False optimizer_sharding = False
# TODO(wangxi): need support dp_as_opt_sharding with sharding # TODO(wangxi): need support dp_as_opt_sharding with sharding
# need support without pp in future # need support without pp in future
if self.sharding_degree == 1 and self.dp_degree > 1 \ if (
and sharding_configs['_dp_as_optimizer_sharding'] \ self.sharding_degree == 1
and self.pp_degree > 1: and self.dp_degree > 1
and sharding_configs['_dp_as_optimizer_sharding']
and self.pp_degree > 1
):
optimizer_sharding = True optimizer_sharding = True
self.hybrid_dp_mode = dp_mode self.hybrid_dp_mode = dp_mode
...@@ -224,19 +252,23 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -224,19 +252,23 @@ class ShardingOptimizer(MetaOptimizerBase):
# this feature is design for ascend, and should NOT be used in GPU training # this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = sharding_configs[ self.pp_allreduce_in_optimize = sharding_configs[
"pp_allreduce_in_optimize"] "pp_allreduce_in_optimize"
]
def _inner_opt_minimize(self, loss, startup_program, parameter_list, def _inner_opt_minimize(
no_grad_set): self, loss, startup_program, parameter_list, no_grad_set
):
pipeline_configs = self.user_defined_strategy.pipeline_configs pipeline_configs = self.user_defined_strategy.pipeline_configs
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None."
)
if self.pp_degree > 1: if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer( pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step) self.inner_opt, self._gradient_merge_acc_step
)
self._pp_optimizer = pp_optimizer self._pp_optimizer = pp_optimizer
global_rank = self.role_maker._worker_index() global_rank = self.role_maker._worker_index()
...@@ -253,17 +285,25 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -253,17 +285,25 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_ring_id': 3, 'global_ring_id': 3,
'mp_degree': self.mp_degree, 'mp_degree': self.mp_degree,
'mp_rank': global_rank % self.mp_degree, 'mp_rank': global_rank % self.mp_degree,
'scale_gradient': self.scale_gradient 'scale_gradient': self.scale_gradient,
} }
main_program = loss.block.program main_program = loss.block.program
main_program._pipeline_opt = pipeline_opt main_program._pipeline_opt = pipeline_opt
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( (
loss, startup_program, parameter_list, no_grad_set) optimize_ops,
params_grads,
program_list,
self.pipeline_pair,
self.pp_ring_map,
) = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set
)
assert self.pp_degree == len(program_list) assert self.pp_degree == len(program_list)
else: else:
optimize_ops, params_grads = self.inner_opt.minimize( optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set
)
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
...@@ -272,8 +312,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -272,8 +312,9 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program = startup_program._pipeline_opt['startup_program'] startup_program = startup_program._pipeline_opt['startup_program']
print("pp_rank:", self.pp_rank) print("pp_rank:", self.pp_rank)
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
main_program = program_list[int( main_program = program_list[
os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))] int(os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))
]
else: else:
main_program = program_list[self.pp_rank] main_program = program_list[self.pp_rank]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f: with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
...@@ -299,14 +340,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -299,14 +340,16 @@ class ShardingOptimizer(MetaOptimizerBase):
return optimize_ops, params_grads return optimize_ops, params_grads
def _apply_sharding_pass(self, params_grads): def _apply_sharding_pass(self, params_grads):
if self.sharding_degree == 1: return if self.sharding_degree == 1:
return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
# step1: build shard # step1: build shard
self._build_shard(params_grads, self.sharding_rank, self._build_shard(
self.sharding_degree) params_grads, self.sharding_rank, self.sharding_degree
)
# step2: split_program # step2: split_program
self._split_program(main_block) self._split_program(main_block)
...@@ -318,13 +361,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -318,13 +361,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# step4: remove unneeded ops and vars from block # step4: remove unneeded ops and vars from block
self._prune_main_program( self._prune_main_program(
main_block, self._shard, main_block,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id]) self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id],
)
self._prune_startup_program(startup_block, self._shard) self._prune_startup_program(startup_block, self._shard)
def _apply_opt_sharding_pass(self, params_grads): def _apply_opt_sharding_pass(self, params_grads):
""" outer dp as optimizer sharding """ """outer dp as optimizer sharding"""
if self._optimizer_sharding is False: return if self._optimizer_sharding is False:
return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
...@@ -338,12 +384,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -338,12 +384,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# step4: remove unneeded ops and vars from block # step4: remove unneeded ops and vars from block
self._prune_main_program( self._prune_main_program(
main_block, self._shard, main_block,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id]) self._shard,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id],
)
self._prune_startup_program(startup_block, self._shard) self._prune_startup_program(startup_block, self._shard)
def _insert_allreduce_for_pp(self, params_grads): def _insert_allreduce_for_pp(self, params_grads):
if self.pp_degree == 1: return if self.pp_degree == 1:
return
strategy = self.user_defined_strategy strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs sharding_configs = strategy.sharding_configs
...@@ -363,10 +412,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -363,10 +412,12 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(idx) main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue if op.type != 'cast':
continue
in_name = op.input_arg_names[0] in_name = op.input_arg_names[0]
if in_name not in self._params: continue if in_name not in self._params:
#if self._shard.has_param(param_name): continue continue
# if self._shard.has_param(param_name): continue
if in_name not in main_block.vars: if in_name not in main_block.vars:
main_block._remove_op(idx) main_block._remove_op(idx)
...@@ -376,7 +427,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -376,7 +427,8 @@ class ShardingOptimizer(MetaOptimizerBase):
shard = self._shard if self._optimizer_sharding else None shard = self._shard if self._optimizer_sharding else None
accumulated_grad_names = self._pp_optimizer._accumulate_gradients( accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block, strategy=strategy, shard=shard) main_block, strategy=strategy, shard=shard
)
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
if self.scale_gradient: if self.scale_gradient:
...@@ -384,8 +436,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -384,8 +436,9 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index = get_first_optimize_op_idx(main_block) first_optimize_op_index = get_first_optimize_op_idx(main_block)
if self.pp_allreduce_in_optimize: if self.pp_allreduce_in_optimize:
logger.info("Pipeline Persistable grad is {}".format( logger.info(
accumulated_grad_names)) "Pipeline Persistable grad is {}".format(accumulated_grad_names)
)
# FIXME(wangxi): accumulated_grad get from pipeline is not # FIXME(wangxi): accumulated_grad get from pipeline is not
# include sharding's param@BroadCast grad when # include sharding's param@BroadCast grad when
# pp_allreduce_in_optimize # pp_allreduce_in_optimize
...@@ -397,10 +450,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -397,10 +450,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._shard, self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize, core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
rank=self.sharding_rank) rank=self.sharding_rank,
)
logger.info("PP-Sharding grad is {}".format(accumulated_grad_names)) logger.info("PP-Sharding grad is {}".format(accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops) first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
if self._optimizer_sharding: if self._optimizer_sharding:
...@@ -413,10 +467,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -413,10 +467,12 @@ class ShardingOptimizer(MetaOptimizerBase):
OpRole.Optimize, OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
rank=self.dp_rank, rank=self.dp_rank,
strategy=strategy) strategy=strategy,
)
logger.info( logger.info(
"Optimizer grad in this rank {}".format(accumulated_grad_names)) "Optimizer grad in this rank {}".format(accumulated_grad_names)
first_optimize_op_index += (len(main_block.ops) - len_of_ops) )
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
# NOTE(wangxi): we fused after optimize_cast # NOTE(wangxi): we fused after optimize_cast
...@@ -424,14 +480,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -424,14 +480,17 @@ class ShardingOptimizer(MetaOptimizerBase):
optimizer_param = utils.insert_broadcast_param_ops( optimizer_param = utils.insert_broadcast_param_ops(
main_block, main_block,
len_of_ops, len_of_ops,
self.dp_ring_id, [x[0].name for x in params_grads], self.dp_ring_id,
[x[0].name for x in params_grads],
self._shard, self._shard,
OpRole.Optimize, OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
rank=self.dp_rank, rank=self.dp_rank,
strategy=None if optimize_cast else strategy) strategy=None if optimize_cast else strategy,
)
logger.info( logger.info(
"Optimizer param in this rank {}".format(optimizer_param)) "Optimizer param in this rank {}".format(optimizer_param)
)
if not strategy.fuse_grad_merge and not optimize_cast: if not strategy.fuse_grad_merge and not optimize_cast:
assert len(accumulated_grad_names) == len(optimizer_param) assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
...@@ -442,15 +501,20 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -442,15 +501,20 @@ class ShardingOptimizer(MetaOptimizerBase):
accumulated_grad_names, accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize, core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
user_defined_strategy=strategy) user_defined_strategy=strategy,
first_optimize_op_index += (len(main_block.ops) - len_of_ops) )
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
# FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there? # FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there?
def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names):
if self.user_defined_strategy.amp and \ if (
self.user_defined_strategy.amp_configs['use_dynamic_loss_scaling']: self.user_defined_strategy.amp
and self.user_defined_strategy.amp_configs[
'use_dynamic_loss_scaling'
]
):
# For AMP, if using dynamic loss scaling the avg # For AMP, if using dynamic loss scaling the avg
# operation can be simple done by modify the LossScaling op. # operation can be simple done by modify the LossScaling op.
for idx, op in enumerate(main_block.ops): for idx, op in enumerate(main_block.ops):
...@@ -461,7 +525,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -461,7 +525,8 @@ class ShardingOptimizer(MetaOptimizerBase):
loss_scale_tmp_var = main_block.create_var( loss_scale_tmp_var = main_block.create_var(
name=loss_scale_tmp_var_name, name=loss_scale_tmp_var_name,
shape=loss_scaling_var.shape, shape=loss_scaling_var.shape,
dtype=loss_scaling_var.dtype) dtype=loss_scaling_var.dtype,
)
main_block._insert_op_without_sync( main_block._insert_op_without_sync(
idx, idx,
type='scale', type='scale',
...@@ -471,8 +536,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -471,8 +536,9 @@ class ShardingOptimizer(MetaOptimizerBase):
'scale': self._gradient_merge_acc_step, 'scale': self._gradient_merge_acc_step,
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False, 'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize,
}) },
)
op._rename_input(loss_scale_name, loss_scale_tmp_var_name) op._rename_input(loss_scale_name, loss_scale_tmp_var_name)
break break
else: else:
...@@ -483,7 +549,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -483,7 +549,9 @@ class ShardingOptimizer(MetaOptimizerBase):
if is_optimizer_op(op) and op.type != 'c_sync_comm_stream': if is_optimizer_op(op) and op.type != 'c_sync_comm_stream':
tmp_first_opt_idx = idx tmp_first_opt_idx = idx
break break
assert tmp_first_opt_idx is not None, 'Occurs some errors, no optimize ops' assert (
tmp_first_opt_idx is not None
), 'Occurs some errors, no optimize ops'
for grad in accumulated_grad_names: for grad in accumulated_grad_names:
main_block._insert_op_without_sync( main_block._insert_op_without_sync(
tmp_first_opt_idx, tmp_first_opt_idx,
...@@ -494,14 +562,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -494,14 +562,17 @@ class ShardingOptimizer(MetaOptimizerBase):
'scale': 1.0 / self._gradient_merge_acc_step, 'scale': 1.0 / self._gradient_merge_acc_step,
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False, 'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize,
}) },
)
def _adapt_amp_clip_without_sharding(self): def _adapt_amp_clip_without_sharding(self):
# if not use sharding, adapt amp/clip, for remain parallelism. # if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt # cast --> amp --> clip --> opt
if self.sharding_degree > 1: return if self.sharding_degree > 1:
if self._optimizer_sharding: return return
if self._optimizer_sharding:
return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
...@@ -515,9 +586,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -515,9 +586,9 @@ class ShardingOptimizer(MetaOptimizerBase):
FP16Utils.sync_amp_check_nan_inf(main_block, rings) FP16Utils.sync_amp_check_nan_inf(main_block, rings)
gradientclip_helper = GradientClipHelper(None) gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(main_block, gradientclip_helper.sync_global_norm(
[self.mp_ring_id, self.pp_ring_id], main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank
self.mp_rank) )
def _insert_loss_grad_scale_op(self): def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
...@@ -538,8 +609,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -538,8 +609,9 @@ class ShardingOptimizer(MetaOptimizerBase):
mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None
dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None
offload_helper = OffloadHelper(mp_ring_id=mp_ring_id, offload_helper = OffloadHelper(
dp_ring_id=dp_ring_id) mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id
)
# optimize offload should be enable while gradient merge is enable and # optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be # acc_step is quite large (e.g. >> 100). Since its memcpy could not be
...@@ -555,32 +627,32 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -555,32 +627,32 @@ class ShardingOptimizer(MetaOptimizerBase):
# will take more memory, but will be faster. Trade space for time. # will take more memory, but will be faster. Trade space for time.
if self._optimizer_sharding: if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param( offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block, main_block, startup_block, [x[0].name for x in params_grads]
[x[0].name for x in params_grads]) )
# NOTE(wangxi): fused after optimize_cast # NOTE(wangxi): fused after optimize_cast
utils.fuse_opt_broadcast_param_ops(main_block, utils.fuse_opt_broadcast_param_ops(
dp_ring_id, main_block, dp_ring_id, self._shard, strategy=strategy
self._shard, )
strategy=strategy)
else: else:
offload_helper.cast_fp32param_in_optimize( offload_helper.cast_fp32param_in_optimize(
main_block, startup_block) main_block, startup_block
)
def _dump_program_for_debug(self): def _dump_program_for_debug(self):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
with open("start_sharding_%d" % self.role_maker._worker_index(), with open(
'w') as f: "start_sharding_%d" % self.role_maker._worker_index(), 'w'
) as f:
f.writelines(str(startup_block.program)) f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(), with open(
'w') as f: "main_sharding_%d" % self.role_maker._worker_index(), 'w'
) as f:
f.writelines(str(main_block.program)) f.writelines(str(main_block.program))
def minimize_impl(self, def minimize_impl(
loss, self, loss, startup_program=None, parameter_list=None, no_grad_set=None
startup_program=None, ):
parameter_list=None,
no_grad_set=None):
# TODO: (JZ-LIANG) support multiple comm in future # TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num # self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1 self._nrings_sharding = 1
...@@ -595,7 +667,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -595,7 +667,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# inner optimize minimize # inner optimize minimize
optimize_ops, params_grads = self._inner_opt_minimize( optimize_ops, params_grads = self._inner_opt_minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set
)
self._init_comm() self._init_comm()
...@@ -644,13 +717,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -644,13 +717,15 @@ class ShardingOptimizer(MetaOptimizerBase):
] ]
pp_rank = 0 if self.pp_rank == pair[0] else 1 pp_rank = 0 if self.pp_rank == pair[0] else 1
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None: if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self.current_endpoint, self._startup_program,
pp_group_endpoints, self.current_endpoint,
pp_rank, pp_group_endpoints,
ring_id, pp_rank,
False, ring_id,
sync=False) False,
sync=False,
)
def _init_npu_pipeline_comm(self, startup_block): def _init_npu_pipeline_comm(self, startup_block):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number # NOTE(wangxi): some bug with hccl, must set pp_degree be even number
...@@ -668,15 +743,22 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -668,15 +743,22 @@ class ShardingOptimizer(MetaOptimizerBase):
my_pair.append(pair) my_pair.append(pair)
# for example: self.pp_rank=2, self.pp_degree=4 # for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (self.pp_rank, (self.pp_rank + 1) % self.pp_degree send_to_next_pair = (
) # 2->3 self.pp_rank,
(self.pp_rank + 1) % self.pp_degree,
) # 2->3
recv_from_next_pair = ( recv_from_next_pair = (
(self.pp_rank + 1) % self.pp_degree, self.pp_rank) # 3->2 (self.pp_rank + 1) % self.pp_degree,
self.pp_rank,
) # 3->2
recv_from_prev_pair = ( recv_from_prev_pair = (
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree, self.pp_rank (self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
self.pp_rank,
) # 1->2 ) # 1->2
send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) % send_to_prev_pair = (
self.pp_degree) # 2->1 self.pp_rank,
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
) # 2->1
even = (self.pp_rank % 2) == 0 even = (self.pp_rank % 2) == 0
...@@ -685,54 +767,66 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -685,54 +767,66 @@ class ShardingOptimizer(MetaOptimizerBase):
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, ring_id)
)
# 2. even recv from next, odd send to prev, 1->0, 3->2 # 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, ring_id)
)
# if pp_degree is 2, only need pair(0->1, 1->0) # if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2: if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0 # 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1], ring_id = self.pp_ring_map.get(
max_ring_id + pair[0] * 1000 + pair[1], max_ring_id + 1
1) # 3->0 not in pp_ring_map ) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
# 4. odd recv from next, even send to prev, 2->1, 0->3 # 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1], ring_id = self.pp_ring_map.get(
max_ring_id + pair[0] * 1000 + pair[1], max_ring_id + 2
2) # 0->3 not in pp_ring_map ) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \ assert len(my_pair) == 0, (
"please check unexpected pair {}".format(my_pair) "Current pipeline does not support cross stage communication, "
"please check unexpected pair {}".format(my_pair)
)
def _init_pipeline_comm(self, startup_block): def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None: if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self.current_endpoint, self._startup_program,
self.pp_group_endpoints, self.current_endpoint,
self.pp_rank, self.pp_group_endpoints,
self.pp_ring_id, self.pp_rank,
False, self.pp_ring_id,
sync=False) False,
sync=False,
)
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block) self._init_npu_pipeline_comm(startup_block)
...@@ -752,13 +846,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -752,13 +846,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# mp ring # mp ring
if self.mp_degree > 1: if self.mp_degree > 1:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self.current_endpoint, self._startup_program,
self.mp_group_endpoints, self.current_endpoint,
self.mp_rank, self.mp_group_endpoints,
self.mp_ring_id, self.mp_rank,
False, self.mp_ring_id,
sync=False) False,
sync=False,
)
# sharding ring # sharding ring
if self.sharding_degree > 1: if self.sharding_degree > 1:
...@@ -769,7 +865,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -769,7 +865,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank, self.sharding_rank,
self.sharding_ring_id, self.sharding_ring_id,
False, False,
sync=False) sync=False,
)
# pp ring # pp ring
if self.pp_degree > 1: if self.pp_degree > 1:
...@@ -777,13 +874,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -777,13 +874,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# pure dp ring # pure dp ring
if self.dp_degree > 1: if self.dp_degree > 1:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self.current_endpoint, self._startup_program,
self.dp_group_endpoints, self.current_endpoint,
self.dp_rank, self.dp_group_endpoints,
self.dp_ring_id, self.dp_rank,
False, self.dp_ring_id,
sync=False) False,
sync=False,
)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -794,9 +893,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -794,9 +893,12 @@ class ShardingOptimizer(MetaOptimizerBase):
# step 3: get broadcast vars # step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params( self._broadcast_vars = self._shard.find_broadcast_params(
self._main_program.global_block()) self._main_program.global_block()
)
def _wait(self, ): def _wait(
self,
):
endpoints = self.global_endpoints[:] endpoints = self.global_endpoints[:]
current_endpoint = endpoints[self.global_rank] current_endpoint = endpoints[self.global_rank]
if self.global_rank == 0: if self.global_rank == 0:
...@@ -821,7 +923,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -821,7 +923,7 @@ class ShardingOptimizer(MetaOptimizerBase):
segment._end_idx = last_backward_op_idx segment._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)): for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx] op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize)) assert int(op.attr('op_role')) != int(OpRole.Optimize)
if self._sharding_segment_strategy == "segment_broadcast_MB": if self._sharding_segment_strategy == "segment_broadcast_MB":
if segment._param_mem >= self._broadcast_MB: if segment._param_mem >= self._broadcast_MB:
segment = self.collect_segment(segment, op_idx, block) segment = self.collect_segment(segment, op_idx, block)
...@@ -835,21 +937,27 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -835,21 +937,27 @@ class ShardingOptimizer(MetaOptimizerBase):
if ".cast_fp16@GRAD" not in input_name: if ".cast_fp16@GRAD" not in input_name:
continue continue
else: else:
input_name = input_name[:input_name. input_name = input_name[
find(".cast_fp16@GRAD")] : input_name.find(".cast_fp16@GRAD")
]
if input_name in self._backward_remain_anchors: if input_name in self._backward_remain_anchors:
segment = self.collect_segment( segment = self.collect_segment(
segment, op_idx, block) segment, op_idx, block
assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format( )
input_name) assert (
input_name not in self._forward_remain_anchors
), "segment anchor [{}] met twice !".format(
input_name
)
self._backward_remain_anchors.remove(input_name) self._backward_remain_anchors.remove(input_name)
self._forward_remain_anchors.append(input_name) self._forward_remain_anchors.append(input_name)
elif int(op.attr('op_role')) == int(OpRole.Forward): elif int(op.attr('op_role')) == int(OpRole.Forward):
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name in self._forward_remain_anchors: if output_name in self._forward_remain_anchors:
segment = self.collect_segment( segment = self.collect_segment(
segment, op_idx, block) segment, op_idx, block
)
self._forward_remain_anchors.remove(output_name) self._forward_remain_anchors.remove(output_name)
# find broadcast vars # find broadcast vars
...@@ -865,47 +973,49 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -865,47 +973,49 @@ class ShardingOptimizer(MetaOptimizerBase):
if self._shard.has_param(input_name): if self._shard.has_param(input_name):
broadcast_var_name = input_name broadcast_var_name = input_name
else: else:
broadcast_var_name = unique_name.generate(input_name + broadcast_var_name = unique_name.generate(
"@BroadCast") input_name + "@BroadCast"
)
segment._fill_constant_vars.append(broadcast_var_name) segment._fill_constant_vars.append(broadcast_var_name)
# (JZ-LIANG) should use Param base name ? # (JZ-LIANG) should use Param base name ?
broadcast_var_base_name = input_name broadcast_var_base_name = input_name
if "subprog" in broadcast_var_base_name: if "subprog" in broadcast_var_base_name:
# remove suffix # remove suffix
broadcast_var_base_name = broadcast_var_base_name[: broadcast_var_base_name = broadcast_var_base_name[
broadcast_var_base_name : broadcast_var_base_name.find(".subprog")
.find( ]
".subprog"
)]
var2broadcast_time[ var2broadcast_time[broadcast_var_base_name] = (
broadcast_var_base_name] = var2broadcast_time.get( var2broadcast_time.get(broadcast_var_base_name, 0) + 1
broadcast_var_base_name, 0) + 1 )
segment._param2broadcast[input_name] = broadcast_var_name segment._param2broadcast[input_name] = broadcast_var_name
segment._broadcast_vars.append( segment._broadcast_vars.append(
(broadcast_var_name, self._shard.device(input_name))) (broadcast_var_name, self._shard.device(input_name))
)
segment._param_mem += get_var_size( segment._param_mem += get_var_size(
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name)
)
# find reduce vars # find reduce vars
if self.pp_degree > 1 and self.pp_allreduce_in_optimize: if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize # place pipeline gradient allreduce in optimize
pass pass
else: else:
if is_backward_op(op) and \ if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0: if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0 assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i param, reduced_grad = (
+ op_role_var[i],
1] op_role_var[i + 1],
)
segment._allreduce_vars.append(reduced_grad) segment._allreduce_vars.append(reduced_grad)
assert (reduced_grad assert (
not in self._reduced_grads_to_param) reduced_grad not in self._reduced_grads_to_param
)
self._reduced_grads_to_param[reduced_grad] = param self._reduced_grads_to_param[reduced_grad] = param
# find cast op # find cast op
...@@ -920,29 +1030,40 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -920,29 +1030,40 @@ class ShardingOptimizer(MetaOptimizerBase):
self._segments.insert(0, segment) self._segments.insert(0, segment)
if self._sharding_segment_strategy == "segment_anchors": if self._sharding_segment_strategy == "segment_anchors":
assert len( assert (
self._forward_remain_anchors) == 0, "remain anchors {}".format( len(self._forward_remain_anchors) == 0
self._forward_remain_anchors) ), "remain anchors {}".format(self._forward_remain_anchors)
assert len( assert (
self._backward_remain_anchors) == 0, "remain anchors {}".format( len(self._backward_remain_anchors) == 0
self._backward_remain_anchors) ), "remain anchors {}".format(self._backward_remain_anchors)
if self._verbose: if self._verbose:
for varname in sorted(var2broadcast_time, for varname in sorted(
key=var2broadcast_time.get, var2broadcast_time, key=var2broadcast_time.get, reverse=True
reverse=True): ):
logger.info("Sharding broadcast: [{}] times [{}]".format( logger.info(
var2broadcast_time[varname], varname)) "Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname
)
)
for idx_ in range(len(self._segments)): for idx_ in range(len(self._segments)):
logger.info("segment [{}] :".format(idx_)) logger.info("segment [{}] :".format(idx_))
logger.info("start op: [{}] [{}]".format( logger.info(
block.ops[self._segments[idx_]._start_idx].desc.type(), "start op: [{}] [{}]".format(
block.ops[self._segments[idx_]. block.ops[self._segments[idx_]._start_idx].desc.type(),
_start_idx].desc.input_arg_names())) block.ops[
logger.info("end op: [{}] [{}]".format( self._segments[idx_]._start_idx
block.ops[self._segments[idx_]._end_idx].desc.type(), ].desc.input_arg_names(),
block.ops[ )
self._segments[idx_]._end_idx].desc.input_arg_names())) )
logger.info(
"end op: [{}] [{}]".format(
block.ops[self._segments[idx_]._end_idx].desc.type(),
block.ops[
self._segments[idx_]._end_idx
].desc.input_arg_names(),
)
)
return return
def _prune_main_program(self, block, shard, rings): def _prune_main_program(self, block, shard, rings):
...@@ -954,7 +1075,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -954,7 +1075,7 @@ class ShardingOptimizer(MetaOptimizerBase):
2. prune cast_fp32_to_fp16; update amp_infine_checking 2. prune cast_fp32_to_fp16; update amp_infine_checking
3. prune gradient_clip related; update global_norm_sum 3. prune gradient_clip related; update global_norm_sum
4. prune optimizer op + param + gradient 4. prune optimizer op + param + gradient
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, shard) weightdecay_helper.prune_weight_decay(block, shard)
...@@ -975,17 +1096,18 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -975,17 +1096,18 @@ class ShardingOptimizer(MetaOptimizerBase):
input_names = op.desc.input_arg_names() input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names() output_names = op.desc.output_arg_names()
# FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE # FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE
if op.type == "c_allreduce_sum" and \ if (
op.attr('use_model_parallel') is False: op.type == "c_allreduce_sum"
assert (len(output_names) == 1) and op.attr('use_model_parallel') is False
):
assert len(output_names) == 1
output_name = output_names[0] output_name = output_names[0]
reduced_grads.append(output_name) reduced_grads.append(output_name)
# prune optimizer state and param # prune optimizer state and param
pruned_opti_vars = [] pruned_opti_vars = []
for var_name in list(block.vars.keys()): for var_name in list(block.vars.keys()):
if shard.is_opti_var(var_name) and \ if shard.is_opti_var(var_name) and not shard.has_opt_var(var_name):
not shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name) pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars) program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
...@@ -996,17 +1118,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -996,17 +1118,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune # Prune
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_allreduce_sum",
"c_sync_comm_stream", "c_sync_comm_stream",
"c_calc_comm_stream", "c_calc_comm_stream",
"c_gen_nccl_id", "c_gen_nccl_id",
"c_comm_init", "c_comm_init",
'send_v2', 'send_v2',
'recv_v2', 'recv_v2',
]: ]:
pass pass
elif op.type == "conditional_block": elif op.type == "conditional_block":
assert (op.desc.has_attr("sub_block")) assert op.desc.has_attr("sub_block")
subblock_idx = op.desc.attr("sub_block").id subblock_idx = op.desc.attr("sub_block").id
subblock_deps = program_deps.get_sub_block_deps(subblock_idx) subblock_deps = program_deps.get_sub_block_deps(subblock_idx)
# only prune amp subblock # only prune amp subblock
...@@ -1022,7 +1144,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1022,7 +1144,8 @@ class ShardingOptimizer(MetaOptimizerBase):
reversed_output_vars.append(output_name) reversed_output_vars.append(output_name)
# prune # prune
for sub_op_idx, _ in reversed( for sub_op_idx, _ in reversed(
list(enumerate(subblock_deps._block.ops))): list(enumerate(subblock_deps._block.ops))
):
if subblock_deps.should_remove_op(sub_op_idx): if subblock_deps.should_remove_op(sub_op_idx):
subblock_deps.remove_op(sub_op_idx) subblock_deps.remove_op(sub_op_idx)
reversed_input_vars = [] reversed_input_vars = []
...@@ -1038,7 +1161,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1038,7 +1161,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# _should_removed_var: opt state not cur shard # _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx): if program_deps.should_remove_op(idx):
# NOTE(wangxi): need reserve all param in optimizer_sharding # NOTE(wangxi): need reserve all param in optimizer_sharding
reserved_vars = self._params if self._optimizer_sharding else None reserved_vars = (
self._params if self._optimizer_sharding else None
)
program_deps.remove_op(idx, reserved_vars) program_deps.remove_op(idx, reserved_vars)
# NOTE (JZ-LIANG) revise and unify logic here # NOTE (JZ-LIANG) revise and unify logic here
...@@ -1049,7 +1174,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1049,7 +1174,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# remove inputs that not on this card # remove inputs that not on this card
reserved_x = [] reserved_x = []
for var_name in op.desc.input("X"): for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name) if block.has_var(var_name):
reserved_x.append(var_name)
op.desc.set_input('X', reserved_x) op.desc.set_input('X', reserved_x)
block._sync_with_cpp() block._sync_with_cpp()
return return
...@@ -1059,7 +1185,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1059,7 +1185,7 @@ class ShardingOptimizer(MetaOptimizerBase):
add broadcast allreduce op add broadcast allreduce op
if enable gradient_merge, insert related ops if enable gradient_merge, insert related ops
if combined with pipeline(grad accumulate), if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role the grad allreduce should be done in optimize role
""" """
if len(self._segments) < 1: if len(self._segments) < 1:
...@@ -1072,175 +1198,280 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1072,175 +1198,280 @@ class ShardingOptimizer(MetaOptimizerBase):
# NOTE (JZ-LIANG) revise and unify logic here # NOTE (JZ-LIANG) revise and unify logic here
# fix the _end_idx for segments[-1] if pp is used. # fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1, for idx in range(
self._segments[-1]._start_idx - 1, -1): self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1,
-1,
):
op = block.ops[idx] op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum": if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1 if "MERGED" in op.output_arg_names[0]:
new_end_idx = idx + 1
elif op.type == "cast": elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1 if "@TMP" in op.output_arg_names[0]:
new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads( shard_allredue_vars = self._shard.filter_grads(
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: )
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( if (
shard_allredue_vars) >= 1: self.gradient_merge_mode != "sharding_gm"
insert_sync_comm_ops(block, self._segments[-1]._end_idx, or self._gradient_merge_acc_step <= 1
self.dp_ring_id, shard_allredue_vars) ):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_sync_comm_ops(
block,
self._segments[-1]._end_idx,
self.dp_ring_id,
shard_allredue_vars,
)
insert_allreduce_ops( insert_allreduce_ops(
block, block,
self._segments[-1]._end_idx, self._segments[-1]._end_idx,
self.dp_ring_id, self.dp_ring_id,
shard_allredue_vars, shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy) user_defined_strategy=self.user_defined_strategy,
)
# gradient merge # gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
self.create_persistable_gradients_and_insert_merge_ops( self.create_persistable_gradients_and_insert_merge_ops(
block, self._startup_program.global_block(), block,
self._segments[-1]._end_idx, shard_allredue_vars, self._startup_program.global_block(),
self._shard) self._segments[-1]._end_idx,
shard_allredue_vars,
insert_sync_comm_ops(block, self._segments[-1]._end_idx, self._shard,
self.sharding_ring_id, )
self._segments[-1]._allreduce_vars)
insert_sync_comm_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
)
# allreduce --> reduce # allreduce --> reduce
insert_reduce_ops(block, insert_reduce_ops(
self._segments[-1]._end_idx, block,
self.sharding_ring_id, self._segments[-1]._end_idx,
self._segments[-1]._allreduce_vars, self.sharding_ring_id,
self._shard, self._segments[-1]._allreduce_vars,
op_role=OpRole.Backward, self._shard,
use_calc_stream=False) op_role=OpRole.Backward,
use_calc_stream=False,
)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = (
idx - 1]._allreduce_vars if idx > 0 else [] self._segments[idx - 1]._allreduce_vars if idx > 0 else []
broadcast_vars = self._segments[ )
idx + broadcast_vars = (
1]._broadcast_vars if idx < len(self._segments) - 1 else [] self._segments[idx + 1]._broadcast_vars
fill_constant_vars = self._segments[ if idx < len(self._segments) - 1
idx + else []
2]._fill_constant_vars if idx < len(self._segments) - 2 else [] )
cast_ops = self._segments[ fill_constant_vars = (
idx + 2]._cast_ops if idx < len(self._segments) - 2 else {} self._segments[idx + 2]._fill_constant_vars
if idx < len(self._segments) - 2
else []
)
cast_ops = (
self._segments[idx + 2]._cast_ops
if idx < len(self._segments) - 2
else {}
)
for op_idx in reversed(range(segment._start_idx, segment._end_idx)): for op_idx in reversed(range(segment._start_idx, segment._end_idx)):
op = block.ops[op_idx] op = block.ops[op_idx]
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in segment._param2broadcast and \ if (
input_name != segment._param2broadcast[input_name]: input_name in segment._param2broadcast
op._rename_input(input_name, and input_name != segment._param2broadcast[input_name]
segment._param2broadcast[input_name]) ):
op._rename_input(
input_name, segment._param2broadcast[input_name]
)
for param_name, broadcast_name in segment._param2broadcast.items(): for param_name, broadcast_name in segment._param2broadcast.items():
if param_name != broadcast_name: if param_name != broadcast_name:
block.create_var( block.create_var(
name=broadcast_name, name=broadcast_name,
shape=self._main_program.global_block().var( shape=self._main_program.global_block()
param_name).shape, .var(param_name)
dtype=self._main_program.global_block().var( .shape,
param_name).dtype, dtype=self._main_program.global_block()
persistable=False) .var(param_name)
.dtype,
persistable=False,
)
# step1: remove cast ops # step1: remove cast ops
block._sync_with_cpp() block._sync_with_cpp()
segment._end_idx += FP16Utils.remove_cast_op( segment._end_idx += FP16Utils.remove_cast_op(
block, self._params, segment, 0) block, self._params, segment, 0
)
# step2: add Sync ops # step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars) shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if (
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( self.gradient_merge_mode != "sharding_gm"
shard_allredue_vars) >= 1: or self._gradient_merge_acc_step <= 1
insert_sync_comm_ops(block, segment._end_idx, ):
self.dp_ring_id, shard_allredue_vars) if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_sync_comm_ops(
block,
segment._end_idx,
self.dp_ring_id,
shard_allredue_vars,
)
broad_cast_vars = [x[0] for x in broadcast_vars] broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0: if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(
self.sharding_ring_id, block,
broad_cast_vars) segment._end_idx,
self.sharding_ring_id,
broad_cast_vars,
)
else: else:
comm_dep_vars = allreduce_vars + [ comm_dep_vars = allreduce_vars + [
x[0] for x in broadcast_vars x[0] for x in broadcast_vars
] ]
if len(comm_dep_vars) > 0: if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(
self.sharding_ring_id, block,
comm_dep_vars) segment._end_idx,
self.sharding_ring_id,
comm_dep_vars,
)
# gradient merge # gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
broad_cast_vars = [x[0] for x in broadcast_vars] broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0: if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(
self.sharding_ring_id, broad_cast_vars) block,
segment._end_idx,
self.sharding_ring_id,
broad_cast_vars,
)
calc_dep_vars = fill_constant_vars + [ calc_dep_vars = (
k for k, v in cast_ops.items() fill_constant_vars
] + self._segments[idx]._allreduce_vars + [k for k, v in cast_ops.items()]
+ self._segments[idx]._allreduce_vars
)
if len(calc_dep_vars) > 0: if len(calc_dep_vars) > 0:
insert_sync_calc_op(block, segment._end_idx, insert_sync_calc_op(
[calc_dep_vars[-1]]) block, segment._end_idx, [calc_dep_vars[-1]]
)
# step3: insert `fill_constant` ops # step3: insert `fill_constant` ops
insert_fill_constant_ops(block, segment._end_idx, insert_fill_constant_ops(
fill_constant_vars) block, segment._end_idx, fill_constant_vars
)
# step4: add `cast` ops # step4: add `cast` ops
insert_cast_ops(block, segment._end_idx, cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops # step5: add broadcast ops
# gradient merge # gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: if (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
self.create_persistable_gradients_and_insert_merge_ops( self.create_persistable_gradients_and_insert_merge_ops(
block, self._startup_program.global_block(), block,
segment._start_idx, shard_allredue_vars, self._shard) self._startup_program.global_block(),
segment._start_idx,
shard_allredue_vars,
self._shard,
)
insert_broadcast_ops(block, segment._start_idx, insert_broadcast_ops(
self.sharding_ring_id, broadcast_vars) block, segment._start_idx, self.sharding_ring_id, broadcast_vars
)
# step6: add all_reduce ops # step6: add all_reduce ops
# dp # dp
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if (
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( self.gradient_merge_mode != "sharding_gm"
shard_allredue_vars) >= 1: or self._gradient_merge_acc_step <= 1
):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_allreduce_ops( insert_allreduce_ops(
block, block,
segment._start_idx, segment._start_idx,
self.dp_ring_id, self.dp_ring_id,
shard_allredue_vars, shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy) user_defined_strategy=self.user_defined_strategy,
insert_sync_comm_ops(block, segment._start_idx, )
self.sharding_ring_id, allreduce_vars) insert_sync_comm_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
)
# gradient merge # gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: elif (
insert_sync_comm_ops(block, segment._start_idx, self.gradient_merge_mode == "sharding_gm"
self.sharding_ring_id, allreduce_vars) and self._gradient_merge_acc_step > 1
):
insert_sync_comm_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
)
# sharding # sharding
# allreduce --> reduce # allreduce --> reduce
# TODO temp change # TODO temp change
if len(allreduce_vars) > 0: if len(allreduce_vars) > 0:
insert_reduce_ops(block, insert_reduce_ops(
segment._start_idx, block,
self.sharding_ring_id, segment._start_idx,
allreduce_vars, self.sharding_ring_id,
self._shard, allreduce_vars,
op_role=OpRole.Backward, self._shard,
use_calc_stream=False) op_role=OpRole.Backward,
use_calc_stream=False,
)
block._sync_with_cpp() block._sync_with_cpp()
if self._segments[0]._broadcast_vars: if self._segments[0]._broadcast_vars:
broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars]
insert_sync_comm_ops(block, self._segments[0]._start_idx, insert_sync_comm_ops(
self.sharding_ring_id, broadcast_vars) block,
insert_broadcast_ops(block, self._segments[0]._start_idx, self._segments[0]._start_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[0]._broadcast_vars) broadcast_vars,
)
insert_broadcast_ops(
block,
self._segments[0]._start_idx,
self.sharding_ring_id,
self._segments[0]._broadcast_vars,
)
fill_constant_vars = [] fill_constant_vars = []
for x in self._segments[:2]: for x in self._segments[:2]:
...@@ -1254,12 +1485,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1254,12 +1485,14 @@ class ShardingOptimizer(MetaOptimizerBase):
calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()]
if fill_constant_vars or cast_ops: if fill_constant_vars or cast_ops:
insert_sync_calc_op(block, self._segments[0]._start_idx, insert_sync_calc_op(
[calc_deps_vars[-1]]) block, self._segments[0]._start_idx, [calc_deps_vars[-1]]
)
if fill_constant_vars: if fill_constant_vars:
insert_fill_constant_ops(block, self._segments[0]._start_idx, insert_fill_constant_ops(
fill_constant_vars) block, self._segments[0]._start_idx, fill_constant_vars
)
if cast_ops: if cast_ops:
insert_cast_ops(block, self._segments[0]._start_idx, cast_ops) insert_cast_ops(block, self._segments[0]._start_idx, cast_ops)
...@@ -1273,7 +1506,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1273,7 +1506,7 @@ class ShardingOptimizer(MetaOptimizerBase):
continue continue
if self._optimizer_sharding and shard.is_param(output_name): if self._optimizer_sharding and shard.is_param(output_name):
continue continue
#TODO why do we remove op, when only one var is removed # TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False) block._remove_op(idx, sync=False)
break break
...@@ -1295,23 +1528,36 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1295,23 +1528,36 @@ class ShardingOptimizer(MetaOptimizerBase):
pp: 4 pp: 4
pp-pair: >= 20 pp-pair: >= 20
if one parallelism is not enable: -1 if one parallelism is not enable: -1
and only support parallelism hierarchy: mp --> sharding --> pp --> dp and only support parallelism hierarchy: mp --> sharding --> pp --> dp
""" """
# step 1: initialize nccl # step 1: initialize nccl
self.global_word_size = self.role_maker._worker_num() self.global_word_size = self.role_maker._worker_num()
self.global_rank = self.role_maker._worker_index() self.global_rank = self.role_maker._worker_index()
self.global_endpoints = self.role_maker._get_trainer_endpoints() self.global_endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.global_endpoints[self.global_rank] self.current_endpoint = self.global_endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker, self._collective_helper = CollectiveHelper(
nrings=self._nrings_sharding) self.role_maker, nrings=self._nrings_sharding
assert self.global_word_size % self.mp_degree == 0, \ )
"global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree) assert (
assert self.global_word_size % self.sharding_degree == 0, \ self.global_word_size % self.mp_degree == 0
"global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree) ), "global_word_size: {} should be divisible to the mp_degree: {}".format(
assert self.global_word_size % self.pp_degree == 0, \ self.global_word_size, self.mp_degree
"global_word_size: {} should be divisible to the pp_degree: {}".format(self.global_word_size, self.pp_degree) )
assert self.global_word_size % self.dp_degree == 0, \ assert (
"global_word_size: {} should be divisible to the dp_degree: {}".format(self.global_word_size, self.dp_degree) self.global_word_size % self.sharding_degree == 0
), "global_word_size: {} should be divisible to the sharding_degree: {}".format(
self.global_word_size, self.sharding_degree
)
assert (
self.global_word_size % self.pp_degree == 0
), "global_word_size: {} should be divisible to the pp_degree: {}".format(
self.global_word_size, self.pp_degree
)
assert (
self.global_word_size % self.dp_degree == 0
), "global_word_size: {} should be divisible to the dp_degree: {}".format(
self.global_word_size, self.dp_degree
)
# mp group # mp group
if self.mp_degree > 1: if self.mp_degree > 1:
...@@ -1319,14 +1565,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1319,14 +1565,16 @@ class ShardingOptimizer(MetaOptimizerBase):
self.mp_rank = self.global_rank % self.mp_degree self.mp_rank = self.global_rank % self.mp_degree
self.mp_group_id = self.global_rank // self.mp_degree self.mp_group_id = self.global_rank // self.mp_degree
self.mp_group_endpoints = [ self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints) ep
for idx, ep in enumerate(self.global_endpoints)
if idx // self.mp_degree == self.mp_group_id if idx // self.mp_degree == self.mp_group_id
] ]
assert self.current_endpoint in self.mp_group_endpoints assert self.current_endpoint in self.mp_group_endpoints
assert len( assert (
self.mp_group_endpoints len(self.mp_group_endpoints) == self.mp_degree
) == self.mp_degree, "num of mp worker in group is [{}], but mp group size is [{}]".format( ), "num of mp worker in group is [{}], but mp group size is [{}]".format(
len(self.mp_group_endpoints), self.mp_degree) len(self.mp_group_endpoints), self.mp_degree
)
else: else:
self.mp_degree = 1 self.mp_degree = 1
self.mp_ring_id = -1 self.mp_ring_id = -1
...@@ -1337,23 +1585,28 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1337,23 +1585,28 @@ class ShardingOptimizer(MetaOptimizerBase):
# sharding # sharding
if self.sharding_degree > 1: if self.sharding_degree > 1:
self.sharding_ring_id = 1 self.sharding_ring_id = 1
self.sharding_rank = (self.global_rank // self.sharding_rank = (
self.mp_degree) % self.sharding_degree self.global_rank // self.mp_degree
self.sharding_group_id = self.global_rank // (self.mp_degree * ) % self.sharding_degree
self.sharding_degree) self.sharding_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree
)
# mp + sharding + ... # mp + sharding + ...
if self.mp_degree > 1: if self.mp_degree > 1:
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints) ep
if (idx // (self.mp_degree * self.sharding_degree)) == self. for idx, ep in enumerate(self.global_endpoints)
sharding_group_id and idx % self.mp_degree == self.mp_rank if (idx // (self.mp_degree * self.sharding_degree))
== self.sharding_group_id
and idx % self.mp_degree == self.mp_rank
] ]
# sharding + ... # sharding + ...
else: else:
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints) ep
if (idx // (self.mp_degree * self.sharding_degree) for idx, ep in enumerate(self.global_endpoints)
) == self.sharding_group_id if (idx // (self.mp_degree * self.sharding_degree))
== self.sharding_group_id
] ]
assert self.current_endpoint in self.sharding_group_endpoints assert self.current_endpoint in self.sharding_group_endpoints
else: else:
...@@ -1368,20 +1621,28 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1368,20 +1621,28 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_pair_ring_id = 20 self.pp_pair_ring_id = 20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3 # pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self.pp_ring_id = 4 self.pp_ring_id = 4
self.pp_rank = self.global_rank // (self.sharding_degree * self.pp_rank = (
self.mp_degree) % self.pp_degree self.global_rank
// (self.sharding_degree * self.mp_degree)
% self.pp_degree
)
# (NOTE): Already adjust for (outter-pure) dp # (NOTE): Already adjust for (outter-pure) dp
self.pp_group_id = self.global_rank // ( self.pp_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree * self.pp_degree) self.mp_degree * self.sharding_degree * self.pp_degree
)
pp_first_stage_idx = self.global_rank % ( pp_first_stage_idx = self.global_rank % (
self.sharding_degree * self.mp_degree) + self.pp_group_id * ( self.sharding_degree * self.mp_degree
self.mp_degree * self.sharding_degree * self.pp_degree) ) + self.pp_group_id * (
self.mp_degree * self.sharding_degree * self.pp_degree
)
pp_stage_offset = self.sharding_degree * self.mp_degree pp_stage_offset = self.sharding_degree * self.mp_degree
self.pp_group_endpoints = [] self.pp_group_endpoints = []
for i in range(self.pp_degree): for i in range(self.pp_degree):
self.pp_group_endpoints.append( self.pp_group_endpoints.append(
self.global_endpoints[pp_first_stage_idx + self.global_endpoints[
pp_stage_offset * i]) pp_first_stage_idx + pp_stage_offset * i
]
)
assert self.current_endpoint in self.pp_group_endpoints assert self.current_endpoint in self.pp_group_endpoints
else: else:
self.pp_ring_id = -1 self.pp_ring_id = -1
...@@ -1397,29 +1658,48 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1397,29 +1658,48 @@ class ShardingOptimizer(MetaOptimizerBase):
# sharding-hybrid-dp as one senario of outter-pure-dp # sharding-hybrid-dp as one senario of outter-pure-dp
local_pp_degree = self.pp_degree local_pp_degree = self.pp_degree
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
assert self.pp_degree == 2, ("For manually set pipeline, only " assert self.pp_degree == 2, (
"pp_degree = 2 is supported.") "For manually set pipeline, only " "pp_degree = 2 is supported."
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.dp_degree, \ )
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( assert (
self.global_word_size, self.mp_degree, self.sharding_degree, self.dp_degree) self.global_word_size
== self.mp_degree * self.sharding_degree * self.dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
self.global_word_size,
self.mp_degree,
self.sharding_degree,
self.dp_degree,
)
local_pp_degree = 1 local_pp_degree = 1
else: else:
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( assert (
self.mp_degree, self.sharding_degree, self.pp_degree, self.global_word_size
self.dp_degree, self.global_word_size) == self.mp_degree
* self.sharding_degree
* self.pp_degree
* self.dp_degree
), "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree,
self.global_word_size,
)
if self.dp_degree > 1: if self.dp_degree > 1:
self.dp_ring_id = 2 self.dp_ring_id = 2
self.dp_rank = self.global_rank // ( self.dp_rank = self.global_rank // (
self.sharding_degree * self.mp_degree * local_pp_degree) self.sharding_degree * self.mp_degree * local_pp_degree
)
dp_first_rank_idx = self.global_rank % ( dp_first_rank_idx = self.global_rank % (
self.sharding_degree * self.mp_degree * local_pp_degree) self.sharding_degree * self.mp_degree * local_pp_degree
dp_offset = (self.sharding_degree * self.mp_degree * )
local_pp_degree) dp_offset = self.sharding_degree * self.mp_degree * local_pp_degree
self.dp_group_endpoints = [] self.dp_group_endpoints = []
for i in range(self.dp_degree): for i in range(self.dp_degree):
self.dp_group_endpoints.append( self.dp_group_endpoints.append(
self.global_endpoints[dp_first_rank_idx + dp_offset * i]) self.global_endpoints[dp_first_rank_idx + dp_offset * i]
)
assert self.current_endpoint in self.dp_group_endpoints assert self.current_endpoint in self.dp_group_endpoints
logger.info("Hybrid DP mode turn on !") logger.info("Hybrid DP mode turn on !")
else: else:
...@@ -1448,8 +1728,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1448,8 +1728,9 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("sharding group size: {}".format(self.sharding_degree)) logger.info("sharding group size: {}".format(self.sharding_degree))
logger.info("sharding rank: {}".format(self.sharding_rank)) logger.info("sharding rank: {}".format(self.sharding_rank))
logger.info("sharding group id: {}".format(self.sharding_group_id)) logger.info("sharding group id: {}".format(self.sharding_group_id))
logger.info("sharding group endpoints: {}".format( logger.info(
self.sharding_group_endpoints)) "sharding group endpoints: {}".format(self.sharding_group_endpoints)
)
logger.info("sharding ring id: {}".format(self.sharding_ring_id)) logger.info("sharding ring id: {}".format(self.sharding_ring_id))
logger.info("#####" * 6) logger.info("#####" * 6)
...@@ -1462,15 +1743,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1462,15 +1743,15 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("pure dp group size: {}".format(self.dp_degree)) logger.info("pure dp group size: {}".format(self.dp_degree))
logger.info("pure dp rank: {}".format(self.dp_rank)) logger.info("pure dp rank: {}".format(self.dp_rank))
logger.info("pure dp group endpoints: {}".format( logger.info(
self.dp_group_endpoints)) "pure dp group endpoints: {}".format(self.dp_group_endpoints)
)
logger.info("pure dp ring id: {}".format(self.dp_ring_id)) logger.info("pure dp ring id: {}".format(self.dp_ring_id))
logger.info("#####" * 6) logger.info("#####" * 6)
return return
def _recreate_not_persist_param_as_var(self): def _recreate_not_persist_param_as_var(self):
def recreate_not_persist_param_as_var(program): def recreate_not_persist_param_as_var(program):
block = program.global_block() block = program.global_block()
params = block.all_parameters() params = block.all_parameters()
...@@ -1494,14 +1775,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1494,14 +1775,16 @@ class ShardingOptimizer(MetaOptimizerBase):
is_distributed = param.is_distributed is_distributed = param.is_distributed
block._remove_var(name, sync=False) block._remove_var(name, sync=False)
var = block.create_var(name=name, var = block.create_var(
shape=shape, name=name,
dtype=dtype, shape=shape,
type=type, dtype=dtype,
lod_level=lod_level, type=type,
stop_gradient=stop_gradient, lod_level=lod_level,
trainable=trainable, stop_gradient=stop_gradient,
persistable=False) trainable=trainable,
persistable=False,
)
if have_dist_attr: if have_dist_attr:
var.is_distributed = is_distributed var.is_distributed = is_distributed
...@@ -1516,6 +1799,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1516,6 +1799,13 @@ class ShardingOptimizer(MetaOptimizerBase):
identical when hybrid-dp is used, and the initialization of identical when hybrid-dp is used, and the initialization of
not distributed param between mp group to be identical. not distributed param between mp group to be identical.
""" """
def _find_master_param(all_vars_name, param_name):
for var_name in all_vars_name:
if param_name in var_name and "fp32_master" in var_name:
return var_name
return None
if self.dp_degree <= 1 and self.mp_degree <= 1: if self.dp_degree <= 1 and self.mp_degree <= 1:
return return
...@@ -1536,8 +1826,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1536,8 +1826,10 @@ class ShardingOptimizer(MetaOptimizerBase):
if op.type == 'c_broadcast': if op.type == 'c_broadcast':
broadcast_params.add(op.desc.output_arg_names()[0]) broadcast_params.add(op.desc.output_arg_names()[0])
all_vars_name = startup_block.vars
for param in params_name: for param in params_name:
if param in broadcast_params: continue if param in broadcast_params:
continue
rings = [] rings = []
# need sync not distributed param in mp group # need sync not distributed param in mp group
...@@ -1547,30 +1839,51 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1547,30 +1839,51 @@ class ShardingOptimizer(MetaOptimizerBase):
rings.append(self.dp_ring_id) rings.append(self.dp_ring_id)
for ring in rings: for ring in rings:
startup_block.append_op(type='c_broadcast', startup_block.append_op(
inputs={'X': param}, type='c_broadcast',
outputs={'Out': param}, inputs={'X': param},
attrs={ outputs={'Out': param},
'ring_id': ring, attrs={
'root': 0, 'ring_id': ring,
'use_calc_stream': True, 'root': 0,
OP_ROLE_KEY: OpRole.Forward 'use_calc_stream': True,
}) OP_ROLE_KEY: OpRole.Forward,
},
)
# Broadcast the master weight at the same time for AMP-O2 training.
master_param = _find_master_param(all_vars_name, param)
if master_param is not None:
startup_block.append_op(
type='c_broadcast',
inputs={'X': master_param},
outputs={'Out': master_param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
},
)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# sharding gradient merge # sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops( def create_persistable_gradients_and_insert_merge_ops(
self, main_block, startup_block, insert_idx, grad_names, shard): self, main_block, startup_block, insert_idx, grad_names, shard
):
for grad_name in grad_names: for grad_name in grad_names:
assert get_grad_device( assert (
grad_name, shard get_grad_device(grad_name, shard) == shard.worker_idx
) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format( ), "try to merge gradient not belong to current shard: [{}]".format(
grad_name) grad_name
)
persistable_grad_name = grad_name + '@GradiantMerge' persistable_grad_name = grad_name + '@GradiantMerge'
assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format( assert (
grad_name) grad_name not in self._grad2merged_grad
), "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(
grad_name
)
self._grad2merged_grad[grad_name] = persistable_grad_name self._grad2merged_grad[grad_name] = persistable_grad_name
grad_var = main_block.var(grad_name) grad_var = main_block.var(grad_name)
# create var # create var
...@@ -1578,36 +1891,38 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1578,36 +1891,38 @@ class ShardingOptimizer(MetaOptimizerBase):
name=persistable_grad_name, name=persistable_grad_name,
shape=grad_var.shape, shape=grad_var.shape,
dtype=grad_var.dtype, dtype=grad_var.dtype,
persistable=True) persistable=True,
)
startup_gradient_merge_var = startup_block.create_var( startup_gradient_merge_var = startup_block.create_var(
name=persistable_grad_name, name=persistable_grad_name,
shape=grad_var.shape, shape=grad_var.shape,
dtype=grad_var.dtype, dtype=grad_var.dtype,
persistable=True) persistable=True,
)
# merge gradient # merge gradient
main_block._insert_op_without_sync( main_block._insert_op_without_sync(
insert_idx, insert_idx,
type="elementwise_add", type="elementwise_add",
inputs={ inputs={'X': grad_name, 'Y': gradient_merge_var},
'X': grad_name,
'Y': gradient_merge_var
},
outputs={'Out': gradient_merge_var}, outputs={'Out': gradient_merge_var},
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False, 'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
)
# startup initialization # startup initialization
startup_block.append_op(type="fill_constant", startup_block.append_op(
outputs={"Out": startup_gradient_merge_var}, type="fill_constant",
attrs={ outputs={"Out": startup_gradient_merge_var},
"shape": grad_var.shape, attrs={
"dtype": grad_var.dtype, "shape": grad_var.shape,
"value": float(0), "dtype": grad_var.dtype,
}) "value": float(0),
},
)
main_block._sync_with_cpp() main_block._sync_with_cpp()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -1620,14 +1935,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1620,14 +1935,17 @@ class ShardingOptimizer(MetaOptimizerBase):
value=int(self._gradient_merge_acc_step), value=int(self._gradient_merge_acc_step),
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True,
)
zero_var = layers.create_global_var(name="gradient_merge_zero", zero_var = layers.create_global_var(
shape=[1], name="gradient_merge_zero",
value=int(0), shape=[1],
dtype='int32', value=int(0),
persistable=True, dtype='int32',
force_cpu=True) persistable=True,
force_cpu=True,
)
# Add step var & cond var # Add step var & cond var
current_step_var = layers.create_global_var( current_step_var = layers.create_global_var(
...@@ -1636,42 +1954,40 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1636,42 +1954,40 @@ class ShardingOptimizer(MetaOptimizerBase):
value=int(0), value=int(0),
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True,
)
cond_var = main_block.create_var(name="gradient_merge_cond", cond_var = main_block.create_var(
shape=[1], name="gradient_merge_cond", shape=[1], dtype='bool'
dtype='bool') )
with device_guard("cpu"): with device_guard("cpu"):
# step_var = (step_var + 1) % k_step # step_var = (step_var + 1) % k_step
main_block.append_op(type='increment', main_block.append_op(
inputs={'X': [current_step_var]}, type='increment',
outputs={'Out': [current_step_var]}, inputs={'X': [current_step_var]},
attrs={ outputs={'Out': [current_step_var]},
'step': float(1), attrs={'step': float(1), OP_ROLE_KEY: OpRole.Optimize},
OP_ROLE_KEY: OpRole.Optimize )
})
main_block.append_op(
main_block.append_op(type='elementwise_mod', type='elementwise_mod',
inputs={ inputs={'X': current_step_var, 'Y': acc_step_var},
'X': current_step_var, outputs={'Out': current_step_var},
'Y': acc_step_var attrs={
}, 'axis': -1,
outputs={'Out': current_step_var}, OP_ROLE_KEY: OpRole.Optimize,
attrs={ 'use_mkldnn': False,
'axis': -1, },
OP_ROLE_KEY: OpRole.Optimize, )
'use_mkldnn': False
})
# cond_var = (step_var == 0) # cond_var = (step_var == 0)
main_block.append_op(type='equal', main_block.append_op(
inputs={ type='equal',
'X': current_step_var, inputs={'X': current_step_var, 'Y': zero_var},
'Y': zero_var outputs={'Out': cond_var},
}, attrs={OP_ROLE_KEY: OpRole.Optimize},
outputs={'Out': cond_var}, )
attrs={OP_ROLE_KEY: OpRole.Optimize})
# paddle.static.Print(current_step_var, message="in FWBW last conditional") # paddle.static.Print(current_step_var, message="in FWBW last conditional")
return cond_var return cond_var
...@@ -1681,7 +1997,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1681,7 +1997,7 @@ class ShardingOptimizer(MetaOptimizerBase):
grad@gradientmerge / acc_step grad@gradientmerge / acc_step
re-create all optimize ops of origin main block and rename them re-create all optimize ops of origin main block and rename them
cast(backward) cast(backward)
amp amp
clip clip
opt opt
# fill constant grad@gradientmerge # fill constant grad@gradientmerge
...@@ -1698,35 +2014,37 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1698,35 +2014,37 @@ class ShardingOptimizer(MetaOptimizerBase):
# allreduce grad@gradientmerge # allreduce grad@gradientmerge
if self.hybrid_dp: if self.hybrid_dp:
assert self.dp_ring_id >= 0, "dp_ring_id should larger than 0 when in sharding&DP mode" assert (
self.dp_ring_id >= 0
), "dp_ring_id should larger than 0 when in sharding&DP mode"
for grad, merged_grad in self._grad2merged_grad.items(): for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad) merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='c_allreduce_sum', cur_block.append_op(
inputs={'X': merged_grad_var}, type='c_allreduce_sum',
outputs={'Out': merged_grad_var}, inputs={'X': merged_grad_var},
attrs={ outputs={'Out': merged_grad_var},
'ring_id': self.dp_ring_id, attrs={
'use_calc_stream': True, 'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Optimize 'use_calc_stream': True,
}) OP_ROLE_KEY: OpRole.Optimize,
},
)
# grad@gradientmerge / acc_step # grad@gradientmerge / acc_step
for grad, merged_grad in self._grad2merged_grad.items(): for grad, merged_grad in self._grad2merged_grad.items():
# grad /= k_steps # grad /= k_steps
merged_grad_var = main_block.var(merged_grad) merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='scale', cur_block.append_op(
inputs={'X': merged_grad_var}, type='scale',
outputs={'Out': merged_grad_var}, inputs={'X': merged_grad_var},
attrs={ outputs={'Out': merged_grad_var},
'scale': attrs={
1.0 / float(self._gradient_merge_acc_step), 'scale': 1.0 / float(self._gradient_merge_acc_step),
'bias': 'bias': 0.0,
0.0, 'bias_after_scale': False,
'bias_after_scale': OP_ROLE_KEY: OpRole.Optimize,
False, },
OP_ROLE_KEY: )
OpRole.Optimize
})
# re-create optimize ops # re-create optimize ops
already_moved_var_names = [] already_moved_var_names = []
...@@ -1737,15 +2055,19 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1737,15 +2055,19 @@ class ShardingOptimizer(MetaOptimizerBase):
for input_name in new_op_desc.input_arg_names(): for input_name in new_op_desc.input_arg_names():
if input_name in self._grad2merged_grad: if input_name in self._grad2merged_grad:
new_op_desc._rename_input( new_op_desc._rename_input(
input_name, self._grad2merged_grad[input_name]) input_name, self._grad2merged_grad[input_name]
)
for output_name in new_op_desc.output_arg_names(): for output_name in new_op_desc.output_arg_names():
if output_name in self._grad2merged_grad: if output_name in self._grad2merged_grad:
new_op_desc._rename_output( new_op_desc._rename_output(
output_name, self._grad2merged_grad[output_name]) output_name, self._grad2merged_grad[output_name]
)
# move non temp optimize vars from block0 to cond block # move non temp optimize vars from block0 to cond block
if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys( if (
output_name not in already_moved_var_names
and output_name not in self._grad2merged_grad.keys()
): ):
var_ = self._main_program.global_block().var(output_name) var_ = self._main_program.global_block().var(output_name)
if not var_.persistable: if not var_.persistable:
...@@ -1754,11 +2076,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1754,11 +2076,14 @@ class ShardingOptimizer(MetaOptimizerBase):
shape_ = var_.shape shape_ = var_.shape
type_ = var_.dtype type_ = var_.dtype
self._main_program.global_block()._remove_var( self._main_program.global_block()._remove_var(
var_.name, sync=False) var_.name, sync=False
self.cond_block.create_var(name=name_, )
shape=shape_, self.cond_block.create_var(
dtype=type_, name=name_,
persistable=False) shape=shape_,
dtype=type_,
persistable=False,
)
already_moved_var_names.append(name_) already_moved_var_names.append(name_)
self._main_program.global_block()._sync_with_cpp() self._main_program.global_block()._sync_with_cpp()
...@@ -1767,14 +2092,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1767,14 +2092,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# fill zero to grad@gradientmerge # fill zero to grad@gradientmerge
for grad, merged_grad in self._grad2merged_grad.items(): for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad) merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='fill_constant', cur_block.append_op(
outputs={'Out': merged_grad_var}, type='fill_constant',
attrs={ outputs={'Out': merged_grad_var},
"shape": merged_grad_var.shape, attrs={
"dtype": merged_grad_var.dtype, "shape": merged_grad_var.shape,
"value": float(0), "dtype": merged_grad_var.dtype,
OP_ROLE_KEY: OpRole.Optimize "value": float(0),
}) OP_ROLE_KEY: OpRole.Optimize,
},
)
# lr_var = main_block.var("gradient_merge_current_step") # lr_var = main_block.var("gradient_merge_current_step")
# paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") # paddle.static.Print(lr_var, message="in OPTIMIZE last conditional")
...@@ -1786,7 +2113,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1786,7 +2113,10 @@ class ShardingOptimizer(MetaOptimizerBase):
create cond block create cond block
""" """
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if (
self.gradient_merge_mode != "sharding_gm"
or self._gradient_merge_acc_step <= 1
):
return return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
...@@ -1805,7 +2135,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1805,7 +2135,8 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(op_idx, sync=False) main_block._remove_op(op_idx, sync=False)
tmp_copy_block._sync_with_cpp() tmp_copy_block._sync_with_cpp()
self.original_optimize_ops_desc = list( self.original_optimize_ops_desc = list(
reversed(self.original_optimize_ops_desc)) reversed(self.original_optimize_ops_desc)
)
# back to block 0 # back to block 0
self._main_program._rollback() self._main_program._rollback()
...@@ -1822,18 +2153,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1822,18 +2153,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# cond op # cond op
step_scope = self._main_program.global_block().create_var( step_scope = self._main_program.global_block().create_var(
type=core.VarDesc.VarType.STEP_SCOPES) type=core.VarDesc.VarType.STEP_SCOPES
)
conditional_block_op = self._main_program.global_block().append_op( conditional_block_op = self._main_program.global_block().append_op(
type='conditional_block', type='conditional_block',
inputs={ inputs={
'Cond': cond, 'Cond': cond,
'Input': [], 'Input': [],
}, },
outputs={ outputs={'Out': [], 'Scope': [step_scope]},
'Out': [],
'Scope': [step_scope]
},
attrs={ attrs={
'sub_block': cond_block, 'sub_block': cond_block,
'is_scalar_condition': True, 'is_scalar_condition': True,
}) },
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册