diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index f8e26ee2406f18152dd7c771c2f75ed05315824a..ce9c18a3ff51b22da390bd6f06b1428586f06cfd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -115,8 +115,8 @@ class ShardingOptimizer(MetaOptimizerBase): main_program._pipeline_opt[ 'global_rank'] = self.role_maker._worker_index() main_program._pipeline_opt['use_sharding'] = True - main_program._pipeline_opt['ring_id'] = 2 - optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize( + main_program._pipeline_opt['ring_id'] = 20 + optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) self.pipeline_nodes = len(program_list) else: @@ -423,7 +423,9 @@ class ShardingOptimizer(MetaOptimizerBase): False) else: for pair in self.pipeline_pair: - print("pp pair:{}".format(pair)) + pair_key = pair[0] * 1000 + pair[1] + ring_id = self.pp_ring_map[pair_key] + print("pp pair:{}, ring_id: {}".format(pair, ring_id)) if self.pp_rank not in pair: continue pp_group_endpoints = [ self.pp_group_endpoints[pair[0]], @@ -437,8 +439,7 @@ class ShardingOptimizer(MetaOptimizerBase): pp_rank = 0 if self.pp_rank == pair[0] else 1 self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - pp_group_endpoints, pp_rank, start_ring_id, False, - False) + pp_group_endpoints, pp_rank, ring_id, False, False) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -869,7 +870,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.sharding_rank = self.global_rank % self.sharding_group_size assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( ) - self.pp_ring_id = 2 + self.pp_ring_id = 20 self.pp_rank = self.global_rank // ( self.sharding_group_size * self._inner_parallelism_size) self.sharding_group_endpoints = [ @@ -885,7 +886,7 @@ class ShardingOptimizer(MetaOptimizerBase): else: self.mp_group_id = 0 self.sharding_ring_id = 1 - self.pp_ring_id = 2 + self.pp_ring_id = 20 self.mp_rank = self.global_rank % self._inner_parallelism_size self.mp_group = self.global_rank // self._inner_parallelism_size self.mp_group_endpoints = [ diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ffb54861904f4a7c6d54ba958343f9cb00cb559b..8f549f93145bdf6af8f84a3c3890ba6031dbe67b 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3789,6 +3789,7 @@ class PipelineOptimizer(object): self._op_device_key = op_maker.kOpDeviceAttrName() self._param_device_map = None self._pipeline_pair = [] + self._pp_ring_map = dict() def _create_vars(self, block, ori_block): # Create vars for block, copied from ori_block @@ -3841,6 +3842,8 @@ class PipelineOptimizer(object): dest_var = block._clone_variable(source_var, False) dest_var.stop_gradient = source_var.stop_gradient + continue + # TODO add allreduce_max when without sharding if not should_insert: continue out_name = op.desc.output_arg_names()[0] out_var = block.var(out_name) @@ -4428,12 +4431,11 @@ class PipelineOptimizer(object): prev_device_index = int(prev_device.split(':')[1]) cur_device_index = int(cur_device.split(':')[1]) pair = (prev_device_index, cur_device_index) + pair_key = prev_device_index * 1000 + cur_device_index if cur_device_index > prev_device_index: ring_id = self.ring_id + cur_device_index - prev_device_index - 1 else: ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 - if pair not in self._pipeline_pair: - self._pipeline_pair.append(pair) if self.schedule_mode == 0: # GPipe block._insert_op( index=index + extra_index, @@ -4467,6 +4469,13 @@ class PipelineOptimizer(object): extra_index += 1 continue assert self.schedule_mode == 1 + if pair not in self._pipeline_pair: + self._pipeline_pair.append(pair) + self._pp_ring_map[pair_key] = self.ring_id + ring_id = self.ring_id + self.ring_id += 1 + else: + ring_id = self._pp_ring_map[pair_key] block._insert_op( index=index + extra_index, #type='send_v2', @@ -4544,7 +4553,7 @@ class PipelineOptimizer(object): self._op_device_key: cur_device, #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, self._op_role_key: op_role, - 'ring_id': self.ring_id, + 'ring_id': ring_id, #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, }) extra_index += 1 @@ -4608,35 +4617,135 @@ class PipelineOptimizer(object): var = block.vars[var_name] prev_device_index = int(prev_device.split(':')[1]) cur_device_index = int(cur_device.split(':')[1]) - #block._insert_op( + pair = (prev_device_index, cur_device_index) + pair_key = prev_device_index * 1000 + cur_device_index + if cur_device_index > prev_device_index: + ring_id = self.ring_id + cur_device_index - prev_device_index - 1 + else: + ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 + print("call xx_insert, schedule_mode:", self.schedule_mode) + if self.schedule_mode == 0: # GPipe + block._insert_op_without_sync( + index=index + extra_index, + type='send_v2', + inputs={'X': var}, + attrs={ + self._op_device_key: prev_device, + self._op_role_key: op_role, + 'use_calc_stream': True, + 'peer': cur_device_index, + 'ring_id': self.ring_id + if cur_device_index > prev_device_index else + self.ring_id + 2, + }) + extra_index += 1 + block._insert_op_without_sync( + index=index + extra_index, + type='recv_v2', + outputs={'Out': [var]}, + attrs={ + 'out_shape': var.shape, + 'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'use_calc_stream': True, + 'peer': prev_device_index, + 'ring_id': self.ring_id + if cur_device_index > prev_device_index else + self.ring_id + 2, + }) + extra_index += 1 + continue + assert self.schedule_mode == 1 + if pair not in self._pipeline_pair: + self._pipeline_pair.append(pair) + self._pp_ring_map[pair_key] = self.ring_id + ring_id = self.ring_id + self.ring_id += 1 + else: + ring_id = self._pp_ring_map[pair_key] + print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id)) block._insert_op_without_sync( index=index + extra_index, - type='send_v2', + #type='send_v2', + type='c_broadcast', inputs={'X': var}, + outputs={'Out': var}, attrs={ self._op_device_key: prev_device, self._op_role_key: op_role, - 'use_calc_stream': True, - 'peer': cur_device_index, - 'ring_id': self.ring_id, + 'use_calc_stream': False, + #'peer': cur_device_index, + #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, + 'ring_id': ring_id, + #'ring_id': self.ring_id, + #'root': prev_device_index, + 'root': 0, }) extra_index += 1 #block._insert_op( + # index=index + extra_index, + # type='c_sync_comm_stream', + # inputs={'X': [var]}, + # outputs={'Out': [var]}, + # attrs={ + # self._op_device_key: cur_device, + # self._op_role_key: + # core.op_proto_and_checker_maker.OpRole.Backward, + # 'ring_id': self.ring_id, + # #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, + # }) + #extra_index += 1 + fill_shape = list(var.shape) + fill_shape[0] = 1 block._insert_op_without_sync( index=index + extra_index, - type='recv_v2', + #type='recv_v2', + type='fill_constant', + inputs={}, outputs={'Out': [var]}, attrs={ - 'out_shape': var.shape, + 'shape': fill_shape, 'dtype': var.dtype, self._op_device_key: cur_device, self._op_role_key: op_role, - 'use_calc_stream': True, - 'peer': prev_device_index, - 'ring_id': self.ring_id, + 'value': float(0.0), + }) + extra_index += 1 + block._insert_op_without_sync( + index=index + extra_index, + #type='recv_v2', + type='c_broadcast', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + #'out_shape': var.shape, + #'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'use_calc_stream': False, + #'peer': prev_device_index, + #'root': prev_device_index, + 'root': 0, + #'ring_id': self.ring_id, + 'ring_id': ring_id, + #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, + #'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2, + }) + extra_index += 1 + block._insert_op_without_sync( + index=index + extra_index, + type='c_sync_comm_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: cur_device, + #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, + self._op_role_key: op_role, + 'ring_id': ring_id, + #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, }) extra_index += 1 - block._sync_with_cpp() def _clear_gradients(self, main_block, param_names): @@ -5120,7 +5229,7 @@ class PipelineOptimizer(object): "num_microbatches": self._num_microbatches, "start_cpu_core_id": self._start_cpu_core_id, } - return optimize_ops, params_grads, program_list, self._pipeline_pair + return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map class RecomputeOptimizer(Optimizer):