You need to sign in or sign up before continuing.
未验证 提交 cc24427e 编写于 作者: J JZ-LIANG 提交者: GitHub

[Dist Pass] Amp Pass (#38764)

* auto parallel sharding base

* chmod

* add unitest

* set unitest cmake dist label

* revise code according to rewiew

* chmod

* bugfix for grad_clip and param broadcast

* chmod

* update unitest

* chmod

* add clip

* chmod

* add amp pass

* chmod

* add unitest

* remove grad update

* fixed bug

* fixed bug

* fixed typose

* fixed typoes
上级 4a64ca1e
......@@ -23,3 +23,4 @@ from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
from . import dist_default
from . import dist_check_finite_and_unscale
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..utils import set_var_dist_attr
from ..utils import set_dist_op_desc_original_id
from ..process_group import new_process_group
from ..dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.process_group import get_world_process_group
global_process_mesh = get_world_process_group().ranks
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscale, self).__init__()
self._name = name
register_distributed_operator_impl_container(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscale("check_finite_and_unscale"))
class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__()
self._name = name
self._forward_implemented = False
self._backward_implemented = True
def is_input_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_input_compatible should not be called !"
)
def is_output_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !"
)
def update_dims_mapping(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !"
)
@staticmethod
def forward(ctx, *args, **kwargs):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's forward should not be called !"
)
@staticmethod
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
assert rank_id in dist_attr.process_mesh.processes
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Scale' in kwargs, "input [{}] is not given".format('Scale')
assert 'Out' in kwargs, "input [{}] is not given".format('Out')
assert 'FoundInfinite' in kwargs, "output [{}] is not given".format(
'FoundInfinite')
assert len(
kwargs['Scale']
) == 1, "check_finite_and_unscale input Scale take 1 variable but got {}".format(
kwargs['Scale'])
assert len(
kwargs['FoundInfinite']
) == 1, "check_finite_and_unscale input FoundInfinite take 1 variable but got {}".format(
kwargs['FoundInfinite'])
assert len(kwargs['X']) == len(
kwargs['Out']
), "check_finite_and_unscale got [{}] X and [{}] Out, which are supposed to be equal".format(
len(kwargs['X']), len(kwargs['Out']))
filter_vars = []
for varname in kwargs['X']:
if rank_id in ctx.get_tensor_dist_attr_for_program(
main_block.var(varname)).process_mesh.processes:
filter_vars.append(varname)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars)
main_block._sync_with_cpp()
# sync result
group = new_process_group(global_process_mesh)
inf_var = main_block.var(kwargs['FoundInfinite'][0])
inf_var_int32 = main_block.create_var(
name=inf_var.name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
set_var_dist_attr(
ctx, inf_var_int32,
ctx.get_tensor_dist_attr_for_program(inf_var).dims_mapping,
ctx.get_tensor_dist_attr_for_program(inf_var).process_mesh)
cast_op1 = main_block.append_op(
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Backward
})
allreduce_op = main_block.append_op(
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
cast_op2 = main_block.append_op(
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Backward
})
main_block._sync_with_cpp()
for op in [cast_op1, allreduce_op, cast_op2]:
new_op_dist_attr = OperatorDistributedAttribute()
for varname in op.input_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block.var(varname))
assert var_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(
varname, var_dist_attr.dims_mapping)
for varname in op.output_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block.var(varname))
new_op_dist_attr.set_output_dims_mapping(
varname, var_dist_attr.dims_mapping)
new_op_dist_attr.process_mesh = var_dist_attr.process_mesh
ctx.set_op_dist_attr_for_program(op, new_op_dist_attr)
register_distributed_operator_impl(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscaleImpl("check_finite_and_unscale"))
......@@ -36,7 +36,7 @@ from .completion import complete_annotation, complete_backward_annotation, compl
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
from .process_group import get_world_process_groups
from .process_group import get_world_process_group
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard
from .utils import set_grad_var_shape
......@@ -97,13 +97,16 @@ class AutoParallelizer:
if suffix in attr_name:
op._remove_attr(attr_name)
def _apply_serial_pass(self, main_program, startup_program):
def _apply_pre_optimization_passed(self, main_program, startup_program,
loss, params_grads):
# apply amp pass
if self._dist_strategy.amp:
auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass",
self._dist_strategy.amp_configs)
auto_parallel_amp_pass.apply(main_program, startup_program,
config = copy.deepcopy(self._dist_strategy.amp_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
# apply recompute pass
......@@ -185,10 +188,10 @@ class AutoParallelizer:
self._parameter_list, self._no_grad_set, self._callbacks)
# serial forward pass
self._apply_serial_pass(completed_main_program, serial_startup_program)
self._apply_pre_optimization_passed(completed_main_program,
serial_startup_program, serial_loss,
params_grads)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
completed_main_program, serial_startup_program, params_grads)
......@@ -235,7 +238,7 @@ class AutoParallelizer:
assert self._cluster is not None, \
"The cluster must not be none when using auto mapping."
dist_programs = {}
world_process_group = get_world_process_groups()
world_process_group = get_world_process_group()
dist_context = None
# auto search
if self._dist_strategy.auto_search:
......
......@@ -33,7 +33,7 @@ def get_process_group(group_id, g_process_group_map=None):
group_id, None)
def get_world_process_groups():
def get_world_process_group():
global _g_process_group_map
return _g_process_group_map[0]
......
......@@ -16,6 +16,7 @@ from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import *
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .cpp_pass import *
__all__ = [
......
此差异已折叠。
......@@ -21,7 +21,7 @@ from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
from paddle.distributed.auto_parallel.process_group import get_world_process_groups, new_process_group
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import is_parameter_related
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.distributed.fleet as fleet
from auto_parallel_pass_test_base import AutoPallelPassTestBase
class TestAMPPass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = True
dist_strategy.amp_configs = {
"custom_white_list": [
'softmax',
'layer_norm',
'gelu',
],
"custom_black_list": ['c_softmax_with_cross_entropy'],
"init_loss_scaling": 32768,
"use_dynamic_loss_scaling": True,
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(
gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model("mp", place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册