From f874e02b71af44a600eb2b06b5d89be68c4100ea Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Tue, 2 Mar 2021 18:15:40 +0800 Subject: [PATCH] update optimizer --- .../meta_optimizers/sharding_optimizer.py | 31 ++- python/paddle/fluid/device_worker.py | 2 + python/paddle/fluid/optimizer.py | 239 ++++++++++++++---- 3 files changed, 212 insertions(+), 60 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index b17084a979e..38dcc14427e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"] class ShardingOptimizer(MetaOptimizerBase): + """Sharding Optimizer.""" + def __init__(self, optimizer): super(ShardingOptimizer, self).__init__(optimizer) self.inner_opt = optimizer @@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase): startup_program=None, parameter_list=None, no_grad_set=None): + """Implementation of minimize.""" # TODO: (JZ-LIANG) support multiple comm in future # self._nrings = self.user_defined_strategy.nccl_comm_num self._nrings_sharding = 1 @@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase): self.user_defined_strategy.sharding_configs["parallelism"]) self.use_pipeline = self.user_defined_strategy.sharding_configs[ "use_pipeline"] + self.acc_steps = self.user_defined_strategy.sharding_configs[ + "acc_steps"] if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") if self.use_pipeline: - pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt) + pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt, + self.acc_steps) main_program = loss.block.program main_program._pipeline_opt = dict() pp_rank = self.role_maker._worker_index() // ( @@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase): 'global_rank'] = self.role_maker._worker_index() main_program._pipeline_opt['use_sharding'] = True main_program._pipeline_opt['ring_id'] = 2 - optimize_ops, params_grads, program_list = pp_optimizer.minimize( + optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) self.pipeline_nodes = len(program_list) else: @@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase): # check op dependecy check_broadcast(main_block) - check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, - self.dp_ring_id) + #check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, + # self.dp_ring_id) #check_allreduce_sum(main_block, self._shard, self.dp_ring_id) self._wait() return optimize_ops, params_grads @@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase): print("pp_group_endpoints:", self.pp_group_endpoints) print("pp_rank:", self.pp_rank) print("pp_ring_id:", self.pp_ring_id) - self._collective_helper._init_communicator( - self._startup_program, self.current_endpoint, - self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False) + for pair in self.pipeline_pair: + if self.pp_rank not in pair: continue + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + if pair[0] < pair[1]: + start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 + else: + start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1 + pp_rank = 0 if self.pp_rank == pair[0] else 1 + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + pp_group_endpoints, pp_rank, start_ring_id, False) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 838aea37f18..e9b9bca3804 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -413,6 +413,8 @@ class Section(DeviceWorker): section_param = trainer_desc.section_param section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] + section_param.pipeline_stage = pipeline_opt["pipeline_stage"] + section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"] cfg = section_param.section_config program = pipeline_opt["section_program"] cfg.program_desc.ParseFromString(program["program"]._get_desc() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 42099fbe8c7..72da8d672c7 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3788,6 +3788,7 @@ class PipelineOptimizer(object): self._op_role_var_key = op_maker.kOpRoleVarAttrName() self._op_device_key = op_maker.kOpDeviceAttrName() self._param_device_map = None + self._pipeline_pair = [] def _create_vars(self, block, ori_block): # Create vars for block, copied from ori_block @@ -4134,6 +4135,7 @@ class PipelineOptimizer(object): if not var_name in first_block.vars: self._create_var(first_block, main_var, var_name) dev_index = int(device.split(':')[1]) + print("dev_index:", dev_index) first_block._insert_op( index=insert_index, type='send_v2', @@ -4141,9 +4143,11 @@ class PipelineOptimizer(object): attrs={ self._op_device_key: first_dev_spec, self._op_role_key: self._op_role.Forward, - 'use_calc_stream': True, + 'use_calc_stream': False, 'peer': dev_index, - 'ring_id': self.ring_id, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id + if dev_index > first_dev_index else self.ring_id + 2, }) # Get the device that that data on assert device in devices @@ -4168,7 +4172,21 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Forward, 'peer': first_dev_index, 'use_calc_stream': True, - 'ring_id': self.ring_id, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id + if first_dev_index < dev_index else self.ring_id + 2, + }) + block._insert_op( + index=index + 1, + type='c_sync_comm_stream', + inputs={'X': [new_var]}, + outputs={'Out': [new_var]}, + attrs={ + self._op_device_key: device, + self._op_role_key: self._op_role.Forward, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id + if first_dev_index > dev_index else self.ring_id + 2, }) def _strip_grad_suffix(self, name): @@ -4409,30 +4427,91 @@ class PipelineOptimizer(object): var = block.vars[var_name] prev_device_index = int(prev_device.split(':')[1]) cur_device_index = int(cur_device.split(':')[1]) + pair = (prev_device_index, cur_device_index) + if cur_device_index > prev_device_index: + ring_id = self.ring_id + cur_device_index - prev_device_index - 1 + else: + ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 + if pair not in self._pipeline_pair: + self._pipeline_pair.append(pair) block._insert_op( index=index + extra_index, - type='send_v2', + #type='send_v2', + type='c_broadcast', inputs={'X': var}, + outputs={'Out': var}, attrs={ self._op_device_key: prev_device, self._op_role_key: op_role, - 'use_calc_stream': True, - 'peer': cur_device_index, + 'use_calc_stream': False, + #'peer': cur_device_index, + #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, + 'ring_id': ring_id, + #'ring_id': self.ring_id, + #'root': prev_device_index, + 'root': 0, + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + type='c_sync_comm_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: cur_device, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Backward, 'ring_id': self.ring_id, + #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, }) extra_index += 1 + fill_shape = list(var.shape) + fill_shape[0] = 1 block._insert_op( index=index + extra_index, - type='recv_v2', + #type='recv_v2', + type='fill_constant', + inputs={}, outputs={'Out': [var]}, attrs={ - 'out_shape': var.shape, + 'shape': fill_shape, 'dtype': var.dtype, self._op_device_key: cur_device, self._op_role_key: op_role, - 'use_calc_stream': True, - 'peer': prev_device_index, + 'value': float(0.0), + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + #type='recv_v2', + type='c_broadcast', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + #'out_shape': var.shape, + #'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'use_calc_stream': False, + #'peer': prev_device_index, + #'root': prev_device_index, + 'root': 0, + #'ring_id': self.ring_id, + 'ring_id': ring_id, + #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, + #'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2, + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + type='c_sync_comm_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: cur_device, + self._op_role_key: op_role, 'ring_id': self.ring_id, + #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, }) extra_index += 1 @@ -4512,6 +4591,15 @@ class PipelineOptimizer(object): first_optimize_op_index = None for index, op in reversed(tuple(enumerate(list(block.ops)))): # device = op.attr(self._op_device_key) + # remove the cast op of fp16 grad to fp32 grad + if self._is_optimize_op(op) and op.type == 'cast': + in_name = op.input_arg_names[0] + out_name = op.output_arg_names[0] + if out_name.strip('@GRAD') in self._param_device_map: + assert in_name.replace('.cast_fp16', '') == out_name + block._remove_op(index) + continue + if not self._is_optimize_op(op) and not first_optimize_op_index: first_optimize_op_index = index + 1 if block.ops[ @@ -4553,11 +4641,11 @@ class PipelineOptimizer(object): # a trick to run this op once per mini-batch self._op_role_key: self._op_role.Optimize.LRSched, }) - offset += 1 + #offset += 1 grad_name = op_role_var[i + 1] # with _0 suffix - grad_var = block.vars[grad_name] # without _0 suffix + grad_var = block.vars[grad_name] real_grad_name = grad_name[0:grad_name.find( - '@GRAD')] + '@GRAD' + '@GRAD')] + '@GRAD' # without _0 suffix real_grad_var = block.vars[ real_grad_name] # without _0 suffix # new_grad_var_name = unique_name.generate(grad_name) @@ -4567,7 +4655,7 @@ class PipelineOptimizer(object): # self._rename_arg(op, grad_name, new_grad_var_name) if not 'cast_fp16' in grad_name: block._insert_op( - index=first_optimize_op_index + offset, + index=index + 1, type='sum', inputs={'X': [grad_var, real_grad_var]}, outputs={'Out': real_grad_var}, @@ -4576,58 +4664,83 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Backward, #self._op_role_var_key: op_role_var }) - offset += 1 + #offset += 1 else: grad_name = op_role_var[i + 1] # with _0 suffix - grad_var = block.vars[grad_name] # without _0 suffix - fp32_grad_var_name = param_name + core.grad_var_suffix() + grad_var = block.vars[grad_name] + fp32_grad_var_name = param_name + core.grad_var_suffix( + ) # without _0 suffix fp32_grad_var = block.vars[fp32_grad_var_name] fp32_grad_var.persistable = True cast_grad_var_name = unique_name.generate( fp32_grad_var_name) - cast_var = self._create_var(block, grad_var, - cast_grad_var_name) - cast_var.persistable = False - real_grad_name = grad_name[0:grad_name.find( - '@GRAD')] + '@GRAD' - real_grad_var = block.vars[ - real_grad_name] # without _0 suffix + cast_grad_var = self._create_var(block, fp32_grad_var, + cast_grad_var_name) + cast_grad_var.persistable = False block._insert_op( - index=first_optimize_op_index + offset, + index=index + 1, type='cast', - inputs={'X': fp32_grad_var}, - outputs={'Out': cast_var}, + inputs={'X': grad_var}, + outputs={'Out': cast_grad_var}, attrs={ - 'in_dtype': fp32_grad_var.dtype, - 'out_dtype': cast_var.dtype, + 'in_dtype': grad_var.dtype, + 'out_dtype': cast_grad_var.dtype, # self._op_device_key: device, self._op_role_key: self._op_role.Backward, # self._op_role_var_key: op_role_var }) offset += 1 block._insert_op( - index=first_optimize_op_index + offset, + index=index + 2, type='sum', - inputs={'X': [grad_var, cast_var]}, - outputs={'Out': real_grad_var}, - attrs={ - # self._op_device_key: device, - self._op_role_key: self._op_role.Backward, - # self._op_role_var_key: op_role_var - }) - offset += 1 - block._insert_op( - index=first_optimize_op_index + offset, - type='cast', - inputs={'X': real_grad_var}, + inputs={'X': [fp32_grad_var, cast_grad_var]}, outputs={'Out': fp32_grad_var}, attrs={ - 'in_dtype': real_grad_var.dtype, - 'out_dtype': fp32_grad_var.dtype, # self._op_device_key: device, self._op_role_key: self._op_role.Backward, # self._op_role_var_key: op_role_var }) + offset += 1 + #real_grad_name = grad_name[0:grad_name.find( + # '@GRAD')] + '@GRAD' + #real_grad_var = block.vars[ + # real_grad_name] # without _0 suffix + #block._insert_op( + # index=first_optimize_op_index + offset, + # type='cast', + # inputs={'X': fp32_grad_var}, + # outputs={'Out': cast_var}, + # attrs={ + # 'in_dtype': fp32_grad_var.dtype, + # 'out_dtype': cast_var.dtype, + # # self._op_device_key: device, + # self._op_role_key: self._op_role.Backward, + # # self._op_role_var_key: op_role_var + # }) + #offset += 1 + #block._insert_op( + # index=first_optimize_op_index + offset, + # type='sum', + # inputs={'X': [grad_var, cast_var]}, + # outputs={'Out': real_grad_var}, + # attrs={ + # # self._op_device_key: device, + # self._op_role_key: self._op_role.Backward, + # # self._op_role_var_key: op_role_var + # }) + #offset += 1 + #block._insert_op( + # index=first_optimize_op_index + offset, + # type='cast', + # inputs={'X': real_grad_var}, + # outputs={'Out': fp32_grad_var}, + # attrs={ + # 'in_dtype': real_grad_var.dtype, + # 'out_dtype': fp32_grad_var.dtype, + # # self._op_device_key: device, + # self._op_role_key: self._op_role.Backward, + # # self._op_role_var_key: op_role_var + # }) def _add_sub_blocks(self, main_block, program_list): main_program = main_block.program @@ -4720,12 +4833,14 @@ class PipelineOptimizer(object): inputs={'X': write_block.var(var_name), }, attrs={ self._op_device_key: write_device, - 'use_calc_stream': True, + 'use_calc_stream': False, # A trick to make the role LRSched to avoid copy every # microbatch self._op_role_key: self._op_role.LRSched, 'peer': read_dev_index, - 'ring_id': self.ring_id, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id if + read_dev_index > write_dev_index else self.ring_id + 2, }) read_block._insert_op( index=0, @@ -4735,12 +4850,28 @@ class PipelineOptimizer(object): 'out_shape': read_block.var(var_name).shape, 'dtype': read_block.var(var_name).dtype, self._op_device_key: read_device, - 'use_calc_stream': True, + 'use_calc_stream': False, # A trick to make the role LRSched to avoid copy every # microbatch self._op_role_key: self._op_role.LRSched, 'peer': write_dev_index, - 'ring_id': self.ring_id, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id if + write_dev_index < read_dev_index else self.ring_id + 2, + }) + read_block._insert_op( + index=1, + type='c_sync_comm_stream', + inputs={'X': [read_block.var(var_name)]}, + outputs={'Out': [read_block.var(var_name)]}, + attrs={ + self._op_device_key: read_device, + # A trick to make the role LRSched to avoid copy every + # microbatch + self._op_role_key: self._op_role.LRSched, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id if + write_dev_index > read_dev_index else self.ring_id + 2, }) def _is_gradient_clip_op(self, op): @@ -4809,8 +4940,8 @@ class PipelineOptimizer(object): program_list = self._split_program(main_program, device_list) for p in program_list: self._create_vars(p["program"].block(0), main_block) - self._insert_sendrecv_for_data_var(main_block, program_list, - startup_program, device_list) + #self._insert_sendrecv_for_data_var(main_block, program_list, + # startup_program, device_list) # Step4: Special Case: process persistable vars that exist in # multiple sections @@ -4824,8 +4955,8 @@ class PipelineOptimizer(object): place_list = [] for dev in device_list: - dev_index = int(dev.split(":")[1]) % 8 - place_list.append(core.CUDAPlace(dev_index)) + dev_index = int(dev.split(":")[1]) + place_list.append(core.CUDAPlace(dev_index % 8)) # Step6: Split startup program new_startup_program = self._split_startup_program(startup_program, @@ -4851,6 +4982,8 @@ class PipelineOptimizer(object): "trainer": "PipelineTrainer", "device_worker": "Section", "inner_parallelism": len(device_list), + "num_pipeline_stages": len(device_list), + "pipeline_stage": local_rank, "section_program": program_list[local_rank], "place": place_list[local_rank], "place_id": place_id, @@ -4858,7 +4991,7 @@ class PipelineOptimizer(object): "num_microbatches": self._num_microbatches, "start_cpu_core_id": self._start_cpu_core_id, } - return optimize_ops, params_grads, program_list + return optimize_ops, params_grads, program_list, self._pipeline_pair class RecomputeOptimizer(Optimizer): -- GitLab