未验证 提交 3070dc8b 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Generalize Amp Pass (#46519)

* support input mask
上级 526d963e
......@@ -110,10 +110,12 @@ class AutoParallelizer:
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program], [startup_program],
self._pass_context)
loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
loss = auto_parallel_amp_pass.get_loss()
# apply recompute pass
if self._dist_strategy.recompute:
......
......@@ -192,10 +192,12 @@ class Parallelizer:
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program], [startup_program],
self._pass_context)
loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
loss = auto_parallel_amp_pass.get_loss()
# apply recompute pass
# recompute is then train-only optimization
......
......@@ -271,10 +271,12 @@ class OptimizationTuner:
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program], [startup_program],
pass_context)
dist_context.serial_loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
pass_context)
dist_context.serial_loss = auto_parallel_amp_pass.get_loss()
if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute.to_dict())
......
......@@ -614,21 +614,17 @@ class AMPPass(PassBase):
loss_op)
if loss.dtype != core.VarDesc.VarType.FP32:
# cast loss here will change the effective loss tensor for the computation graph
# and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge),
# so we it is not allowed by now. fixed it in future.
raise NotImplementedError(
"Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list."
)
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
cast_loss = main_block.create_var(name=tmp_name,
dtype=core.VarDesc.VarType.FP32)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(
cast_loss, loss_dist_attr)
# forward
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
......@@ -645,7 +641,34 @@ class AMPPass(PassBase):
core.op_proto_and_checker_maker.OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context)
loss = loss.astype('float32')
# backward
first_backward_op = main_block.ops[loss_op_idx + 2]
assert first_backward_op.type == "fill_constant" and int(
first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
cast_loss_grad = main_block.create_var(
name=unique_name.generate(tmp_name + "@GRAD"),
shape=loss.shape,
dtype=core.VarDesc.VarType.FP32,
persistable=loss.persistable)
set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
cast_grad_op = main_block._insert_op(
loss_op_idx + 3,
type='cast',
inputs={'X': [cast_loss_grad]},
outputs={'Out': [pre_grad_name]},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_grad_op, ref_mesh, [-1], self.dist_context)
loss_op = cast_op
loss = cast_loss
if self.get_attr("use_dynamic_loss_scaling"
) or self.get_attr("init_loss_scaling") != 1.0:
......@@ -718,7 +741,7 @@ class AMPPass(PassBase):
else:
self._scaled_loss = loss
self._loss = loss
main_block._sync_with_cpp()
def _update_loss_scaling(self, grads, found_inf):
......@@ -782,3 +805,13 @@ class AMPPass(PassBase):
self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
main_block._sync_with_cpp()
def get_loss(self):
# the amp / fp16 might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp / fp16 pass.
if self._loss:
return self._loss
else:
return self.get_attr("loss")
......@@ -368,6 +368,10 @@ class FP16State(object):
for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[
forward_op_id]:
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if slot_name not in op.input_names:
continue
# rename input
assert src_name in op.input(
slot_name), "var: {} not in op's {}. {}".format(
......@@ -379,12 +383,15 @@ class FP16State(object):
# create cast grad
grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
assert grad_slot_name in op.output_names, "[{}], Current Op: {}".format(
grad_slot_name, str(op))
# some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue
assert len(op.output(grad_slot_name)) == 1
assert len(
op.output(grad_slot_name)) == 1, "[{}], Current Op: {}".format(
grad_slot_name, str(op))
grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name)
grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册