未验证 提交 3576e49c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy
上级 369a235d
...@@ -60,14 +60,10 @@ class Engine: ...@@ -60,14 +60,10 @@ class Engine:
strategy=None, strategy=None,
user_tuning_config=None): user_tuning_config=None):
self.model = model self.model = model
self.strategy = strategy or fleet.DistributedStrategy()
self.inputs_spec = self._validate_spec(inputs_spec) self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec) self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster self.cluster = cluster or get_default_cluster()
if self.cluster is None:
self.cluster = get_default_cluster()
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._user_tuning_config = user_tuning_config self._user_tuning_config = user_tuning_config
self._executor = None self._executor = None
...@@ -433,7 +429,7 @@ class Engine: ...@@ -433,7 +429,7 @@ class Engine:
break break
train_logs["step: {:d} "] = step train_logs["step: {:d} "] = step
if lr_scheduler is not None: if lr_scheduler is not None and step % self.k_steps == 0:
lr_scheduler.step() lr_scheduler.step()
try: try:
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr() train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
...@@ -551,6 +547,12 @@ class Engine: ...@@ -551,6 +547,12 @@ class Engine:
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None): collate_fn=None):
if self.strategy.gradient_merge and batch_size is not None:
assert batch_size % self.k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self.k_steps)
batch_size //= self.k_steps
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
...@@ -612,6 +614,9 @@ class Engine: ...@@ -612,6 +614,9 @@ class Engine:
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = to_list(specs) specs = to_list(specs)
self.k_steps = 1
if self.strategy.gradient_merge:
self.k_steps = self.strategy.gradient_merge_configs['k_steps']
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec) assert isinstance(spec, InputSpec)
...@@ -619,6 +624,12 @@ class Engine: ...@@ -619,6 +624,12 @@ class Engine:
raise ValueError( raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}." "Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec)) .format(i, spec))
if self.k_steps > 1:
shape = list(spec.shape)
assert shape[0] % self.k_steps == 0, \
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self.k_steps)
shape[0] //= self.k_steps
spec.shape = shape
return specs return specs
def _is_local_var(self, var): def _is_local_var(self, var):
......
...@@ -84,7 +84,7 @@ class AutoParallelizer: ...@@ -84,7 +84,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None # self._pass_context = None
def _remove_distributed_attrs(self, main_program): def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix() suffix = core.kAutoParallelSuffix()
...@@ -143,10 +143,11 @@ class AutoParallelizer: ...@@ -143,10 +143,11 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy( optimize_ops = optimizer.apply_gradients(params_grads)
self._optimizer).apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
# update completion # update completion
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
...@@ -165,6 +166,15 @@ class AutoParallelizer: ...@@ -165,6 +166,15 @@ class AutoParallelizer:
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
if self._dist_strategy.gradient_merge: if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs) config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
......
...@@ -230,9 +230,9 @@ class Parallelizer: ...@@ -230,9 +230,9 @@ class Parallelizer:
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization # GradClip is train-only optimization
if self._mode == "train": if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs) config = copy.deepcopy(self._strategy.sharding_configs)
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
......
...@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale', new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -575,18 +575,18 @@ class FP16Pass(AMPPass): ...@@ -575,18 +575,18 @@ class FP16Pass(AMPPass):
) or self.get_attr("init_loss_scaling") != 1.0: ) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = [] found_infs = []
if fp32_grads: if fp32_grads:
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient( _, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32", fp32_grads, self._loss_scaling, "@fp32",
self.dist_context) self.dist_context)
found_infs.append(found_inf_fp32) found_infs.append(found_inf_fp32)
if fp16_grads: if fp16_grads:
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient( _, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16", fp16_grads, self._loss_scaling, "@fp16",
self.dist_context) self.dist_context)
found_infs.append(found_inf_fp16) found_infs.append(found_inf_fp16)
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
block = main_program.global_block() block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs) all_infs = paddle.fluid.layers.concat(found_infs)
...@@ -608,7 +608,7 @@ class FP16Pass(AMPPass): ...@@ -608,7 +608,7 @@ class FP16Pass(AMPPass):
block, self.dist_context) block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard(): with main_program._optimized_guard([]):
if fp32_grads: if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf) self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads: if fp16_grads:
......
...@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super(ClipGradByGloblNormPass, self).__init__() super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None) self.set_attr("rank_id", None)
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context") dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None: if dist_context._lr_optimizer._grad_clip is None:
return False return False
if self.get_attr("params_grads") is None:
return False
return True return True
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
...@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context", None) dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None) rank_id = self.get_attr("rank_id", None)
block = main_program.global_block() block = main_program.global_block()
dist_params_grads = _get_params_grads(block) dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block, self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context) dist_context)
......
...@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps, dist_context): def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block() main_block = main_program.global_block()
# Add const var # Add const var
...@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op( ...@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
# {grad.name: gradient_merge_var.name} to rename opt inputs # {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {} grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op # {param: gradient_merge_var} to insert scale op and fill_constant op
......
...@@ -59,6 +59,7 @@ class ShardingPass(PassBase): ...@@ -59,6 +59,7 @@ class ShardingPass(PassBase):
self.varname_to_sharding_info = {} self.varname_to_sharding_info = {}
self.partial_sharding = False self.partial_sharding = False
self.outer_dp_group = None self.outer_dp_group = None
self.shared_params_grads = []
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -94,6 +95,8 @@ class ShardingPass(PassBase): ...@@ -94,6 +95,8 @@ class ShardingPass(PassBase):
self._shard_gradient_synchronization(main_block) self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block) self._shard_parameter(main_block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
def _build_sharding_groups(self, main_block, params_grads): def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block) self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads) self._build_sharding_infos(params_grads)
...@@ -148,13 +151,10 @@ class ShardingPass(PassBase): ...@@ -148,13 +151,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank, sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group) params_grads)
self.sharding_infos.append(sharding_info) self.sharding_infos.append(sharding_info)
for param in params_in_group: for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads, def _shard_optimizer(self, main_block, startup_block, params_grads,
...@@ -201,6 +201,7 @@ class ShardingPass(PassBase): ...@@ -201,6 +201,7 @@ class ShardingPass(PassBase):
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
else: else:
if op.type == "check_finite_and_unscale": if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0] out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name] out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
...@@ -212,6 +213,7 @@ class ShardingPass(PassBase): ...@@ -212,6 +213,7 @@ class ShardingPass(PassBase):
"shape": out_var.shape, "shape": out_var.shape,
"dtype": out_var.dtype, "dtype": out_var.dtype,
"value": 0, "value": 0,
OP_ROLE_KEY: op_role,
}) })
else: else:
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
...@@ -313,6 +315,9 @@ class ShardingPass(PassBase): ...@@ -313,6 +315,9 @@ class ShardingPass(PassBase):
if varname != param_name if varname != param_name
]) ])
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))
for idx, op in reversed(list(enumerate(startup_block.ops))): for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[ if len(op.output_arg_names) == 1 and op.output_arg_names[
...@@ -365,6 +370,13 @@ class ShardingPass(PassBase): ...@@ -365,6 +370,13 @@ class ShardingPass(PassBase):
sharding_info = self.varname_to_sharding_info[param_name] sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name) return sharding_info.is_in_local_shard(param_name)
def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g
def _shard_gradient_synchronization(self, main_block): def _shard_gradient_synchronization(self, main_block):
if self.stage < 2: if self.stage < 2:
...@@ -705,9 +717,13 @@ def shard_parameters(params, group_size): ...@@ -705,9 +717,13 @@ def shard_parameters(params, group_size):
class ShardingInfo(object): class ShardingInfo(object):
def __init__(self, group, rank, params): def __init__(self, group, rank, params_grads):
self.group = group self.group = group
self.params = params self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"
self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params] self.param_names = [p.name for p in self.params]
self.group_size = group.nranks self.group_size = group.nranks
self.global_rank = rank self.global_rank = rank
...@@ -762,3 +778,11 @@ class ShardingInfo(object): ...@@ -762,3 +778,11 @@ class ShardingInfo(object):
if usage > 0: if usage > 0:
broadcast_vars.add(param) broadcast_vars.add(param)
return broadcast_vars, param_usage return broadcast_vars, param_usage
def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
...@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds = model(tokens, position_ids, attention_mask) preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion() criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask) loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
if kwargs.get('optimizer', None) == "LarsMomentum": if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
...@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program() startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize( _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册