未验证 提交 b99c1d07 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Mixed Precision FP16 Pass (#40615)

*  add FP16 Pass 

* Support the auto completion of while_op

*  acc aligned
上级 5c5a3660
......@@ -67,6 +67,8 @@ message AMPConfig {
repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
optional bool use_optimizer_fp16 = 12
[ default = false ]; // auto parallel effective only
}
message LocalSGDConfig {
......
......@@ -105,9 +105,15 @@ class AutoParallelizer:
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context)
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
# apply recompute pass
if self._dist_strategy.recompute:
......
......@@ -357,10 +357,11 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
src_var = src_block.var(src_varname)
if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=True,
persistable=persist,
stop_gradient=True)
target_shape = None
else:
......
......@@ -1047,8 +1047,7 @@ def set_grad_var_shape(program, dist_context):
forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name}"
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
......
......@@ -17,6 +17,7 @@ from .fuse_all_reduce import *
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_fp16 import *
from .auto_parallel_recompute import *
from .cpp_pass import *
import os
......
......@@ -503,8 +503,6 @@ class AMPPass(PassBase):
return False
if self.get_attr("decr_ratio") < 0:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
if self.get_attr("dist_context") is None:
return False
return True
......@@ -576,6 +574,8 @@ class AMPPass(PassBase):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
......@@ -583,6 +583,37 @@ class AMPPass(PassBase):
loss_op)
if loss.dtype != core.VarDesc.VarType.FP32:
# cast loss here will change the effective loss tensor for the computation graph
# and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge),
# so we it is not allowed by now. fixed it in future.
raise NotImplementedError(
"Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list."
)
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(cast_loss,
loss_dist_attr)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
type='cast',
inputs={'X': [loss]},
outputs={'Out': [cast_loss]},
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
})
loss_op._set_attr(OP_ROLE_KEY,
core.op_proto_and_checker_maker.OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context)
loss = loss.astype('float32')
if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
......@@ -600,7 +631,6 @@ class AMPPass(PassBase):
set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
ref_mesh)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1,
type='elementwise_mul',
......@@ -667,8 +697,11 @@ class AMPPass(PassBase):
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
assert self._loss_scaling.dtype == e.dtype, \
"The dtype of prev_loss_scaling should be equal to the dtype of x."
if e.dtype == core.VarDesc.VarType.FP16:
assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs = {
'X': grads,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op
from .auto_parallel_amp import AMPPass
world_process_group = get_world_process_group()
# if user use python "+, -, * /" for network, there might be cast in vanilla program
__amp_skip_ops__ = [
'create_py_reader',
'create_double_buffer_reader',
'while',
'cast',
]
def set_op_dtype_to_fp16(op):
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
# adapot for backward op
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type == 'batch_norm':
# Scale, Bias, Mean, Variance should be float32.
return in_name != 'X'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return in_name != 'X'
if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
}
# backward
if op_type in ['batch_norm_grad']:
return in_name not in {'X', 'Y@GRAD'}
if op_type in ['layer_norm_grad']:
return in_name not in {'X', 'Y@GRAD'}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation']:
return out_name != 'Y'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean',
'Ln1Variance'
}
# backward
if op_type in ['layer_norm_grad']:
return out_name != 'X@GRAD'
if op_type in ['batch_norm_grad']:
return out_name != 'X@GRAD'
return False
class FP16State(object):
def __init__(self, program, amp_list, dist_context, use_fp16_guard):
self.program = program
self.amp_list = amp_list
self.use_fp16_guard = use_fp16_guard
self.dist_context = dist_context
self.grad_op_to_op_map = self.dist_context.dist_op_context.grad_op_id_to_op_id
self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
# a trick to determine leaf tensor node in program {varname: generator_op_id}
self.forward_non_leaf_tensors = {}
# record the cast ops that are inserted for a forward
self.forward_input_cast_ops = defaultdict(
list
) # {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self.is_train = False
def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)
def _build_state(self):
"""
mark the execution mode (fp16 or fp32) for ops in all blocks
include forward ops & backward ops
"""
# mark op dtype
# assume all backward block are behind forward blocks
for block in self.program.blocks:
for op in block.ops:
self._mark_op(op)
# set forward tensor dtype
for block in self.program.blocks:
self.resolute_tensor_dtype(block)
# insert cast ops
for block in self.program.blocks:
self.cast_block(block)
return self.is_train
def _mark_op(self, op):
if op.type in __amp_skip_ops__:
return
if is_forward_op(op):
# ernie inference trick
if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.id()] = False
return
if _need_keep_fp32(op, self.amp_list.unsupported_list,
self.use_fp16_guard):
self._op_fp16_dict[op.desc.id()] = False
else:
self._op_fp16_dict[op.desc.id()] = True
for var_name in op.output_arg_names:
# assert var_name not in self.forward_non_leaf_tensors, "{}".format(var_name)
self.forward_non_leaf_tensors[var_name] = op.desc.id()
elif is_backward_op(op) == int(OpRole.Backward):
if op.desc.id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[op.desc.id()]
assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
self._op_fp16_dict[op.desc.id()] = self._op_fp16_dict[fwd_op_id]
if int(op.attr('op_role')) == 257:
self.is_train = True
def set_var_to_fp16(self, var_name, block):
var = None
try:
var = block.var(var_name)
except ValueError as e:
var = self.program.global_block().var(var_name)
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
if var is None or var.type not in _valid_types or "array_" in var_name:
return
if var.dtype == core.VarDesc.VarType.FP32:
var.desc.set_dtype(core.VarDesc.VarType.FP16)
def resolute_tensor_dtype(self, block):
for op in block.ops:
op_id = op.desc.id()
if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op_id) == True or op.type == "cast":
for in_name in op.input_names:
if _keep_fp32_input(op, in_name):
continue
for in_var_name in op.input(in_name):
if in_var_name not in self.forward_non_leaf_tensors:
self.set_var_to_fp16(in_var_name, block)
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif self._is_fp16_op(op_id) == False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op):
if self._is_fp16_op(op_id) == True:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
self.set_var_to_fp16(out_var_name, block)
set_op_dtype_to_fp16(op)
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
elif self._is_fp16_op(op_id) == False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
def cast_block(self, block):
dist_op_context = self.dist_context.dist_op_context
idx = 0
while idx < len(block.ops):
op = block.ops[idx]
op_id = op.desc.id()
num_cast_ops = 0
if op.type in __amp_skip_ops__:
idx += 1
continue
elif is_forward_op(op):
if self._is_fp16_op(op_id) == False:
num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context)
elif self._is_fp16_op(op_id) == True:
num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context)
elif is_backward_op(op):
if op_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(op_id) == False:
num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context)
elif self._is_fp16_op(op_id) == True:
num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context)
elif op.type == "sum":
# all inputs dtype of sum should be equal and output dtype should follow input
out_var_name = op.output_arg_names[0]
in_var_name = op.input_arg_names[0]
out_var = block.var(out_var_name)
in_var = block._find_var_recursive(in_var_name)
for in_var_name in op.input_arg_names:
assert in_var.dtype == block.var(
in_var_name).dtype, "{}, {}, {}".format(
in_var, block.var(in_var_name), str(op))
out_var.desc.set_dtype(in_var.dtype)
idx += num_cast_ops + 1
block._sync_with_cpp()
def _insert_forward_cast_ops(self, op, idx, block, src_dtype, dst_dtype,
dist_context):
num_cast_ops = 0
op_id = op.desc.id()
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name):
continue
consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
assert consume_op_attr is not None
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var is None or in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(
dst_dtype)
cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op_id] += [(
cast_name, in_var.name, dst_dtype, src_dtype, in_name)]
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name)
assert in_var_dist_attr is not None
# truely insert cast op
if cast_var is None or cast_var.dtype != dst_dtype:
# NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var
# refine op's dist_attr
ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping
cast_var = block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient)
set_var_dist_attr(dist_context, cast_var, ref_mapping,
ref_mesh)
cast_op = block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": cast_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": cast_var.dtype,
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
num_cast_ops += 1
op._rename_input(in_var.name, cast_name)
consume_op_attr.set_input_dist_attr(cast_name,
in_var_dist_attr)
if op.has_attr('out_dtype') and op.attr('out_dtype') != -1:
assert op.attr('out_dtype') == dst_dtype
return num_cast_ops
def _insert_backward_cast_ops(self, op, idx, block, src_dtype, dst_dtype,
dist_context):
num_cast_ops = 0
op_id = op.desc.id()
dist_op_context = dist_context.dist_op_context
forward_op_id = dist_op_context.grad_op_id_to_op_id[op_id]
grad_op_attr = dist_context.get_op_dist_attr_for_program(op)
assert grad_op_attr is not None
for out_var_name in op.output_arg_names:
out_var = block.var(out_var_name)
if _keep_fp32_output(op, out_var.name):
continue
assert out_var.dtype == dst_dtype, "{}, {}".format(
str(out_var), dst_dtype)
for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[
forward_op_id]:
# rename input
assert src_name in op.input(
slot_name), "var: {} not in op's {}. {}".format(src_name,
slot_name,
str(op))
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name)
grad_op_attr.set_input_dist_attr(cast_name, src_var_dist_attr)
# create cast grad
grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
assert len(op.output(grad_slot_name)) == 1
grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name)
grad_dist_attr = grad_op_attr.get_output_dist_attr(grad_name)
assert grad_dist_attr is not None, "{}".format(grad_name)
ref_mesh = grad_dist_attr.process_mesh
ref_mapping = grad_dist_attr.dims_mapping
cast_grad = block.create_var(
name=unique_name.generate_with_ignorable_key("".join(
[cast_name, '@GRAD'])),
dtype=dst_dtype,
shape=grad.shape,
type=grad.type,
persistable=grad.persistable,
stop_gradient=grad.stop_gradient)
dist_context.set_tensor_dist_attr_for_program(cast_grad,
grad_dist_attr)
op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)
# add cast
cast_op = block._insert_op_without_sync(
idx + 1,
type="cast",
inputs={"X": [cast_grad.name]},
outputs={"Out": [grad.name]},
attrs={
"in_dtype": dst_dtype,
"out_dtype": src_dtype,
})
grad.desc.set_dtype(src_dtype)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
num_cast_ops += 1
return num_cast_ops
def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
['find_infinite_scale', name])),
shape=[1],
dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks
new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale"
for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(g.name,
g_dist_attr.dims_mapping)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf
def _split_grads(params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \
"Data types of all grads must be either fp16 or fp32."
return grads, fp32_grads, fp16_grads
def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = ranks
new_op_dist_attr.impl_idx = 0
for var_name in new_op.input_arg_names:
var = block.var(var_name)
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
assert var_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(var_name,
var_dist_attr.dims_mapping)
for var_name in new_op.output_arg_names:
var = block.var(var_name)
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
assert var_dist_attr is not None
new_op_dist_attr.set_output_dims_mapping(var_name,
var_dist_attr.dims_mapping)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
def __init__(self):
super(FP16Pass, self).__init__()
# NOTE: why FP16Pass can override apply_single_impl instead of
# apply_impl? AMP is an optimization pass for serial program,
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
amp_list = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), None)
# TODO support multiple blocks
with paddle.static.program_guard(main_program, startup_program):
fp16_state = FP16State(main_program, amp_list, self.dist_context,
self.get_attr("use_fp16_guard"))
is_train = fp16_state._build_state()
if is_train:
with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
self._scale_loss()
grads, fp32_grads, fp16_grads = _split_grads(params_grads)
if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
"init_loss_scaling") != 1.0:
found_infs = []
if fp32_grads:
with main_program._backward_role_guard():
_, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32",
self.dist_context)
found_infs.append(found_inf_fp32)
if fp16_grads:
with main_program._backward_role_guard():
_, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16",
self.dist_context)
found_infs.append(found_inf_fp16)
with main_program._backward_role_guard():
block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs)
set_var_dist_attr(self.dist_context, all_infs, [-1],
world_process_group.ranks)
new_op = block.ops[-1]
assert new_op.type == "concat"
_set_op_dist_attr_with_ranks(new_op,
world_process_group.ranks,
block, self.dist_context)
found_inf = paddle.fluid.layers.reduce_any(all_infs)
set_var_dist_attr(self.dist_context, found_inf, [-1],
world_process_group.ranks)
new_op = block.ops[-1]
assert new_op.type == "reduce_any"
_set_op_dist_attr_with_ranks(new_op,
world_process_group.ranks,
block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard():
if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads:
self._update_loss_scaling(fp16_grads, found_inf)
# modify optimizer
base_opt = self.get_attr("base_opt")
base_opt._multi_precision = True
if self.get_attr("use_optimizer_fp16"):
base_opt._multi_precision = False
if isinstance(base_opt, (paddle.fluid.optimizer.Adam,
paddle.optimizer.AdamW)):
# with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy(
# found_inf, paddle.CPUPlace())
base_opt._set_auxiliary_var('found_inf', found_inf.name)
elif hasattr(base_opt, "_set_auxiliary_var"):
base_opt._set_auxiliary_var('found_inf', found_inf.name)
......@@ -14,6 +14,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_XPU) AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass")
endif()
foreach(TEST_OP ${TEST_OPS})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.distributed.fleet as fleet
from auto_parallel_pass_test_base import AutoPallelPassTestBase
from test_auto_parallel_amp_pass import TestAMPPass
class TestPF16Pass(TestAMPPass):
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = True
dist_strategy.amp_configs = {
"custom_white_list": [
'softmax',
'layer_norm',
'gelu',
],
"custom_black_list": ['c_softmax_with_cross_entropy'],
"init_loss_scaling": 32768,
"use_dynamic_loss_scaling": True,
"use_pure_fp16": True
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册