# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from functools import reduce import paddle from .pass_base import PassBase, register_pass from ..auto_parallel.reshard import Resharder from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.utils import ( is_gradient_clip_op, is_optimize_op, OP_ROLE_KEY, OpRole, _get_comm_group, ) from ..auto_parallel.dist_attribute import ( TensorDistributedAttribute, OperatorDistributedAttribute, ) def _get_params_grads(block): params_grads = [] for op in reversed(block.ops): if not is_optimize_op(op): break if "Param" in op.input_names and "Grad" in op.input_names: param_name = op.input("Param")[0] grad_name = op.input("Grad")[0] param = block.var(param_name) grad = block.var(grad_name) params_grads.append((param, grad)) return params_grads def _get_dpmp_topology(origin_topology, sharding_group): """ Get dpmp topology from origin_topology Example: the parallel strategy: dp4-mp2-sharding2 the complete process_mesh: topology: [4, 2] processes: [0, 1, 2, 3, 4, 5, 6, 7] the dpmp topology: [2, 2] the sharding axis: 1 """ sharding_axis = 1 dp_sharding_topology = [ origin_topology[0] // sharding_group.nranks, sharding_group.nranks, ] if dp_sharding_topology[0] == 1: sharding_axis = 0 dp_sharding_topology = dp_sharding_topology[1:] product_dp_sharding = reduce(lambda x, y: x * y, dp_sharding_topology) product_topology = reduce(lambda x, y: x * y, origin_topology) if product_topology == product_dp_sharding: dpmp_topology = dp_sharding_topology else: assert product_topology % product_dp_sharding == 0 mp_degree = product_topology // product_dp_sharding dpmp_topology = dp_sharding_topology + [mp_degree] return dpmp_topology, sharding_axis def _get_dpmp_process_mesh(rank_id, topology, processes, sharding_group): """ Get dpmp process_mesh from the complete process_mesh which apply sharding. Example: the parallel strategy: dp4-mp2-sharding2 the complete process_mesh: topology: [4, 2] processes: [0, 1, 2, 3, 4, 5, 6, 7] the dpmp process_mesh is: 1) topology: [2, 2], processes: [0, 1, 4, 5] 2) topology: [2, 2], processes: [2, 3, 6, 7] """ if sharding_group is None: return topology, processes # get dpmp_topology dpmp_topology, sharding_axis = _get_dpmp_topology(topology, sharding_group) # get all sharding_groups of ranks sharding_groups = [] for rank in processes: group = _get_comm_group(processes, dpmp_topology, sharding_axis, rank) if group not in sharding_groups: sharding_groups.append(group) # get dpmp_processes sharding_groups = np.array(sharding_groups) dpmp_processes_in_sharding = None for i in range(sharding_groups.shape[-1]): if rank_id in sharding_groups[:, i]: dpmp_processes_in_sharding = sharding_groups[:, i] assert dpmp_processes_in_sharding is not None return dpmp_topology, list(dpmp_processes_in_sharding) def _is_about_global_norm( rank_id, tensor_shape, topology, processes, dims_mapping, sharding_group ): # get current process_mesh where the parameter exist. dpmp_topology, dpmp_processes = _get_dpmp_process_mesh( rank_id, topology, processes, sharding_group ) complete_shape = Resharder.compute_complete_shape( tensor_shape, dpmp_topology, dims_mapping ) complete_partitions = [] complete_param_ranks = [] for process in dpmp_processes: partition_index = Resharder.compute_partition_index( process, complete_shape, dims_mapping, dpmp_topology, dpmp_processes ) if partition_index not in complete_partitions: complete_partitions.append(partition_index) complete_param_ranks.append(process) return rank_id in complete_param_ranks class ClipHelper: def __init__(self, params_grads, rank_id, block, dist_context): params, _ = zip(*params_grads) self.params = list(params) self.params_name = [p.name for p in self.params] self.rank_id = rank_id self.block = block self.dist_context = dist_context self.sharding_group = None self.world_ranks = get_world_process_group().ranks if hasattr(dist_context, '_sharding_group'): self.sharding_group = dist_context._sharding_group def _is_calcuate_norm(self, name): if not self._is_local_param(name): return False, [] param = self.params[self.params_name.index(name)] dist_attr = self._get_dist_attr(name) topology = dist_attr.process_mesh.topology processes = dist_attr.process_mesh.processes dims_mapping = dist_attr.dims_mapping return _is_about_global_norm( self.rank_id, param.shape, topology, processes, dims_mapping, self.sharding_group, ) def _get_dist_attr(self, name): var = self.block.vars[name] return self.dist_context.get_tensor_dist_attr_for_program(var) def _is_local_param(self, name): if name not in self.params_name: return False return True def _is_local_var(self, name): dist_attr = self._get_dist_attr(name) assert dist_attr is not None return self.rank_id in dist_attr.process_mesh.processes def _init_dist_attr(self, op): op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = self.world_ranks for in_name in op.input_arg_names: in_var = self.block.vars[in_name] in_dist_attr = TensorDistributedAttribute() in_dist_attr.process_mesh = self.world_ranks in_dist_attr.dims_mapping = [-1] self.dist_context.set_tensor_dist_attr_for_program( in_var, in_dist_attr ) op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) for out_name in op.output_arg_names: out_var = self.block.vars[out_name] out_dist_attr = TensorDistributedAttribute() out_dist_attr.process_mesh = self.world_ranks out_dist_attr.dims_mapping = [-1] self.dist_context.set_tensor_dist_attr_for_program( out_var, out_dist_attr ) op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr) @register_pass("auto_parallel_grad_clip") class ClipGradByGloblNormPass(PassBase): """ 1. Remove norm-compute op and grad-scale op when the grad is not in current rank or is independent of the calculation of norm. 2. Each rank computes its own norm value, then gets global_norm by allreduce_sum only once. """ def __init__(self): super().__init__() self.set_attr("rank_id", None) self.set_attr("dist_context", None) self.set_attr("params_grads", None) def _check_self(self): if self.get_attr("dist_context") is None: return False dist_context = self.get_attr("dist_context") if dist_context._serial_optimizer._grad_clip is None: return False if self.get_attr("params_grads") is None: return False return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, context): dist_context = self.get_attr("dist_context", None) rank_id = self.get_attr("rank_id", None) block = main_program.global_block() dist_params_grads = self.get_attr("params_grads", None) # dist_params_grads = _get_params_grads(block) self.clip_helper = ClipHelper( dist_params_grads, rank_id, block, dist_context ) self._remove_no_need_ops_vars(block) def _remove_no_need_ops_vars(self, block): removed_op_out_type = [ 'clip_by_norm', 'squared_l2_norm', 'square', 'reduce_sum', ] removed_op_idx = set() removed_tmp_var = set() for idx, op in enumerate(block.ops): if not is_gradient_clip_op(op): continue if op.type in removed_op_out_type: input_name = op.input("X")[0] if input_name.find("@GRAD") != -1: # 'clip_by_norm', 'squared_l2_norm', 'square' param_name = input_name[: input_name.find("@GRAD")] is_local = self.clip_helper._is_local_param(param_name) is_calculate = self.clip_helper._is_calcuate_norm( param_name ) if not is_local or ( not is_calculate and op.type != 'clip_by_norm' ): removed_op_idx.add(idx) removed_tmp_var.update(set(op.output_arg_names)) else: # 'reduce_sum' if idx - 1 in removed_op_idx: removed_op_idx.add(idx) removed_tmp_var.update(set(op.output_arg_names)) elif op.type == 'elementwise_mul': input_name = op.input("X")[0] if input_name.find("@GRAD") != -1: param_name = input_name[: input_name.find("@GRAD")] is_local = self.clip_helper._is_local_param(param_name) if not is_local: removed_op_idx.add(idx) if block.ops[idx - 1].type == 'cast': removed_op_idx.add(idx - 1) removed_tmp_var.update( set(block.ops[idx - 1].output_arg_names) ) elif op.type == 'sum': reserved_vars = [] for input_name in op.input_arg_names: if ( input_name not in removed_tmp_var and self.clip_helper._is_local_var(input_name) ): reserved_vars.append(input_name) if not reserved_vars: removed_op_idx.add(idx) removed_tmp_var.update(set(op.output_arg_names)) if block.ops[idx + 1].type == 'cast': removed_op_idx.add(idx + 1) removed_tmp_var.update( set(block.ops[idx + 1].output_arg_names) ) else: op.desc.set_input("X", reserved_vars) for idx, op in reversed(list(enumerate(block.ops))): if not is_optimize_op(op): break if not is_gradient_clip_op(op): continue if idx in removed_op_idx: block._remove_op(idx, sync=False) for idx, op in reversed(list(enumerate(block.ops))): if not is_optimize_op(op): break if not is_gradient_clip_op(op): continue if op.type == 'sqrt': input_name = op.input("X")[0] input_var = block.vars[input_name] if paddle.distributed.get_world_size() > 1: offset = 0 if input_name in removed_tmp_var: removed_tmp_var.remove(input_name) fill_constant_op = block._insert_op( idx, type='fill_constant', inputs={}, outputs={'Out': [input_var]}, attrs={ 'shape': [1], 'dtype': input_var.dtype, 'value': 0, 'force_cpu': False, OP_ROLE_KEY: OpRole.Optimize, }, ) fill_constant_op._set_attr( 'op_namescope', "/gradient_clip_pass" ) offset += 1 self.clip_helper._init_dist_attr(fill_constant_op) allreduce_op = block._insert_op( idx + offset, type='c_allreduce_sum', inputs={'X': [input_var]}, outputs={'Out': [input_var]}, attrs={ 'ring_id': 0, 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize, }, ) allreduce_op._set_attr( 'op_namescope', "/gradient_clip_pass" ) self.clip_helper._init_dist_attr(allreduce_op) for varname in removed_tmp_var: block._remove_var(varname, sync=False) block._sync_with_cpp()