未验证 提交 3576e49c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy
上级 369a235d
......@@ -60,14 +60,10 @@ class Engine:
strategy=None,
user_tuning_config=None):
self.model = model
self.strategy = strategy or fleet.DistributedStrategy()
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster
if self.cluster is None:
self.cluster = get_default_cluster()
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self.cluster = cluster or get_default_cluster()
self._user_tuning_config = user_tuning_config
self._executor = None
......@@ -433,7 +429,7 @@ class Engine:
break
train_logs["step: {:d} "] = step
if lr_scheduler is not None:
if lr_scheduler is not None and step % self.k_steps == 0:
lr_scheduler.step()
try:
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
......@@ -551,6 +547,12 @@ class Engine:
epochs=1,
steps_per_epoch=None,
collate_fn=None):
if self.strategy.gradient_merge and batch_size is not None:
assert batch_size % self.k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self.k_steps)
batch_size //= self.k_steps
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode]
......@@ -612,6 +614,9 @@ class Engine:
def _validate_spec(self, specs):
specs = to_list(specs)
self.k_steps = 1
if self.strategy.gradient_merge:
self.k_steps = self.strategy.gradient_merge_configs['k_steps']
if specs is not None:
for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec)
......@@ -619,6 +624,12 @@ class Engine:
raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec))
if self.k_steps > 1:
shape = list(spec.shape)
assert shape[0] % self.k_steps == 0, \
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self.k_steps)
shape[0] //= self.k_steps
spec.shape = shape
return specs
def _is_local_var(self, var):
......
......@@ -84,7 +84,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
# self._pass_context = None
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
......@@ -143,10 +143,11 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads):
optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(
self._optimizer).apply_gradients(params_grads)
optimize_ops = optimizer.apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
# update completion
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
......@@ -165,6 +166,15 @@ class AutoParallelizer:
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
......
......@@ -230,9 +230,9 @@ class Parallelizer:
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization
if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs)
config["dist_context"] = self._dist_context
......
......@@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
......@@ -575,18 +575,18 @@ class FP16Pass(AMPPass):
) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = []
if fp32_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32",
self.dist_context)
found_infs.append(found_inf_fp32)
if fp16_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16",
self.dist_context)
found_infs.append(found_inf_fp16)
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs)
......@@ -608,7 +608,7 @@ class FP16Pass(AMPPass):
block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads:
......
......@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None)
self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None:
return False
if self.get_attr("params_grads") is None:
return False
return True
def _check_conflict(self, other_pass):
......@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None)
block = main_program.global_block()
dist_params_grads = _get_params_grads(block)
dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context)
......
......@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block()
# Add const var
......@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op
......
......@@ -59,6 +59,7 @@ class ShardingPass(PassBase):
self.varname_to_sharding_info = {}
self.partial_sharding = False
self.outer_dp_group = None
self.shared_params_grads = []
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -94,6 +95,8 @@ class ShardingPass(PassBase):
self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads)
......@@ -148,13 +151,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group)
params_grads)
self.sharding_infos.append(sharding_info)
for param in params_in_group:
for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads,
......@@ -201,6 +201,7 @@ class ShardingPass(PassBase):
op.desc.set_output('Out', reversed_x)
else:
if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False)
......@@ -212,6 +213,7 @@ class ShardingPass(PassBase):
"shape": out_var.shape,
"dtype": out_var.dtype,
"value": 0,
OP_ROLE_KEY: op_role,
})
else:
main_block._remove_op(idx, sync=False)
......@@ -313,6 +315,9 @@ class ShardingPass(PassBase):
if varname != param_name
])
main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))
for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[
......@@ -365,6 +370,13 @@ class ShardingPass(PassBase):
sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name)
def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g
def _shard_gradient_synchronization(self, main_block):
if self.stage < 2:
......@@ -705,9 +717,13 @@ def shard_parameters(params, group_size):
class ShardingInfo(object):
def __init__(self, group, rank, params):
def __init__(self, group, rank, params_grads):
self.group = group
self.params = params
self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"
self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params]
self.group_size = group.nranks
self.global_rank = rank
......@@ -762,3 +778,11 @@ class ShardingInfo(object):
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars, param_usage
def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
......@@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
......@@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册