未验证 提交 496422e9 编写于 作者: Z zhaoyingli 提交者: GitHub

make params_grads order same bewteen dynamic and auto_parallel (#56126)

* make params_grads order same bewteen dynamic and static mode

* revert inplace clip

* use sorted attribute to control

* tiny fix

* fix find loss_grad_op
上级 70a009ee
...@@ -133,6 +133,9 @@ class Engine: ...@@ -133,6 +133,9 @@ class Engine:
"'model must be sub classes of `paddle.nn.Layer` or any callable function." "'model must be sub classes of `paddle.nn.Layer` or any callable function."
) )
self._model = model self._model = model
self._parameter_list = (
None if not model else [p.name for p in model.parameters()]
)
if ( if (
loss loss
...@@ -765,9 +768,9 @@ class Engine: ...@@ -765,9 +768,9 @@ class Engine:
self._dist_contexts[mode], self._dist_contexts[mode],
) )
if not all_ranks: if not all_ranks:
parallelizer.parallel(self._cur_rank) parallelizer.parallel(self._cur_rank, self._parameter_list)
else: else:
parallelizer.parallel_all() parallelizer.parallel_all(self._parameter_list)
def _init_dist_context(self, mode): def _init_dist_context(self, mode):
# Init dist_context['mode'] with the first planned dist_context # Init dist_context['mode'] with the first planned dist_context
......
...@@ -46,15 +46,15 @@ class Parallelizer: ...@@ -46,15 +46,15 @@ class Parallelizer:
def is_test(self): def is_test(self):
return self._mode in ["eval", "predict"] return self._mode in ["eval", "predict"]
def parallel_all(self): def parallel_all(self, parameter_list=None):
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks all_ranks = world_process_group.ranks
for rank in all_ranks: for rank in all_ranks:
# self._dist_context._backup(serial=True, dist=True) # self._dist_context._backup(serial=True, dist=True)
self.parallel(rank) self.parallel(rank, parameter_list)
# self._dist_context._restore(serial=True, dist=True) # self._dist_context._restore(serial=True, dist=True)
def parallel(self, rank): def parallel(self, rank, parameter_list=None):
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
serial_startup_program = self._dist_context.serial_startup_program serial_startup_program = self._dist_context.serial_startup_program
serial_optimizer = self._dist_context.serial_optimizer serial_optimizer = self._dist_context.serial_optimizer
...@@ -62,7 +62,10 @@ class Parallelizer: ...@@ -62,7 +62,10 @@ class Parallelizer:
# Generate backward # Generate backward
serial_loss = self._dist_context.serial_loss serial_loss = self._dist_context.serial_loss
params_grads = self._generate_backward( params_grads = self._generate_backward(
serial_main_program, serial_startup_program, serial_loss serial_main_program,
serial_startup_program,
serial_loss,
parameter_list,
) )
# Apply pre optimization passes # Apply pre optimization passes
time0 = time.time() time0 = time.time()
...@@ -211,10 +214,20 @@ class Parallelizer: ...@@ -211,10 +214,20 @@ class Parallelizer:
self._dist_context.dist_main_programs[rank] = dist_main_prog self._dist_context.dist_main_programs[rank] = dist_main_prog
self._dist_context.dist_startup_programs[rank] = dist_startup_prog self._dist_context.dist_startup_programs[rank] = dist_startup_prog
def _generate_backward(self, main_program, startup_program, loss): def _generate_backward(
self, main_program, startup_program, loss, parameter_list=None
):
# NOTE(zhaoyinglia):
# Guarantee the order of params_grads is same between dynamic mode and static mode
# by making parameter_list equal to model.parameters(),
# because the order affact the result of ClipGradByGLobalNorm.
# If parameter_list is not None, the order of params_grads is same with parameter_list.
# If parameter_list is None, params_grads will be as prog.global_block().all_parameters().
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
params_grads = append_backward( params_grads = append_backward(
loss, distop_context=self._dist_context.dist_op_context loss,
parameter_list=parameter_list,
distop_context=self._dist_context.dist_op_context,
) )
self._completer.complete_backward_annotation(main_program) self._completer.complete_backward_annotation(main_program)
self._dist_context.block_state.parse_backward_blocks(main_program) self._dist_context.block_state.parse_backward_blocks(main_program)
...@@ -231,6 +244,7 @@ class Parallelizer: ...@@ -231,6 +244,7 @@ class Parallelizer:
optimizer = copy.deepcopy(optimizer) optimizer = copy.deepcopy(optimizer)
self._dist_context._serial_optimizer = optimizer self._dist_context._serial_optimizer = optimizer
self._dist_context._serial_optimizer._learning_rate = learning_rate self._dist_context._serial_optimizer._learning_rate = learning_rate
optimizer._sorted = False
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
with unique_name.guard("opt_"): with unique_name.guard("opt_"):
......
...@@ -1053,11 +1053,16 @@ class AMPPass(PassBase): ...@@ -1053,11 +1053,16 @@ class AMPPass(PassBase):
) )
# backward # backward
first_backward_op = main_block.ops[loss_op_idx + 2] first_backward_op = None
assert ( for op in main_block.ops[loss_op_idx:]:
first_backward_op.type == "fill_constant" if op.type == "fill_constant" and is_loss_grad_op(op):
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 first_backward_op = op
) break
if is_backward_op(op):
break
assert first_backward_op is not None, "There is not loss_grad op."
scaled_loss_grad = main_block.create_var( scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD", name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape, shape=loss.shape,
......
...@@ -802,7 +802,7 @@ class FP16Pass(AMPPass): ...@@ -802,7 +802,7 @@ class FP16Pass(AMPPass):
cast_startup_program() cast_startup_program()
if is_train: if is_train:
if self.target_dtype == "fp16": if self.target_dtype == "float16":
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var() self._init_amp_var()
...@@ -870,7 +870,7 @@ class FP16Pass(AMPPass): ...@@ -870,7 +870,7 @@ class FP16Pass(AMPPass):
# found_inf = paddle.fluid.layers.reduce_any(all_infs) # found_inf = paddle.fluid.layers.reduce_any(all_infs)
found_inf = block.create_var( found_inf = block.create_var(
name=paddle.utils.unique_name.generate_with_ignorable_key( name=paddle.utils.unique_name.generate_with_ignorable_key(
".".join(['reduce_any', 'tmp']) ".".join(['find_infinite_scale', 'tmp'])
), ),
dtype=all_infs.dtype, dtype=all_infs.dtype,
shape=None, shape=None,
...@@ -915,7 +915,7 @@ class FP16Pass(AMPPass): ...@@ -915,7 +915,7 @@ class FP16Pass(AMPPass):
if self.use_optimizer_fp16: if self.use_optimizer_fp16:
base_opt._multi_precision = False base_opt._multi_precision = False
if self.target_dtype == "fp16": if self.target_dtype == "float16":
if isinstance( if isinstance(
base_opt, (paddle.optimizer.Adam, paddle.optimizer.AdamW) base_opt, (paddle.optimizer.Adam, paddle.optimizer.AdamW)
): ):
......
...@@ -1258,8 +1258,9 @@ class Optimizer: ...@@ -1258,8 +1258,9 @@ class Optimizer:
>>> optimizer.apply_gradients(params_grads) >>> optimizer.apply_gradients(params_grads)
""" """
# NOTE(zhaoyinglia): AutoParallel set '_sorted' attribute to skip the 'sorted' operator.
params_grads = sorted(params_grads, key=lambda x: x[0].name) if not hasattr(self, "_sorted"):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
# 'optimizer(grad_clip)' or 'set_gradient_clip' # 'optimizer(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None: if self._grad_clip is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册