diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 9ff673b1d2901ba6ab8e60b6c48d5ddac8af8379..7cad4d746bbf25307bf27db4fa41fc898a86d0b5 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -23,6 +23,7 @@ import logging import pickle import time import paddle +from paddle.fluid.backward import append_backward from paddle.distributed.utils import get_logger from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core @@ -96,49 +97,35 @@ class AutoParallelizer: if suffix in attr_name: op._remove_attr(attr_name) - def _apply_serial_forward_pass(self, main_program, startup_program): + def _apply_serial_pass(self, main_program, startup_program): - # apply amp forward pass + # apply amp pass if self._dist_strategy.amp: auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass", self._dist_strategy.amp_configs) - auto_parallel_amp_pass.apply_forward(main_program, startup_program, - self._pass_context) + auto_parallel_amp_pass.apply(main_program, startup_program, + self._pass_context) - # apply recompute forward pass + # apply recompute pass if self._dist_strategy.recompute: auto_parallel_recompute_pass = new_pass( "auto_parallel_recompute_pass", self._dist_strategy.recompute_configs) - auto_parallel_recompute_pass.apply_forward( - main_program, startup_program, self._pass_context) + auto_parallel_recompute_pass.apply(main_program, startup_program, + self._pass_context) def _generate_backward(self, main_program, startup_program, loss, parameter_list, no_grad_set, callbacks): - # apply recompute backward pass - if self._dist_strategy.recompute: - assert auto_parallel_recompute_pass - auto_parallel_recompute_pass.apply_forward( - main_program, startup_program, parameter_list, no_grad_set, - self._pass_context) - else: - from paddle.fluid.backward import append_backward - with program_guard(main_program, startup_program): - params_grads = append_backward( - loss, - parameter_list, - no_grad_set, - callbacks, - distop_context=self._dist_context.dist_op_context) - complete_backward_annotation( - main_program, dist_context=self._dist_context) - - # apply amp forward pass - if self._dist_strategy.amp: - assert auto_parallel_amp_pass - auto_parallel_amp_pass.apply_backward(main_program, startup_program, - self._pass_context) + with program_guard(main_program, startup_program): + params_grads = append_backward( + loss, + parameter_list, + no_grad_set, + callbacks, + distop_context=self._dist_context.dist_op_context) + complete_backward_annotation( + main_program, dist_context=self._dist_context) return params_grads @@ -192,14 +179,14 @@ class AutoParallelizer: completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) - # serial forward pass - self._apply_serial_forward_pass(completed_main_program, - serial_startup_program) # serial backward pass params_grads = self._generate_backward( completed_main_program, serial_startup_program, serial_loss, self._parameter_list, self._no_grad_set, self._callbacks) + # serial forward pass + self._apply_serial_pass(completed_main_program, serial_startup_program) + # Logical partition rank = paddle.distributed.get_rank() partitioner = Partitioner(self._dist_context, rank) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 5e799c52092db3b4800046633b6a03c104473556..2785eae6e8a469642c39fe1942c3e2ede4bfd87b 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -94,7 +94,7 @@ class ShardingPass(PassBase): def _collective_data_parallel_groups(self, main_block): for op in main_block.ops: - if op.type in _skip_ops: + if not _is_forward_op(op) or op.type in _skip_ops: continue group = _inference_data_parallel_group_for_operator( self.global_rank, op, self._dist_context) @@ -106,7 +106,7 @@ class ShardingPass(PassBase): if len(self.dp_groups) != 1: raise NotImplementedError( "So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups". - format(len(groups))) + format(len(self.dp_groups))) def _build_sharding_infos(self, params_grads): @@ -193,18 +193,32 @@ class ShardingPass(PassBase): return # TODO (JZ-LIANG) support calculate global norm with tensor parallelism - is_clip_grad_by_global_norm = False + removed_op_type = ['elementwise_mul', 'squared_l2_norm', 'clip_by_norm'] + removed_op_idx = set() + removed_tmp_var = set() + for idx, op in list(enumerate(main_block.ops)): if not _is_gradient_clip_op(op): continue - if op.type == 'sum': - is_clip_grad_by_global_norm = True - break - if not is_clip_grad_by_global_norm: - return - removed_op_idx = set() - removed_tmp_var = set() + if op.type in removed_op_type: + input_name = op.input("X")[0] + param_name = input_name[:input_name.find("@GRAD")] + if not self._is_parameter_in_local_shard(param_name): + removed_op_idx.add(idx) + if op.type in ['squared_l2_norm', 'clip_by_norm']: + for output_name in op.output_arg_names: + removed_tmp_var.add(output_name) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if not _is_gradient_clip_op(op): + continue + if idx in removed_op_idx: + main_block._remove_op(idx, sync=False) + + for varname in removed_tmp_var: + main_block._remove_var(varname, sync=False) + for idx, op in list(enumerate(main_block.ops)): if not _is_gradient_clip_op(op): continue @@ -218,7 +232,7 @@ class ShardingPass(PassBase): sum_op_output = op.desc.output_arg_names()[0] for i, sharding_info in enumerate(self.sharding_infos): new_op = main_block._insert_op( - idx + i, + idx + i + 1, type='c_allreduce_sum', inputs={'X': [sum_op_output]}, outputs={'Out': [sum_op_output]}, @@ -235,21 +249,6 @@ class ShardingPass(PassBase): new_op, dist_attr.process_mesh, dist_attr.dims_mapping, self._dist_context) break - for input_name in op.input_arg_names: - param_name = input_name[:input_name.find("@GRAD")] - if not self._is_parameter_in_local_shard(param_name): - removed_op_idx.add(idx) - for output_name in op.output_arg_names: - removed_tmp_var.add(output_name) - - for idx, op in reversed(list(enumerate(main_block.ops))): - if not _is_gradient_clip_op(op): - continue - if idx in removed_op_idx: - main_block._remove_op(idx, sync=False) - - for varname in removed_tmp_var: - main_block._remove_var(varname, sync=False) main_block._sync_with_cpp() @@ -424,12 +423,15 @@ class ShardingPass(PassBase): startup_block._remove_op(idx, sync=False) continue - if op.type != "c_broadcast" and output_name in not_used_param_nane: + if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank( + output_name) != sharding_info.local_rank: startup_block._remove_op(idx, sync=False) - for varname in not_used_param_nane: - main_block._remove_var(varname, sync=False) - startup_block._remove_var(varname, sync=False) + for param_name in param_usage: + if sharding_info.get_var_rank( + param_name) != sharding_info.local_rank: + main_block._remove_var(param_name, sync=False) + startup_block._remove_var(param_name, sync=False) main_block._sync_with_cpp() startup_block._sync_with_cpp() @@ -594,6 +596,10 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids): return block.var(base_name).is_parameter +def _is_forward_op(op): + return op.attr("op_role") == 0 + + def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): dp_group = None diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index f5eda2fdbf8e27252b9036d78ce8f1a6082456e8..42bdf678242206f62592dd1359b8b15f7c59a1c8 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -178,13 +178,13 @@ class AutoPallelPassTestBase(DistPassTestBase): preds = model(tokens, position_ids, attention_mask) criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) - + clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) optimizer = paddle.fluid.optimizer.AdamOptimizer( learning_rate=0.00001, beta1=0.9, beta2=0.999, epsilon=1e-08, - grad_clip=None) + grad_clip=clip) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize( diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py index f6b42701c2195f9907fd35a51b27f123bde8d02b..51e87260609df262c21e521d1c8cb080824e39d4 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py @@ -46,7 +46,7 @@ class TestShardingPass(AutoPallelPassTestBase): dist_strategy.sharding = True dist_strategy.sharding_configs = { "sharding_degree": 2, - "stage": 3, + "stage": 2, } fleet.init(is_collective=True, strategy=dist_strategy) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index ab91c3fe7c4c228167f4d3982a6cc34132c5dae4..83254de61298b32aa79194a3a5b1d9bb4f31c0a2 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -157,9 +157,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 9fe5a52cf08af17489221dc7b5572f5aaeffa0cd..3a28595c833e03990270c80c931decf820cd7995 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -478,8 +478,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): # auto completion complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) + params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 3270cfc3c8a5412fd118c92eda9c70a9aac31ad1..dc2ad1d900f52562d9dc30d15948252d830fdb58 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -884,10 +884,6 @@ class TestGPTPartitioner(unittest.TestCase): complete_train_program = auto.complete_annotation(train_program, dist_context) - # serial forward pass - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - # serial backward pass params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 0631cc74a32bdacddf8f45284355a571aac2b577..614b996d265214a2a21cb33914fda571f6a3d142 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -155,9 +155,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 0e098664f7ebb1e177f96fe694fe233216e51569..cfbb7653fad8eaf759afdfedcb9f4618fb081886 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -119,9 +119,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index c6b1be652073cb6845a96d2ffc28750dad1e36c1..272c1c212f08e782f56e22e6b388d3bce06928d1 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -134,8 +134,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) params_grads = parallelizer._generate_backward( complete_train_program,