未验证 提交 cc24427e 编写于 作者: J JZ-LIANG 提交者: GitHub

[Dist Pass] Amp Pass (#38764)

* auto parallel sharding base

* chmod

* add unitest

* set unitest cmake dist label

* revise code according to rewiew

* chmod

* bugfix for grad_clip and param broadcast

* chmod

* update unitest

* chmod

* add clip

* chmod

* add amp pass

* chmod

* add unitest

* remove grad update

* fixed bug

* fixed bug

* fixed typose

* fixed typoes
上级 4a64ca1e
......@@ -23,3 +23,4 @@ from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
from . import dist_default
from . import dist_check_finite_and_unscale
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..utils import set_var_dist_attr
from ..utils import set_dist_op_desc_original_id
from ..process_group import new_process_group
from ..dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.process_group import get_world_process_group
global_process_mesh = get_world_process_group().ranks
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscale, self).__init__()
self._name = name
register_distributed_operator_impl_container(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscale("check_finite_and_unscale"))
class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__()
self._name = name
self._forward_implemented = False
self._backward_implemented = True
def is_input_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_input_compatible should not be called !"
)
def is_output_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !"
)
def update_dims_mapping(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !"
)
@staticmethod
def forward(ctx, *args, **kwargs):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's forward should not be called !"
)
@staticmethod
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
assert rank_id in dist_attr.process_mesh.processes
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Scale' in kwargs, "input [{}] is not given".format('Scale')
assert 'Out' in kwargs, "input [{}] is not given".format('Out')
assert 'FoundInfinite' in kwargs, "output [{}] is not given".format(
'FoundInfinite')
assert len(
kwargs['Scale']
) == 1, "check_finite_and_unscale input Scale take 1 variable but got {}".format(
kwargs['Scale'])
assert len(
kwargs['FoundInfinite']
) == 1, "check_finite_and_unscale input FoundInfinite take 1 variable but got {}".format(
kwargs['FoundInfinite'])
assert len(kwargs['X']) == len(
kwargs['Out']
), "check_finite_and_unscale got [{}] X and [{}] Out, which are supposed to be equal".format(
len(kwargs['X']), len(kwargs['Out']))
filter_vars = []
for varname in kwargs['X']:
if rank_id in ctx.get_tensor_dist_attr_for_program(
main_block.var(varname)).process_mesh.processes:
filter_vars.append(varname)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars)
main_block._sync_with_cpp()
# sync result
group = new_process_group(global_process_mesh)
inf_var = main_block.var(kwargs['FoundInfinite'][0])
inf_var_int32 = main_block.create_var(
name=inf_var.name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
set_var_dist_attr(
ctx, inf_var_int32,
ctx.get_tensor_dist_attr_for_program(inf_var).dims_mapping,
ctx.get_tensor_dist_attr_for_program(inf_var).process_mesh)
cast_op1 = main_block.append_op(
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Backward
})
allreduce_op = main_block.append_op(
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
cast_op2 = main_block.append_op(
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Backward
})
main_block._sync_with_cpp()
for op in [cast_op1, allreduce_op, cast_op2]:
new_op_dist_attr = OperatorDistributedAttribute()
for varname in op.input_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block.var(varname))
assert var_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(
varname, var_dist_attr.dims_mapping)
for varname in op.output_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block.var(varname))
new_op_dist_attr.set_output_dims_mapping(
varname, var_dist_attr.dims_mapping)
new_op_dist_attr.process_mesh = var_dist_attr.process_mesh
ctx.set_op_dist_attr_for_program(op, new_op_dist_attr)
register_distributed_operator_impl(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscaleImpl("check_finite_and_unscale"))
......@@ -36,7 +36,7 @@ from .completion import complete_annotation, complete_backward_annotation, compl
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
from .process_group import get_world_process_groups
from .process_group import get_world_process_group
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard
from .utils import set_grad_var_shape
......@@ -97,13 +97,16 @@ class AutoParallelizer:
if suffix in attr_name:
op._remove_attr(attr_name)
def _apply_serial_pass(self, main_program, startup_program):
def _apply_pre_optimization_passed(self, main_program, startup_program,
loss, params_grads):
# apply amp pass
if self._dist_strategy.amp:
auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass",
self._dist_strategy.amp_configs)
auto_parallel_amp_pass.apply(main_program, startup_program,
config = copy.deepcopy(self._dist_strategy.amp_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
# apply recompute pass
......@@ -185,10 +188,10 @@ class AutoParallelizer:
self._parameter_list, self._no_grad_set, self._callbacks)
# serial forward pass
self._apply_serial_pass(completed_main_program, serial_startup_program)
self._apply_pre_optimization_passed(completed_main_program,
serial_startup_program, serial_loss,
params_grads)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
completed_main_program, serial_startup_program, params_grads)
......@@ -235,7 +238,7 @@ class AutoParallelizer:
assert self._cluster is not None, \
"The cluster must not be none when using auto mapping."
dist_programs = {}
world_process_group = get_world_process_groups()
world_process_group = get_world_process_group()
dist_context = None
# auto search
if self._dist_strategy.auto_search:
......
......@@ -33,7 +33,7 @@ def get_process_group(group_id, g_process_group_map=None):
group_id, None)
def get_world_process_groups():
def get_world_process_group():
global _g_process_group_map
return _g_process_group_map[0]
......
......@@ -16,6 +16,7 @@ from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import *
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .cpp_pass import *
__all__ = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _keep_fp32_output, find_op_index
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
global_process_mesh = get_world_process_group().ranks
class AMPState(object):
def __init__(self, block):
self._block = block
self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)
def _build_stats(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward):
if op.desc.id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[op.desc.id(
)]
if self._is_fp16_op(fwd_op_id) == True:
self._op_fp16_dict[op.desc.id()] = True
elif self._is_fp16_op(fwd_op_id) == False:
self._op_fp16_dict[op.desc.id()] = False
elif int(op.attr('op_role')) == int(OpRole.Optimize):
break
def _mark_black_white_ops(self, amp_lists):
"""
this function is modified from paddle.fluid.contrib.mixed_precision
"""
self._block._sync_with_cpp()
ops = self._block.ops
for op in ops:
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
self._op_fp16_dict[op.desc.id()] = False
continue
if op.type in amp_lists.black_list:
self._op_fp16_dict[op.desc.id()] = False
elif op.type in amp_lists.white_list:
self._op_fp16_dict[op.desc.id()] = True
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = self._block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op,
in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if self._is_fp16_op(prev_op.desc.id()) == False or \
prev_op.type in amp_lists.black_list:
is_black_op = True
elif self._is_fp16_op(prev_op.desc.id()) == True or \
prev_op.type in amp_lists.white_list:
is_white_op = True
if is_black_op:
self._op_fp16_dict[op.desc.id()] = False
elif is_white_op:
self._op_fp16_dict[op.desc.id()] = True
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
self._op_fp16_dict[op.desc.id()] = False
def cast_forward_program(self, dist_context):
ops = self._block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if int(op.attr('op_role')) == int(OpRole.Backward):
break
if self._is_fp16_op(op.desc.id()) == False:
num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context)
elif self._is_fp16_op(op.desc.id()) == True:
num_cast_ops = self._insert_cast_op_forward(
op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context)
else:
pass
idx += num_cast_ops + 1
self._block._sync_with_cpp()
def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype,
dist_context):
"""
only for forward cast
modified from paddle.fluid.contrib.mixed_precision
"""
num_cast_ops = 0
for in_name in op.input_names:
var_name_dict = {}
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name):
continue
for in_var_name in op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(
dst_dtype)
out_var = self._block.vars.get(cast_name)
var_name_dict[in_var.name] = cast_name
consume_op_attr = dist_context.get_op_dist_attr_for_program(
op)
assert consume_op_attr is not None
if out_var is None or out_var.dtype != dst_dtype:
# NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name)
assert in_var_dist_attr is not None
ref_mesh = in_var_dist_attr.process_mesh
ref_mapping = in_var_dist_attr.dims_mapping
consume_op_attr.set_input_dist_attr(cast_name,
in_var_dist_attr)
out_var = self._block.create_var(
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient)
set_var_dist_attr(dist_context, out_var, ref_mapping,
ref_mesh)
cast_op = self._block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
num_cast_ops += 1
else:
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name)
consume_op_attr.set_input_dist_attr(cast_name,
in_var_dist_attr)
_rename_arg(op, in_var.name, cast_name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dst_dtype)
self._var_name_dict[op.desc.id()] = var_name_dict
if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = self._block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
return num_cast_ops
def cast_backward_program(self, params_grads, dist_context):
self._block._sync_with_cpp()
ops = self._block.ops
loss_op = get_loss_op(self._block)
loss_op_index = find_op_index(self._block.desc, loss_op.desc)
idx = loss_op_index + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
dist_op_context = dist_context.dist_op_context
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op.desc.id()) == False: # fp32
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context)
elif self._is_fp16_op(grad_op.desc.id()) == True: # fp16
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
for in_var_name in grad_op.desc.input_arg_names():
assert src_dtype == self._block.var(in_var_name).dtype
out_var_name = grad_op.desc.output_arg_names()[0]
out_var = self._block.var(out_var_name)
if out_var.dtype != src_dtype:
out_var.desc.set_dtype(src_dtype)
elif int(grad_op.attr('op_role')) == 257:
pass
else:
raise ValueError(
"'{}' op is not supported in the complete amp pass.".format(
grad_op.type))
idx += num_cast_ops + 1
self._block._sync_with_cpp()
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype,
dist_context):
""" only for backward cast """
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return in_name not in {'X', 'Y@GRAD'}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['layer_norm_grad']:
return out_name != 'X@GRAD'
return False
num_cast_ops = 0
dist_op_context = dist_context.dist_op_context
fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]
for in_name in grad_op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
grad_op, in_name):
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
assert in_var.dtype == core.VarDesc.VarType.FP32
continue
for in_var_name in grad_op.input(in_name):
in_var = self._block._find_var_recursive(in_var_name)
if in_var.dtype == src_dtype:
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op)
if in_var_name in self._var_name_dict[fwd_op_id]:
# NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr.
cast_name = self._var_name_dict[fwd_op_id][in_var_name]
grad_op.desc._rename_input(in_var_name, cast_name)
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var_name)
consume_op_attr.set_input_dist_attr(cast_name,
in_var_dist_attr)
else:
assert in_var.dtype == dst_dtype
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
grad_op, out_name):
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
assert out_var.dtype == core.VarDesc.VarType.FP32
continue
for out_var_name in grad_op.output(out_name):
out_var = self._block._find_var_recursive(out_var_name)
out_var_name_prefix = out_var_name[:out_var_name.find("@")]
fwd_var = self._block._find_var_recursive(out_var_name_prefix)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if out_var.dtype != fwd_var.dtype:
out_var.desc.set_dtype(fwd_var.dtype)
if out_var.dtype == src_dtype:
if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
# NOTE: if out_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype
consume_op_attr = dist_context.get_op_dist_attr_for_program(
grad_op)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix]
cast_name = fwd_cast_name + "@GRAD"
cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
out_var_dist_attr = consume_op_attr.get_output_dist_attr(
out_var_name)
ref_mesh = out_var_dist_attr.process_mesh
ref_mapping = out_var_dist_attr.dims_mapping
consume_op_attr.set_output_dist_attr(
cast_name, out_var_dist_attr)
assert ref_mapping is not None
cast_var = self._block.create_var(
name=cast_name,
shape=out_var.shape,
dtype=dst_dtype,
persistable=False,
stop_gradient=out_var.stop_gradient)
set_var_dist_attr(dist_context, cast_var,
ref_mapping, ref_mesh)
cast_op = self._block._insert_op(
idx + 1,
type="cast",
inputs={"X": cast_var},
outputs={"Out": out_var},
attrs={
"in_dtype": cast_var.dtype,
"out_dtype": out_var.dtype,
"op_role": OpRole.Backward
})
cast_op._remove_attr("op_role_var")
cast_op._remove_attr("op_namescope")
cast_op._remove_attr("with_quant_attr")
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
num_cast_ops += 1
else:
assert out_var.dtype == dst_dtype
return num_cast_ops
def _update_backward_cast_ops(params_grads, dist_context):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
'op_role_var'):
op._remove_attr("op_role_var")
post_ops = find_true_post_op(main_block.ops, op, g.name)
if post_ops:
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
if op == main_block.ops[-1]:
continue
# add new op in the python and cpp at the same time
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = paddle.fluid.framework.Operator(
block=main_block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None)
main_block.ops.append(new_op)
# dist attr
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
main_block.var(op.output_arg_names[0]))
assert param_dist_attr is not None
assert output_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, param_dist_attr.process_mesh,
param_dist_attr.dims_mapping, dist_context)
output_dist_attr.process_mesh = param_dist_attr.process_mesh
output_dist_attr.dims_mapping = param_dist_attr.dims_mapping
op_idx = find_op_index(main_block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
main_block._remove_op(op_idx, sync=False)
main_block._sync_with_cpp()
def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
grads = [g for _, g in params_grads]
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
['find_infinite_scale', 'tmp'])),
shape=[1],
dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
set_var_dist_attr(dist_context, found_inf, [-1], global_process_mesh)
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = global_process_mesh
if len(global_process_mesh) > 1:
new_op_dist_attr.impl_idx = 0
for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(g.name,
g_dist_attr.dims_mapping)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf
@register_pass("auto_parallel_amp")
class AMPPass(PassBase):
def __init__(self):
super(AMPPass, self).__init__()
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("custom_white_list", None)
self.set_attr("custom_black_list", None)
self.set_attr("custom_black_varnames", None)
self.set_attr("init_loss_scaling", 32768.0)
self.set_attr("incr_every_n_steps", 1000)
self.set_attr("decr_every_n_nan_or_inf", 2)
self.set_attr("incr_ratio", 2.0)
self.set_attr("decr_ratio", 0.8)
self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("params_grads", [])
self._loss_scaling = None
self._num_good_steps = None
self._num_bad_steps = None
def _check_self(self):
if self.get_attr("init_loss_scaling") < 0:
return False
if self.get_attr("incr_every_n_steps") < 0:
return False
if self.get_attr("decr_every_n_nan_or_inf") < 0:
return False
if self.get_attr("incr_ratio") < 0:
return False
if self.get_attr("decr_ratio") < 0:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
# NOTE: why AMPBackwardPass can override apply_single_impl instead of
# apply_impl? AMP is an optimization pass for serial program,
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
amp_lists = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")))
amp_state = AMPState(main_program.global_block())
amp_state._build_stats(amp_lists, self.dist_context)
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_forward_program(self.dist_context)
amp_state.cast_backward_program(params_grads, self.dist_context)
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
self._scale_loss()
if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
"init_loss_scaling") != 1.0:
grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)
def _init_amp_var(self):
self._loss_scaling = paddle.static.create_global_var(
name=unique_name.generate("loss_scaling"),
shape=[1],
value=self.get_attr("init_loss_scaling"),
dtype='float32',
persistable=True)
set_var_dist_attr(self.dist_context, self._loss_scaling, [-1],
global_process_mesh)
if self.get_attr("use_dynamic_loss_scaling"):
self._num_good_steps = paddle.static.create_global_var(
name=unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
set_var_dist_attr(self.dist_context, self._num_good_steps, [-1],
global_process_mesh)
self._num_bad_steps = paddle.static.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1],
global_process_mesh)
def _scale_loss(self):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op)
if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32')
if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
"init_loss_scaling") != 1.0:
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
# forward
ref_mesh = loss_op_dist_attr.process_mesh
self._scaled_loss = main_block.create_var(
name=unique_name.generate("scaled_loss"),
shape=loss.shape,
dtype=loss.dtype,
persistable=loss.persistable)
set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
ref_mesh)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1,
type='elementwise_mul',
inputs={'X': [loss],
'Y': [self._loss_scaling]},
outputs={'Out': [self._scaled_loss]},
attrs={'op_role': loss_op.all_attrs()[OP_ROLE_KEY], })
loss_op._set_attr(OP_ROLE_KEY,
core.op_proto_and_checker_maker.OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_op, ref_mesh, [-1], self.dist_context)
# backward
first_backward_op = main_block.ops[loss_op_idx + 2]
assert first_backward_op.type == "fill_constant" and int(
first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
self._scaled_loss_grad = main_block.create_var(
name=unique_name.generate("scaled_loss") + "@GRAD",
shape=loss.shape,
dtype=loss.dtype,
persistable=loss.persistable)
set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1],
ref_mesh)
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name,
self._scaled_loss_grad.name)
# FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op(
loss_op_idx + 3)
elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
elementwise_mul_grad_op_desc.set_input(
'Out@GRAD', [self._scaled_loss_grad.name])
elementwise_mul_grad_op_desc.set_input('X', [loss.name])
elementwise_mul_grad_op_desc.set_input('Y',
[self._loss_scaling.name])
elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
elementwise_mul_grad_op_desc._set_attr(
OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward)
elementwise_mul_grad_op_desc._set_attr('axis', -1)
elementwise_mul_grad_op = paddle.fluid.framework.Operator(
main_block, elementwise_mul_grad_op_desc)
main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
main_block._sync_with_cpp()
elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context)
else:
self._scaled_loss = loss
main_block._sync_with_cpp()
def _update_loss_scaling(self, grads, found_inf):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
check_variable_and_dtype(self._loss_scaling, "prev_loss_scaling",
['float32', 'float64'], "update_loss_scaling")
check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
assert self._loss_scaling.dtype == e.dtype, \
"The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs = {
'X': grads,
'FoundInfinite': found_inf,
'PrevLossScaling': self._loss_scaling,
'InGoodSteps': self._num_good_steps,
'InBadSteps': self._num_bad_steps
}
outputs = {
'Out': grads,
'LossScaling': self._loss_scaling,
'OutGoodSteps': self._num_good_steps,
'OutBadSteps': self._num_bad_steps
}
attrs = {
'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
'incr_ratio': self.get_attr("incr_ratio"),
'decr_ratio': self.get_attr("decr_ratio"),
'stop_update': self.get_attr("stop_update"),
'op_role': OpRole.Backward
}
new_op = main_block.append_op(
type='update_loss_scaling',
inputs=inputs,
outputs=outputs,
attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = global_process_mesh
if len(global_process_mesh) > 1:
new_op_dist_attr.impl_idx = 0
for g in grads:
g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(g.name,
g_dist_attr.dims_mapping)
self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
main_block._sync_with_cpp()
......@@ -21,7 +21,7 @@ from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
from paddle.distributed.auto_parallel.process_group import get_world_process_groups, new_process_group
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import is_parameter_related
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.distributed.fleet as fleet
from auto_parallel_pass_test_base import AutoPallelPassTestBase
class TestAMPPass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = True
dist_strategy.amp_configs = {
"custom_white_list": [
'softmax',
'layer_norm',
'gelu',
],
"custom_black_list": ['c_softmax_with_cross_entropy'],
"init_loss_scaling": 32768,
"use_dynamic_loss_scaling": True,
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(
gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model("mp", place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册