未验证 提交 747000dd 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Pass bugfix (#38741)

上级 aec493c0
......@@ -23,6 +23,7 @@ import logging
import pickle
import time
import paddle
from paddle.fluid.backward import append_backward
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
......@@ -96,49 +97,35 @@ class AutoParallelizer:
if suffix in attr_name:
op._remove_attr(attr_name)
def _apply_serial_forward_pass(self, main_program, startup_program):
def _apply_serial_pass(self, main_program, startup_program):
# apply amp forward pass
# apply amp pass
if self._dist_strategy.amp:
auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass",
self._dist_strategy.amp_configs)
auto_parallel_amp_pass.apply_forward(main_program, startup_program,
self._pass_context)
auto_parallel_amp_pass.apply(main_program, startup_program,
self._pass_context)
# apply recompute forward pass
# apply recompute pass
if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass(
"auto_parallel_recompute_pass",
self._dist_strategy.recompute_configs)
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, self._pass_context)
auto_parallel_recompute_pass.apply(main_program, startup_program,
self._pass_context)
def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks):
# apply recompute backward pass
if self._dist_strategy.recompute:
assert auto_parallel_recompute_pass
auto_parallel_recompute_pass.apply_forward(
main_program, startup_program, parameter_list, no_grad_set,
self._pass_context)
else:
from paddle.fluid.backward import append_backward
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
# apply amp forward pass
if self._dist_strategy.amp:
assert auto_parallel_amp_pass
auto_parallel_amp_pass.apply_backward(main_program, startup_program,
self._pass_context)
with program_guard(main_program, startup_program):
params_grads = append_backward(
loss,
parameter_list,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
return params_grads
......@@ -192,14 +179,14 @@ class AutoParallelizer:
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
# serial forward pass
self._apply_serial_forward_pass(completed_main_program,
serial_startup_program)
# serial backward pass
params_grads = self._generate_backward(
completed_main_program, serial_startup_program, serial_loss,
self._parameter_list, self._no_grad_set, self._callbacks)
# serial forward pass
self._apply_serial_pass(completed_main_program, serial_startup_program)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_context, rank)
......
......@@ -94,7 +94,7 @@ class ShardingPass(PassBase):
def _collective_data_parallel_groups(self, main_block):
for op in main_block.ops:
if op.type in _skip_ops:
if not _is_forward_op(op) or op.type in _skip_ops:
continue
group = _inference_data_parallel_group_for_operator(
self.global_rank, op, self._dist_context)
......@@ -106,7 +106,7 @@ class ShardingPass(PassBase):
if len(self.dp_groups) != 1:
raise NotImplementedError(
"So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups".
format(len(groups)))
format(len(self.dp_groups)))
def _build_sharding_infos(self, params_grads):
......@@ -193,18 +193,32 @@ class ShardingPass(PassBase):
return
# TODO (JZ-LIANG) support calculate global norm with tensor parallelism
is_clip_grad_by_global_norm = False
removed_op_type = ['elementwise_mul', 'squared_l2_norm', 'clip_by_norm']
removed_op_idx = set()
removed_tmp_var = set()
for idx, op in list(enumerate(main_block.ops)):
if not _is_gradient_clip_op(op):
continue
if op.type == 'sum':
is_clip_grad_by_global_norm = True
break
if not is_clip_grad_by_global_norm:
return
removed_op_idx = set()
removed_tmp_var = set()
if op.type in removed_op_type:
input_name = op.input("X")[0]
param_name = input_name[:input_name.find("@GRAD")]
if not self._is_parameter_in_local_shard(param_name):
removed_op_idx.add(idx)
if op.type in ['squared_l2_norm', 'clip_by_norm']:
for output_name in op.output_arg_names:
removed_tmp_var.add(output_name)
for idx, op in reversed(list(enumerate(main_block.ops))):
if not _is_gradient_clip_op(op):
continue
if idx in removed_op_idx:
main_block._remove_op(idx, sync=False)
for varname in removed_tmp_var:
main_block._remove_var(varname, sync=False)
for idx, op in list(enumerate(main_block.ops)):
if not _is_gradient_clip_op(op):
continue
......@@ -218,7 +232,7 @@ class ShardingPass(PassBase):
sum_op_output = op.desc.output_arg_names()[0]
for i, sharding_info in enumerate(self.sharding_infos):
new_op = main_block._insert_op(
idx + i,
idx + i + 1,
type='c_allreduce_sum',
inputs={'X': [sum_op_output]},
outputs={'Out': [sum_op_output]},
......@@ -235,21 +249,6 @@ class ShardingPass(PassBase):
new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
self._dist_context)
break
for input_name in op.input_arg_names:
param_name = input_name[:input_name.find("@GRAD")]
if not self._is_parameter_in_local_shard(param_name):
removed_op_idx.add(idx)
for output_name in op.output_arg_names:
removed_tmp_var.add(output_name)
for idx, op in reversed(list(enumerate(main_block.ops))):
if not _is_gradient_clip_op(op):
continue
if idx in removed_op_idx:
main_block._remove_op(idx, sync=False)
for varname in removed_tmp_var:
main_block._remove_var(varname, sync=False)
main_block._sync_with_cpp()
......@@ -424,12 +423,15 @@ class ShardingPass(PassBase):
startup_block._remove_op(idx, sync=False)
continue
if op.type != "c_broadcast" and output_name in not_used_param_nane:
if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank(
output_name) != sharding_info.local_rank:
startup_block._remove_op(idx, sync=False)
for varname in not_used_param_nane:
main_block._remove_var(varname, sync=False)
startup_block._remove_var(varname, sync=False)
for param_name in param_usage:
if sharding_info.get_var_rank(
param_name) != sharding_info.local_rank:
main_block._remove_var(param_name, sync=False)
startup_block._remove_var(param_name, sync=False)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
......@@ -594,6 +596,10 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
return block.var(base_name).is_parameter
def _is_forward_op(op):
return op.attr("op_role") == 0
def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dp_group = None
......
......@@ -178,13 +178,13 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
......
......@@ -46,7 +46,7 @@ class TestShardingPass(AutoPallelPassTestBase):
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
"stage": 2,
}
fleet.init(is_collective=True, strategy=dist_strategy)
......
......@@ -157,9 +157,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
......
......@@ -478,8 +478,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
......
......@@ -884,10 +884,6 @@ class TestGPTPartitioner(unittest.TestCase):
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# serial forward pass
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
# serial backward pass
params_grads = parallelizer._generate_backward(
complete_train_program,
......
......@@ -155,9 +155,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
......
......@@ -119,9 +119,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
......
......@@ -134,8 +134,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
parallelizer._apply_serial_forward_pass(complete_train_program,
startup_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册