提交 779fde8d 编写于 作者: S sandyhouse

update

上级 e166873b
......@@ -115,8 +115,8 @@ class ShardingOptimizer(MetaOptimizerBase):
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize(
main_program._pipeline_opt['ring_id'] = 20
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
......@@ -423,7 +423,9 @@ class ShardingOptimizer(MetaOptimizerBase):
False)
else:
for pair in self.pipeline_pair:
print("pp pair:{}".format(pair))
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
......@@ -437,8 +439,7 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, start_ring_id, False,
False)
pp_group_endpoints, pp_rank, ring_id, False, False)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......@@ -869,7 +870,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_ring_id = 2
self.pp_ring_id = 20
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
self.sharding_group_endpoints = [
......@@ -885,7 +886,7 @@ class ShardingOptimizer(MetaOptimizerBase):
else:
self.mp_group_id = 0
self.sharding_ring_id = 1
self.pp_ring_id = 2
self.pp_ring_id = 20
self.mp_rank = self.global_rank % self._inner_parallelism_size
self.mp_group = self.global_rank // self._inner_parallelism_size
self.mp_group_endpoints = [
......
......@@ -3789,6 +3789,7 @@ class PipelineOptimizer(object):
self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None
self._pipeline_pair = []
self._pp_ring_map = dict()
def _create_vars(self, block, ori_block):
# Create vars for block, copied from ori_block
......@@ -3841,6 +3842,8 @@ class PipelineOptimizer(object):
dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient
continue
# TODO add allreduce_max when without sharding
if not should_insert: continue
out_name = op.desc.output_arg_names()[0]
out_var = block.var(out_name)
......@@ -4428,12 +4431,11 @@ class PipelineOptimizer(object):
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
if self.schedule_mode == 0: # GPipe
block._insert_op(
index=index + extra_index,
......@@ -4467,6 +4469,13 @@ class PipelineOptimizer(object):
extra_index += 1
continue
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
block._insert_op(
index=index + extra_index,
#type='send_v2',
......@@ -4544,7 +4553,7 @@ class PipelineOptimizer(object):
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
......@@ -4608,7 +4617,14 @@ class PipelineOptimizer(object):
var = block.vars[var_name]
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
#block._insert_op(
pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
if self.schedule_mode == 0: # GPipe
block._insert_op_without_sync(
index=index + extra_index,
type='send_v2',
......@@ -4618,10 +4634,11 @@ class PipelineOptimizer(object):
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
#block._insert_op(
block._insert_op_without_sync(
index=index + extra_index,
type='recv_v2',
......@@ -4633,10 +4650,102 @@ class PipelineOptimizer(object):
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
continue
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync(
index=index + extra_index,
#type='send_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names):
......@@ -5120,7 +5229,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id,
}
return optimize_ops, params_grads, program_list, self._pipeline_pair
return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map
class RecomputeOptimizer(Optimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册