提交 779fde8d 编写于 作者: S sandyhouse

update

上级 e166873b
...@@ -115,8 +115,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -115,8 +115,8 @@ class ShardingOptimizer(MetaOptimizerBase):
main_program._pipeline_opt[ main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index() 'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 2 main_program._pipeline_opt['ring_id'] = 20
optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize( optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list) self.pipeline_nodes = len(program_list)
else: else:
...@@ -423,7 +423,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -423,7 +423,9 @@ class ShardingOptimizer(MetaOptimizerBase):
False) False)
else: else:
for pair in self.pipeline_pair: for pair in self.pipeline_pair:
print("pp pair:{}".format(pair)) pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue if self.pp_rank not in pair: continue
pp_group_endpoints = [ pp_group_endpoints = [
self.pp_group_endpoints[pair[0]], self.pp_group_endpoints[pair[0]],
...@@ -437,8 +439,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -437,8 +439,7 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank = 0 if self.pp_rank == pair[0] else 1 pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, start_ring_id, False, pp_group_endpoints, pp_rank, ring_id, False, False)
False)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -869,7 +870,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -869,7 +870,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank = self.global_rank % self.sharding_group_size self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
) )
self.pp_ring_id = 2 self.pp_ring_id = 20
self.pp_rank = self.global_rank // ( self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size) self.sharding_group_size * self._inner_parallelism_size)
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
...@@ -885,7 +886,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -885,7 +886,7 @@ class ShardingOptimizer(MetaOptimizerBase):
else: else:
self.mp_group_id = 0 self.mp_group_id = 0
self.sharding_ring_id = 1 self.sharding_ring_id = 1
self.pp_ring_id = 2 self.pp_ring_id = 20
self.mp_rank = self.global_rank % self._inner_parallelism_size self.mp_rank = self.global_rank % self._inner_parallelism_size
self.mp_group = self.global_rank // self._inner_parallelism_size self.mp_group = self.global_rank // self._inner_parallelism_size
self.mp_group_endpoints = [ self.mp_group_endpoints = [
......
...@@ -3789,6 +3789,7 @@ class PipelineOptimizer(object): ...@@ -3789,6 +3789,7 @@ class PipelineOptimizer(object):
self._op_device_key = op_maker.kOpDeviceAttrName() self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None self._param_device_map = None
self._pipeline_pair = [] self._pipeline_pair = []
self._pp_ring_map = dict()
def _create_vars(self, block, ori_block): def _create_vars(self, block, ori_block):
# Create vars for block, copied from ori_block # Create vars for block, copied from ori_block
...@@ -3841,6 +3842,8 @@ class PipelineOptimizer(object): ...@@ -3841,6 +3842,8 @@ class PipelineOptimizer(object):
dest_var = block._clone_variable(source_var, False) dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient dest_var.stop_gradient = source_var.stop_gradient
continue
# TODO add allreduce_max when without sharding
if not should_insert: continue if not should_insert: continue
out_name = op.desc.output_arg_names()[0] out_name = op.desc.output_arg_names()[0]
out_var = block.var(out_name) out_var = block.var(out_name)
...@@ -4428,12 +4431,11 @@ class PipelineOptimizer(object): ...@@ -4428,12 +4431,11 @@ class PipelineOptimizer(object):
prev_device_index = int(prev_device.split(':')[1]) prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1]) cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index) pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index: if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1 ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else: else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
if self.schedule_mode == 0: # GPipe if self.schedule_mode == 0: # GPipe
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
...@@ -4467,6 +4469,13 @@ class PipelineOptimizer(object): ...@@ -4467,6 +4469,13 @@ class PipelineOptimizer(object):
extra_index += 1 extra_index += 1
continue continue
assert self.schedule_mode == 1 assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
#type='send_v2', #type='send_v2',
...@@ -4544,7 +4553,7 @@ class PipelineOptimizer(object): ...@@ -4544,7 +4553,7 @@ class PipelineOptimizer(object):
self._op_device_key: cur_device, self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role, self._op_role_key: op_role,
'ring_id': self.ring_id, 'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
...@@ -4608,35 +4617,135 @@ class PipelineOptimizer(object): ...@@ -4608,35 +4617,135 @@ class PipelineOptimizer(object):
var = block.vars[var_name] var = block.vars[var_name]
prev_device_index = int(prev_device.split(':')[1]) prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1]) cur_device_index = int(cur_device.split(':')[1])
#block._insert_op( pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
if self.schedule_mode == 0: # GPipe
block._insert_op_without_sync(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
continue
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index, index=index + extra_index,
type='send_v2', #type='send_v2',
type='c_broadcast',
inputs={'X': var}, inputs={'X': var},
outputs={'Out': var},
attrs={ attrs={
self._op_device_key: prev_device, self._op_device_key: prev_device,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'use_calc_stream': False,
'peer': cur_device_index, #'peer': cur_device_index,
'ring_id': self.ring_id, #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
}) })
extra_index += 1 extra_index += 1
#block._insert_op( #block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index, index=index + extra_index,
type='recv_v2', #type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape, 'shape': fill_shape,
'dtype': var.dtype, 'dtype': var.dtype,
self._op_device_key: cur_device, self._op_device_key: cur_device,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'value': float(0.0),
'peer': prev_device_index, })
'ring_id': self.ring_id, extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
block._sync_with_cpp() block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names): def _clear_gradients(self, main_block, param_names):
...@@ -5120,7 +5229,7 @@ class PipelineOptimizer(object): ...@@ -5120,7 +5229,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches, "num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id, "start_cpu_core_id": self._start_cpu_core_id,
} }
return optimize_ops, params_grads, program_list, self._pipeline_pair return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map
class RecomputeOptimizer(Optimizer): class RecomputeOptimizer(Optimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册