未验证 提交 496422e9 编写于 作者: Z zhaoyingli 提交者: GitHub

make params_grads order same bewteen dynamic and auto_parallel (#56126)

* make params_grads order same bewteen dynamic and static mode

* revert inplace clip

* use sorted attribute to control

* tiny fix

* fix find loss_grad_op
上级 70a009ee
......@@ -133,6 +133,9 @@ class Engine:
"'model must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._model = model
self._parameter_list = (
None if not model else [p.name for p in model.parameters()]
)
if (
loss
......@@ -765,9 +768,9 @@ class Engine:
self._dist_contexts[mode],
)
if not all_ranks:
parallelizer.parallel(self._cur_rank)
parallelizer.parallel(self._cur_rank, self._parameter_list)
else:
parallelizer.parallel_all()
parallelizer.parallel_all(self._parameter_list)
def _init_dist_context(self, mode):
# Init dist_context['mode'] with the first planned dist_context
......
......@@ -46,15 +46,15 @@ class Parallelizer:
def is_test(self):
return self._mode in ["eval", "predict"]
def parallel_all(self):
def parallel_all(self, parameter_list=None):
world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks
for rank in all_ranks:
# self._dist_context._backup(serial=True, dist=True)
self.parallel(rank)
self.parallel(rank, parameter_list)
# self._dist_context._restore(serial=True, dist=True)
def parallel(self, rank):
def parallel(self, rank, parameter_list=None):
serial_main_program = self._dist_context.serial_main_program
serial_startup_program = self._dist_context.serial_startup_program
serial_optimizer = self._dist_context.serial_optimizer
......@@ -62,7 +62,10 @@ class Parallelizer:
# Generate backward
serial_loss = self._dist_context.serial_loss
params_grads = self._generate_backward(
serial_main_program, serial_startup_program, serial_loss
serial_main_program,
serial_startup_program,
serial_loss,
parameter_list,
)
# Apply pre optimization passes
time0 = time.time()
......@@ -211,10 +214,20 @@ class Parallelizer:
self._dist_context.dist_main_programs[rank] = dist_main_prog
self._dist_context.dist_startup_programs[rank] = dist_startup_prog
def _generate_backward(self, main_program, startup_program, loss):
def _generate_backward(
self, main_program, startup_program, loss, parameter_list=None
):
# NOTE(zhaoyinglia):
# Guarantee the order of params_grads is same between dynamic mode and static mode
# by making parameter_list equal to model.parameters(),
# because the order affact the result of ClipGradByGLobalNorm.
# If parameter_list is not None, the order of params_grads is same with parameter_list.
# If parameter_list is None, params_grads will be as prog.global_block().all_parameters().
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss, distop_context=self._dist_context.dist_op_context
loss,
parameter_list=parameter_list,
distop_context=self._dist_context.dist_op_context,
)
self._completer.complete_backward_annotation(main_program)
self._dist_context.block_state.parse_backward_blocks(main_program)
......@@ -231,6 +244,7 @@ class Parallelizer:
optimizer = copy.deepcopy(optimizer)
self._dist_context._serial_optimizer = optimizer
self._dist_context._serial_optimizer._learning_rate = learning_rate
optimizer._sorted = False
with program_guard(main_program, startup_program):
with unique_name.guard("opt_"):
......
......@@ -1053,11 +1053,16 @@ class AMPPass(PassBase):
)
# backward
first_backward_op = main_block.ops[loss_op_idx + 2]
assert (
first_backward_op.type == "fill_constant"
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
)
first_backward_op = None
for op in main_block.ops[loss_op_idx:]:
if op.type == "fill_constant" and is_loss_grad_op(op):
first_backward_op = op
break
if is_backward_op(op):
break
assert first_backward_op is not None, "There is not loss_grad op."
scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape,
......
......@@ -802,7 +802,7 @@ class FP16Pass(AMPPass):
cast_startup_program()
if is_train:
if self.target_dtype == "fp16":
if self.target_dtype == "float16":
with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
......@@ -870,7 +870,7 @@ class FP16Pass(AMPPass):
# found_inf = paddle.fluid.layers.reduce_any(all_infs)
found_inf = block.create_var(
name=paddle.utils.unique_name.generate_with_ignorable_key(
".".join(['reduce_any', 'tmp'])
".".join(['find_infinite_scale', 'tmp'])
),
dtype=all_infs.dtype,
shape=None,
......@@ -915,7 +915,7 @@ class FP16Pass(AMPPass):
if self.use_optimizer_fp16:
base_opt._multi_precision = False
if self.target_dtype == "fp16":
if self.target_dtype == "float16":
if isinstance(
base_opt, (paddle.optimizer.Adam, paddle.optimizer.AdamW)
):
......
......@@ -1258,8 +1258,9 @@ class Optimizer:
>>> optimizer.apply_gradients(params_grads)
"""
params_grads = sorted(params_grads, key=lambda x: x[0].name)
# NOTE(zhaoyinglia): AutoParallel set '_sorted' attribute to skip the 'sorted' operator.
if not hasattr(self, "_sorted"):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
# 'optimizer(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册