From 7aeec4ed42f0e4e19c2a1120f98112339cb75e95 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 12 Aug 2022 17:33:17 +0800 Subject: [PATCH] [Auto Parallel] Data Parallel Optimization Pass 1 (#44882) * bugfix * remove scaling * support rescale_grad opt --- .../auto_parallel/operators/common.py | 10 + .../auto_parallel/parallelizer_v2.py | 8 + .../auto_parallel/process_group.py | 19 +- .../auto_parallel/tuner/optimization_tuner.py | 12 +- .../paddle/distributed/auto_parallel/utils.py | 15 ++ python/paddle/distributed/passes/__init__.py | 1 + ...uto_parallel_data_parallel_optimization.py | 207 ++++++++++++++++++ .../distributed_passes/CMakeLists.txt | 2 + .../auto_parallel_pass_test_base.py | 18 +- ...arallel_data_parallel_optimization_pass.py | 110 ++++++++++ 10 files changed, 382 insertions(+), 20 deletions(-) create mode 100644 python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 75002ae4ce1..7b4eb27fc82 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -435,3 +435,13 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, return sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names) + + +def is_data_parallel_scale_op(op): + return op.type == "scale" and op.desc.has_attr("op_namescope") \ + and ParallelMode.DataParallel in op.desc.attr("op_namescope") + + +def is_data_parallel_reduce_op(op): + return op.type in ["c_reduce_sum", "c_allreduce_sum"] and op.desc.has_attr("op_namescope") \ + and ParallelMode.DataParallel in op.desc.attr("op_namescope") diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 005e51dfce7..e6b30e03680 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -195,6 +195,14 @@ class Parallelizer: params_grads): if self._strategy is None: return + + # data parallel optimization + config = {} + config["dist_context"] = self._dist_context + config["global_rank"] = rank + dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) + dp_pass.apply([main_program], [startup_program], self._pass_context) + if self._strategy.sharding: config = copy.deepcopy(self._strategy.sharding_configs) config["dist_context"] = self._dist_context diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 17f960381aa..5b0d5e286ff 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -160,21 +160,24 @@ class ProcessGroup: def is_member(self): return True - # def __eq__(self, other): - # if not isinstance(other, ProcessGroup): - # return False - # if self.id != other.id: - # return False - # return True + def __eq__(self, other): + if not isinstance(other, ProcessGroup): + return False + if self.id != other.id: + return False + return True - # def __ne__(self, other): - # return not self.__eq__(other) + def __ne__(self, other): + return not self.__eq__(other) def __str__(self): string = "id: {}, nranks: {}, ranks: {}.".format( self.id, self.nranks, ", ".join(map(str, self.ranks))) return string + def __hash__(self): + return hash(self.__str__()) + # Note that Process group 0 is reserved for representing all ranks. # At the beginning, group 0 is empty and new ranks will be added automatically. diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 89b6d22e32a..bb50e2fb9c5 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -266,7 +266,7 @@ class OptimizationTuner: config["input_data"] = self._baseline_dist_context.serial_feed_vars["inputs"] \ + self._baseline_dist_context.serial_feed_vars["labels"] if config["use_pure_fp16"]: - config["base_opt"] = dist_context.optimizer + config["base_opt"] = dist_context.serial_optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply([main_program], [startup_program], pass_context) @@ -363,11 +363,11 @@ class OptimizationTuner: profile_args = " ".join([ "--rank", - str(self.rank), - "--device_id", - str(self.device_id), - "--ctx_filename", - ctx_path, + str(self.rank), "--device_id", + str(self.device_id), "--ctx_filename", ctx_path, + "--profile_start_step", + str(self._config.profile_start_step), "--profile_end_step", + str(self._config.profile_end_step) ]) cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index b0d4963140e..d6fd06647ba 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -23,6 +23,7 @@ from functools import reduce import paddle.fluid.core as core from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.distributed.auto_parallel.process_group import get_all_process_groups from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute @@ -1123,6 +1124,13 @@ def is_loss_op(op): int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss)) +def is_loss_grad_op(op): + if OP_ROLE_KEY not in op.attr_names: + return False + op_role = int(op.all_attrs()[OP_ROLE_KEY]) + return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss) + + def is_prim_op(op): return op.type.endswith("_p") @@ -1481,3 +1489,10 @@ def debug_program(program, path, name): path, name + '_program' + ".%d" % (paddle.distributed.get_rank())) with open(filename, 'w') as f: f.write(str(program)) + + +def ring_id_to_process_group(ring_id): + for g in get_all_process_groups(): + if g.id == ring_id: + return g + return None diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 3649d571aa4..670e7f003d7 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -19,6 +19,7 @@ from .auto_parallel_sharding import * from .auto_parallel_amp import * from .auto_parallel_fp16 import * from .auto_parallel_recompute import * +from .auto_parallel_data_parallel_optimization import * from .cpp_pass import * import os from .ps_trainer_pass import * diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py new file mode 100644 index 00000000000..b274f7b9b84 --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -0,0 +1,207 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import paddle +from paddle.fluid.framework import default_main_program +from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op +from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group +from .pass_base import PassBase, PassType, register_pass + +# add new optimizers supporting rescale_grad here +__rescale_grad_supported_opts__ = [ + 'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum', + 'merge_momentum' +] + + +@register_pass("auto_parallel_data_parallel_optimization") +class DataParallelOptimizationPass(PassBase): + """ + Apply Optimizations that specialized for data parallelism in Auto Parallel. + 1. prune grad scaling + 2. overlap comm and calc + 3. fuse allreduce + """ + + def __init__(self): + super(DataParallelOptimizationPass, self).__init__() + # NOTE not use depence on loss and param_grads + self.set_attr("dist_context", None) + self.set_attr("global_rank", -1) + # {grad1: group1, grad2: group1, grad3: group2} + # record the order for fuse grad data memory + self._grad_name_to_group_map = OrderedDict() + # {group1:[grad1, grad2] , group2:[grad3]} + self._group_to_grad_name_map = OrderedDict() + self._support_rescale_grad = False + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + if (not isinstance(self.get_attr("global_rank"), + int)) or self.get_attr("global_rank") < 0: + return False + + return True + + def _check_conflict(self, other_pass): + return True + + def _type(self): + return PassType.COMM_OPT + + def _apply_single_impl(self, main_program, startup_program, context): + + self.dist_context = self.get_attr("dist_context") + self.global_rank = int(self.get_attr("global_rank")) + + with paddle.static.program_guard(main_program, startup_program): + self._analyze_program() + self._prune_grad_scaling() + self._overlap_comm() + self._fuse_allreduce() + + def _prune_grad_scaling(self): + + if not self._could_be_prune(): + return + + if self._all_dp_groups_same_degree(): + self._scale_backward_initial_grad() + else: + self._update_opt_rescale_grad() + + self._remove_grad_scaling() + + def _overlap_comm(self): + pass + + def _fuse_allreduce(self): + pass + + def _analyze_program(self): + """ + {param_grad_name: data_parallel_group} + {pdata_parallel_group: aram_grad_name} + """ + + block = default_main_program().global_block() + ops = block.ops + scaled_grads = [] + + for op in ops: + if is_data_parallel_reduce_op(op): + grad_name = op.output_arg_names[0] + if grad_name in self._grad_name_to_group_map: + continue + assert op.has_attr( + "ring_id" + ), "Unexception: comm op [{}] has NOT ring id.".format(str(op)) + group = ring_id_to_process_group(op.attr("ring_id")) + + assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format( + grad_name, str(op)) + + self._grad_name_to_group_map[grad_name] = group + + if group not in self._group_to_grad_name_map: + self._group_to_grad_name_map[group] = [grad_name] + else: + self._group_to_grad_name_map[group].append(grad_name) + + elif is_data_parallel_scale_op(op): + grad_name = op.output_arg_names[0] + scaled_grads.append(grad_name) + + # TODO support multiple optimizers in on network in future. + # here we assume that the optimizer is unique in network. + elif is_optimize_op( + op) and op.type in __rescale_grad_supported_opts__: + self._support_rescale_grad = True + + not_synchronized_grads = [] + for grad_name in scaled_grads: + if grad_name not in self._grad_name_to_group_map: + not_synchronized_grads.append(grad_name) + assert len( + not_synchronized_grads + ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( + not_synchronized_grads) + + def _could_be_prune(self): + + return self._support_rescale_grad or self._all_dp_groups_same_degree() + + def _all_dp_groups_same_degree(self): + return len( + set([ + len(group.ranks) + for group in self._group_to_grad_name_map.keys() + ])) == 1 + + def _scale_backward_initial_grad(self): + + block = default_main_program().global_block() + dp_degree = len(list(self._group_to_grad_name_map.keys())[0].ranks) + + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + assert op.type == 'fill_constant', \ + "loss_grad_op must be fill_constant op, " \ + "but this op is {}".format(op.type) + assert op.has_attr('value') + loss_scale = float(op.attr('value')) + loss_scale = loss_scale / dp_degree + op._set_attr('value', loss_scale) + break + + def _remove_grad_scaling(self): + block = default_main_program().global_block() + + for op_idx, op in reversed(list(enumerate(block.ops))): + if is_data_parallel_scale_op(op): + block._remove_op(op_idx, False) + + block._sync_with_cpp() + + def _update_opt_rescale_grad(self): + + block = default_main_program().global_block() + scaled_grads = set() + + for idx, op in reversed(list(enumerate(block.ops))): + if is_optimize_op( + op) and op.type in __rescale_grad_supported_opts__: + assert op.has_attr( + 'rescale_grad' + ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format( + str(op)) + assert len( + op.input("Grad") + ) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format( + str(op)) + + grad_name = op.input("Grad")[0] + dp_degree = len( + list(self._grad_name_to_group_map[grad_name].ranks)) + scaled_grads.add(grad_name) + + rescale_grad = float(op.attr('rescale_grad')) / dp_degree + op._set_attr('rescale_grad', rescale_grad) + + assert scaled_grads == set(self._grad_name_to_group_map.keys( + )), "Unexception: gradients [{}] are unscaled.".format( + set(self._grad_name_to_group_map.keys()) - scaled_grads) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt index 51f298eccdb..b9f4d818282 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/distributed_passes/CMakeLists.txt @@ -20,6 +20,8 @@ if((NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass") + list(REMOVE_ITEM TEST_OPS + "test_auto_parallel_data_parallel_optimization_pass") endif() foreach(TEST_OP ${TEST_OPS}) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index 63abdeef595..ec879e77611 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -108,7 +108,7 @@ class AutoPallelPassTestBase(DistPassTestBase): pickle.dump(all_fetch_values, f) def get_gpt_model(self, strategy, place, batch_size, sequence_len, - vocab_size): + vocab_size, **kwargs): modeling.init_global() if strategy == "dp": modeling._global_parallel_strategy = "dp" @@ -179,11 +179,17 @@ class AutoPallelPassTestBase(DistPassTestBase): criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) - optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=clip) + + if kwargs.get('optimizer', None) == "LarsMomentum": + optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( + learning_rate=0.001, momentum=0.9) + else: + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=clip) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize( diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py new file mode 100644 index 00000000000..f8fe59f6979 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_data_parallel_optimization_pass.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numpy as np + +import unittest +import paddle +import paddle.nn as nn +import paddle.distributed.fleet as fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context +from paddle.distributed.passes import new_pass, PassManager, PassContext +from auto_parallel_pass_test_base import AutoPallelPassTestBase + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + + +class TestDataParallelPassWithScale1(AutoPallelPassTestBase): + + def init(self): + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.rtol = 1e-5 + self.atol = 1e-8 + # NOTE a hack to compare pass apply or not, since there is no + # setting of this pass in dist_strategy + self._apply_pass = False + + rank = paddle.distributed.get_rank() + paddle.seed(rank + 2021) + random.seed(rank + 2021) + np.random.seed(rank + 2021) + + def apply_passes(self): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + self._apply_pass = True + + def apply_no_passes(self): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + self._apply_pass = False + + def test_bs_8(self): + self.check_main(gpus=[0, 1], + batch_size=8, + sequence_len=512, + vocab_size=1000) + + # test scaling with fillconstant + def get_model(self, place, batch_size, sequence_len, vocab_size): + + dist_main_prog, dist_startup_prog, data_holder, [ + loss + ], gen_data = self.get_gpt_model('dp', place, batch_size, sequence_len, + vocab_size) + if self._apply_pass: + config = {} + config["dist_context"] = get_default_distributed_context() + config["global_rank"] = paddle.distributed.get_rank() + dp_pass = new_pass("auto_parallel_data_parallel_optimization", + config) + dp_pass.apply([dist_main_prog], [dist_startup_prog], PassContext()) + + return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data + + +class TestDataParallelPassWithScale2(TestDataParallelPassWithScale1): + + # test scaling with optimizer rescale_grad + def get_model(self, place, batch_size, sequence_len, vocab_size): + + dist_main_prog, dist_startup_prog, data_holder, [ + loss + ], gen_data = self.get_gpt_model('dp', + place, + batch_size, + sequence_len, + vocab_size, + optimizer='LarsMomentum') + if self._apply_pass: + config = {} + config["dist_context"] = get_default_distributed_context() + config["global_rank"] = paddle.distributed.get_rank() + dp_pass = new_pass("auto_parallel_data_parallel_optimization", + config) + dp_pass.apply([dist_main_prog], [dist_startup_prog], PassContext()) + + return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data + + +if __name__ == "__main__": + unittest.main() -- GitLab