diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index a383780f7740298c37c74c5d4897e9122f7a0d5a..4712634a6c4eb5d5db48ffc5355b1cd3f4254938 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -60,14 +60,10 @@ class Engine: strategy=None, user_tuning_config=None): self.model = model + self.strategy = strategy or fleet.DistributedStrategy() self.inputs_spec = self._validate_spec(inputs_spec) self.labels_spec = self._validate_spec(labels_spec) - self.cluster = cluster - if self.cluster is None: - self.cluster = get_default_cluster() - self.strategy = strategy - if self.strategy is None: - self.strategy = fleet.DistributedStrategy() + self.cluster = cluster or get_default_cluster() self._user_tuning_config = user_tuning_config self._executor = None @@ -433,7 +429,7 @@ class Engine: break train_logs["step: {:d} "] = step - if lr_scheduler is not None: + if lr_scheduler is not None and step % self.k_steps == 0: lr_scheduler.step() try: train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr() @@ -551,6 +547,12 @@ class Engine: epochs=1, steps_per_epoch=None, collate_fn=None): + + if self.strategy.gradient_merge and batch_size is not None: + assert batch_size % self.k_steps == 0, \ + "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self.k_steps) + batch_size //= self.k_steps + dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] dist_context = self._dist_contexts[self.mode] @@ -612,6 +614,9 @@ class Engine: def _validate_spec(self, specs): specs = to_list(specs) + self.k_steps = 1 + if self.strategy.gradient_merge: + self.k_steps = self.strategy.gradient_merge_configs['k_steps'] if specs is not None: for i, spec in enumerate(specs): assert isinstance(spec, InputSpec) @@ -619,6 +624,12 @@ class Engine: raise ValueError( "Requires Input[{}].name != None, but receive `None` with {}." .format(i, spec)) + if self.k_steps > 1: + shape = list(spec.shape) + assert shape[0] % self.k_steps == 0, \ + "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self.k_steps) + shape[0] //= self.k_steps + spec.shape = shape return specs def _is_local_var(self, var): diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 4b538431bb072b55e987477fc6bbed82208ad5ca..01f7207ab91b3e1d888807021be6961fcecaba7d 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -84,7 +84,7 @@ class AutoParallelizer: self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping.lower() == 'true' else False - self._pass_context = None + # self._pass_context = None def _remove_distributed_attrs(self, main_program): suffix = core.kAutoParallelSuffix() @@ -143,10 +143,11 @@ class AutoParallelizer: def _apply_optimize(self, main_program, startup_program, params_grads): + optimizer = copy.deepcopy(self._optimizer) with program_guard(main_program, startup_program): - optimize_ops = copy.deepcopy( - self._optimizer).apply_gradients(params_grads) + optimize_ops = optimizer.apply_gradients(params_grads) + self._dist_context._lr_optimizer = optimizer # update completion self._completer = Completer(self._dist_context) self._completer.complete_update_annotation(main_program) @@ -165,6 +166,15 @@ class AutoParallelizer: config) auto_parallel_sharding_pass.apply([main_program], [startup_program], self._pass_context) + params_grads = self._pass_context.get_attr("params_grads") + + config = copy.deepcopy(self._dist_strategy.sharding_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["rank_id"] = rank + auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config) + auto_parallel_clip_pass.apply([main_program], [startup_program], + self._pass_context) if self._dist_strategy.gradient_merge: config = copy.deepcopy(self._dist_strategy.gradient_merge_configs) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index d4bb81e6b222c28a129950e0384fde9287486c9d..7e43ee95266438b0689656d3916ea91d3008bf31 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -230,9 +230,9 @@ class Parallelizer: config) auto_parallel_sharding_pass.apply([main_program], [startup_program], self._pass_context) + params_grads = self._pass_context.get_attr("params_grads") # GradClip is train-only optimization - if self._mode == "train": config = copy.deepcopy(self._strategy.sharding_configs) config["dist_context"] = self._dist_context diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index f65b7591e59727dfa85acdacda744d59869ec975..7702de7c01edd50fec1c3d5ea153cfc52f6c5209 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} - attrs = {'op_role': OpRole.Backward} + attrs = {'op_role': OpRole.Optimize} new_op = main_block.append_op(type='check_finite_and_unscale', inputs=inputs, outputs=outputs, @@ -575,18 +575,18 @@ class FP16Pass(AMPPass): ) or self.get_attr("init_loss_scaling") != 1.0: found_infs = [] if fp32_grads: - with main_program._backward_role_guard(): + with main_program._optimized_guard([]): _, found_inf_fp32 = _check_and_update_gradient( fp32_grads, self._loss_scaling, "@fp32", self.dist_context) found_infs.append(found_inf_fp32) if fp16_grads: - with main_program._backward_role_guard(): + with main_program._optimized_guard([]): _, found_inf_fp16 = _check_and_update_gradient( fp16_grads, self._loss_scaling, "@fp16", self.dist_context) found_infs.append(found_inf_fp16) - with main_program._backward_role_guard(): + with main_program._optimized_guard([]): block = main_program.global_block() all_infs = paddle.fluid.layers.concat(found_infs) @@ -608,7 +608,7 @@ class FP16Pass(AMPPass): block, self.dist_context) if self.get_attr("use_dynamic_loss_scaling"): - with main_program._backward_role_guard(): + with main_program._optimized_guard([]): if fp32_grads: self._update_loss_scaling(fp32_grads, found_inf) if fp16_grads: diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 6fba98ce75207928ae20e5f9ee648e16d9afba00..f1a0c6e38674ab6e962f2dad7802f75b319ce0d5 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase): super(ClipGradByGloblNormPass, self).__init__() self.set_attr("rank_id", None) self.set_attr("dist_context", None) + self.set_attr("params_grads", None) def _check_self(self): if self.get_attr("dist_context") is None: @@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase): dist_context = self.get_attr("dist_context") if dist_context._lr_optimizer._grad_clip is None: return False + if self.get_attr("params_grads") is None: + return False return True def _check_conflict(self, other_pass): @@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase): dist_context = self.get_attr("dist_context", None) rank_id = self.get_attr("rank_id", None) block = main_program.global_block() - dist_params_grads = _get_params_grads(block) + dist_params_grads = self.get_attr("params_grads", None) + # dist_params_grads = _get_params_grads(block) self.clip_helper = ClipHelper(dist_params_grads, rank_id, block, dist_context) diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 717f8fa27f2df0091a5f7fc21bac8d0e9529cb55..c61d944400d665fe38b19ea09664c3fc4c300a80 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context): return optimize_ops_desc -def _remove_op_role_var(param, grad): - op_maker = core.op_proto_and_checker_maker - op = grad.op - if op and op.has_attr(op_maker.kOpRoleVarAttrName()): - op._remove_attr(op_maker.kOpRoleVarAttrName()) - - def _get_gm_cond_var(main_program, k_steps, dist_context): main_block = main_program.global_block() # Add const var @@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op( param.type != core.VarDesc.VarType.SELECTED_ROWS ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" - _remove_op_role_var(param, grad) - # {grad.name: gradient_merge_var.name} to rename opt inputs grad_to_gradient_merge = {} # {param: gradient_merge_var} to insert scale op and fill_constant op diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index fa07915bf60dd5c782727c351decffcb205581fb..8b4d2288b791f417223049edb36ef889223f18d0 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -59,6 +59,7 @@ class ShardingPass(PassBase): self.varname_to_sharding_info = {} self.partial_sharding = False self.outer_dp_group = None + self.shared_params_grads = [] def _check_self(self): if self.get_attr("dist_context") is None: @@ -94,6 +95,8 @@ class ShardingPass(PassBase): self._shard_gradient_synchronization(main_block) self._shard_parameter(main_block, startup_block) + context.set_attr("params_grads", self.shared_params_grads) + def _build_sharding_groups(self, main_block, params_grads): self._collective_data_parallel_groups(main_block) self._build_sharding_infos(params_grads) @@ -148,13 +151,10 @@ class ShardingPass(PassBase): self._dist_context._sharding_group = sharding_group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group - params_in_group = [p for p, g in params_grads] - assert len(params_in_group) == len( - set(params_in_group)), "found duplicated param in params_grads" sharding_info = ShardingInfo(sharding_group, self.global_rank, - params_in_group) + params_grads) self.sharding_infos.append(sharding_info) - for param in params_in_group: + for param in sharding_info.params: self.varname_to_sharding_info[param.name] = sharding_info def _shard_optimizer(self, main_block, startup_block, params_grads, @@ -201,6 +201,7 @@ class ShardingPass(PassBase): op.desc.set_output('Out', reversed_x) else: if op.type == "check_finite_and_unscale": + op_role = op.attr('op_role') out_name = op.output_arg_names[0] out_var = main_block.vars[out_name] main_block._remove_op(idx, sync=False) @@ -212,6 +213,7 @@ class ShardingPass(PassBase): "shape": out_var.shape, "dtype": out_var.dtype, "value": 0, + OP_ROLE_KEY: op_role, }) else: main_block._remove_op(idx, sync=False) @@ -313,6 +315,9 @@ class ShardingPass(PassBase): if varname != param_name ]) main_block._remove_op(idx, sync=False) + else: + self.shared_params_grads.append( + self._get_param_grad(param_name)) for idx, op in reversed(list(enumerate(startup_block.ops))): if len(op.output_arg_names) == 1 and op.output_arg_names[ @@ -365,6 +370,13 @@ class ShardingPass(PassBase): sharding_info = self.varname_to_sharding_info[param_name] return sharding_info.is_in_local_shard(param_name) + def _get_param_grad(self, param_name): + assert param_name in self.varname_to_sharding_info + sharding_info = self.varname_to_sharding_info[param_name] + p_g = sharding_info.get_param_grad(param_name) + assert p_g is not None + return p_g + def _shard_gradient_synchronization(self, main_block): if self.stage < 2: @@ -705,9 +717,13 @@ def shard_parameters(params, group_size): class ShardingInfo(object): - def __init__(self, group, rank, params): + def __init__(self, group, rank, params_grads): self.group = group - self.params = params + self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads]) + assert len(self.params_grads) == len(set( + self.params_grads)), "found duplicated param in params_grads" + + self.params = [p for p, _ in params_grads] self.param_names = [p.name for p in self.params] self.group_size = group.nranks self.global_rank = rank @@ -762,3 +778,11 @@ class ShardingInfo(object): if usage > 0: broadcast_vars.add(param) return broadcast_vars, param_usage + + def get_param_grad(self, param_name): + if not self.is_in_local_shard(param_name): + raise ValueError( + "param[{}] not in current rank.".format(param_name)) + if param_name not in self.params_grads: + raise ValueError('param[{}] not in params_grads'.format(param_name)) + return self.params_grads.get(param_name, None) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index 7708767609e66c7cdd03e62f32ac32474a60f35a..ec879e77611cd491a7585d8b673d145577303b79 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -178,6 +178,7 @@ class AutoPallelPassTestBase(DistPassTestBase): preds = model(tokens, position_ids, attention_mask) criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) + clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) if kwargs.get('optimizer', None) == "LarsMomentum": optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( @@ -188,7 +189,7 @@ class AutoPallelPassTestBase(DistPassTestBase): beta1=0.9, beta2=0.999, epsilon=1e-08, - grad_clip=None) + grad_clip=clip) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(