diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 63fc4660c7f043cc1a424b78f12ab6a001390f6a..250d7c9d58d4e2997b978fa63277db230a21e9df 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -110,10 +110,12 @@ class AutoParallelizer: auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply([main_program], [startup_program], self._pass_context) + loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply([main_program], [startup_program], self._pass_context) + loss = auto_parallel_amp_pass.get_loss() # apply recompute pass if self._dist_strategy.recompute: diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index a1dd58fef71316507098c2e93aa1f4e403fab6a9..32f7b5f3aa68b642921b1de604f7c95f1a2e4673 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -192,10 +192,12 @@ class Parallelizer: auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply([main_program], [startup_program], self._pass_context) + loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply([main_program], [startup_program], self._pass_context) + loss = auto_parallel_amp_pass.get_loss() # apply recompute pass # recompute is then train-only optimization diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index a2da7396ce83c6e67bcb2edc1bdaadb8c4a7970b..835faed0f18e2e1294c44a4ee32a64ada4153492 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -271,10 +271,12 @@ class OptimizationTuner: auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply([main_program], [startup_program], pass_context) + dist_context.serial_loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply([main_program], [startup_program], pass_context) + dist_context.serial_loss = auto_parallel_amp_pass.get_loss() if new_strategy.recompute.enable: config = copy.deepcopy(new_strategy.recompute.to_dict()) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 3545783ba177e4428c3793bb5cd1a2a67b7a8172..064075bff366776fa5738138b47f827c18a1fe48 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -614,21 +614,17 @@ class AMPPass(PassBase): loss_op) if loss.dtype != core.VarDesc.VarType.FP32: - # cast loss here will change the effective loss tensor for the computation graph - # and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge), - # so we it is not allowed by now. fixed it in future. - raise NotImplementedError( - "Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list." - ) tmp_name = unique_name.generate(loss.name + ".cast_fp32") - cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) + cast_loss = main_block.create_var(name=tmp_name, + dtype=core.VarDesc.VarType.FP32) loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( loss) ref_mesh = loss_op_dist_attr.process_mesh self.dist_context.set_tensor_dist_attr_for_program( cast_loss, loss_dist_attr) + # forward loss_op_idx = find_op_index(main_block.desc, loss_op.desc) cast_op = main_block._insert_op( loss_op_idx + 1, @@ -645,7 +641,34 @@ class AMPPass(PassBase): core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, [-1], self.dist_context) - loss = loss.astype('float32') + + # backward + first_backward_op = main_block.ops[loss_op_idx + 2] + assert first_backward_op.type == "fill_constant" and int( + first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 + cast_loss_grad = main_block.create_var( + name=unique_name.generate(tmp_name + "@GRAD"), + shape=loss.shape, + dtype=core.VarDesc.VarType.FP32, + persistable=loss.persistable) + set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh) + + pre_grad_name = first_backward_op.output_arg_names[0] + first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name) + cast_grad_op = main_block._insert_op( + loss_op_idx + 3, + type='cast', + inputs={'X': [cast_loss_grad]}, + outputs={'Out': [pre_grad_name]}, + attrs={ + "in_dtype": core.VarDesc.VarType.FP32, + "out_dtype": core.VarDesc.VarType.FP16, + 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, + }) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_grad_op, ref_mesh, [-1], self.dist_context) + loss_op = cast_op + loss = cast_loss if self.get_attr("use_dynamic_loss_scaling" ) or self.get_attr("init_loss_scaling") != 1.0: @@ -718,7 +741,7 @@ class AMPPass(PassBase): else: self._scaled_loss = loss - + self._loss = loss main_block._sync_with_cpp() def _update_loss_scaling(self, grads, found_inf): @@ -782,3 +805,13 @@ class AMPPass(PassBase): self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) main_block._sync_with_cpp() + + def get_loss(self): + # the amp / fp16 might change the effective loss variable for network and + # therefore would affect the subsequent passes that rely on the loss. + # return the effective loss after amp / fp16 pass. + + if self._loss: + return self._loss + else: + return self.get_attr("loss") diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 64562668a42ac7cf74d159582e2dd3e5493fbb3f..541901f0c7665add0306d78e786f29d5f0be3396 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -368,6 +368,10 @@ class FP16State(object): for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[ forward_op_id]: + # some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy + if slot_name not in op.input_names: + continue + # rename input assert src_name in op.input( slot_name), "var: {} not in op's {}. {}".format( @@ -379,12 +383,15 @@ class FP16State(object): # create cast grad grad_slot_name = slot_name + "@GRAD" - assert grad_slot_name in op.output_names + assert grad_slot_name in op.output_names, "[{}], Current Op: {}".format( + grad_slot_name, str(op)) + + # some forward input maybe stop_gradient=True, e.g. input_mask if len(op.output(grad_slot_name)) == 0: - var = block.var(src_name) - assert var.stop_gradient is True continue - assert len(op.output(grad_slot_name)) == 1 + assert len( + op.output(grad_slot_name)) == 1, "[{}], Current Op: {}".format( + grad_slot_name, str(op)) grad_name = op.output(grad_slot_name)[0] grad = block.var(grad_name) grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)