提交 f874e02b 编写于 作者: S sandyhouse

update optimizer

上级 d2c81529
......@@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
......@@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""Implementation of minimize."""
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
......@@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase):
self.user_defined_strategy.sharding_configs["parallelism"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt)
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt,
self.acc_steps)
main_program = loss.block.program
main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index() // (
......@@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list = pp_optimizer.minimize(
optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
......@@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
return optimize_ops, params_grads
......@@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase):
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False)
for pair in self.pipeline_pair:
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, start_ring_id, False)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......
......@@ -413,6 +413,8 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
cfg = section_param.section_config
program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
......@@ -3788,6 +3788,7 @@ class PipelineOptimizer(object):
self._op_role_var_key = op_maker.kOpRoleVarAttrName()
self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None
self._pipeline_pair = []
def _create_vars(self, block, ori_block):
# Create vars for block, copied from ori_block
......@@ -4134,6 +4135,7 @@ class PipelineOptimizer(object):
if not var_name in first_block.vars:
self._create_var(first_block, main_var, var_name)
dev_index = int(device.split(':')[1])
print("dev_index:", dev_index)
first_block._insert_op(
index=insert_index,
type='send_v2',
......@@ -4141,9 +4143,11 @@ class PipelineOptimizer(object):
attrs={
self._op_device_key: first_dev_spec,
self._op_role_key: self._op_role.Forward,
'use_calc_stream': True,
'use_calc_stream': False,
'peer': dev_index,
'ring_id': self.ring_id,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if dev_index > first_dev_index else self.ring_id + 2,
})
# Get the device that that data on
assert device in devices
......@@ -4168,7 +4172,21 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Forward,
'peer': first_dev_index,
'use_calc_stream': True,
'ring_id': self.ring_id,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index < dev_index else self.ring_id + 2,
})
block._insert_op(
index=index + 1,
type='c_sync_comm_stream',
inputs={'X': [new_var]},
outputs={'Out': [new_var]},
attrs={
self._op_device_key: device,
self._op_role_key: self._op_role.Forward,
#'ring_id': self.ring_id,
'ring_id': self.ring_id
if first_dev_index > dev_index else self.ring_id + 2,
})
def _strip_grad_suffix(self, name):
......@@ -4409,30 +4427,91 @@ class PipelineOptimizer(object):
var = block.vars[var_name]
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
block._insert_op(
index=index + extra_index,
type='send_v2',
#type='send_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'use_calc_stream': False,
#'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward,
'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
#type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'value': float(0.0),
})
extra_index += 1
block._insert_op(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key: op_role,
'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
......@@ -4512,6 +4591,15 @@ class PipelineOptimizer(object):
first_optimize_op_index = None
for index, op in reversed(tuple(enumerate(list(block.ops)))):
# device = op.attr(self._op_device_key)
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index)
continue
if not self._is_optimize_op(op) and not first_optimize_op_index:
first_optimize_op_index = index + 1
if block.ops[
......@@ -4553,11 +4641,11 @@ class PipelineOptimizer(object):
# a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched,
})
offset += 1
#offset += 1
grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name] # without _0 suffix
grad_var = block.vars[grad_name]
real_grad_name = grad_name[0:grad_name.find(
'@GRAD')] + '@GRAD'
'@GRAD')] + '@GRAD' # without _0 suffix
real_grad_var = block.vars[
real_grad_name] # without _0 suffix
# new_grad_var_name = unique_name.generate(grad_name)
......@@ -4567,7 +4655,7 @@ class PipelineOptimizer(object):
# self._rename_arg(op, grad_name, new_grad_var_name)
if not 'cast_fp16' in grad_name:
block._insert_op(
index=first_optimize_op_index + offset,
index=index + 1,
type='sum',
inputs={'X': [grad_var, real_grad_var]},
outputs={'Out': real_grad_var},
......@@ -4576,58 +4664,83 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Backward,
#self._op_role_var_key: op_role_var
})
offset += 1
#offset += 1
else:
grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name] # without _0 suffix
fp32_grad_var_name = param_name + core.grad_var_suffix()
grad_var = block.vars[grad_name]
fp32_grad_var_name = param_name + core.grad_var_suffix(
) # without _0 suffix
fp32_grad_var = block.vars[fp32_grad_var_name]
fp32_grad_var.persistable = True
cast_grad_var_name = unique_name.generate(
fp32_grad_var_name)
cast_var = self._create_var(block, grad_var,
cast_grad_var_name)
cast_var.persistable = False
real_grad_name = grad_name[0:grad_name.find(
'@GRAD')] + '@GRAD'
real_grad_var = block.vars[
real_grad_name] # without _0 suffix
cast_grad_var = self._create_var(block, fp32_grad_var,
cast_grad_var_name)
cast_grad_var.persistable = False
block._insert_op(
index=first_optimize_op_index + offset,
index=index + 1,
type='cast',
inputs={'X': fp32_grad_var},
outputs={'Out': cast_var},
inputs={'X': grad_var},
outputs={'Out': cast_grad_var},
attrs={
'in_dtype': fp32_grad_var.dtype,
'out_dtype': cast_var.dtype,
'in_dtype': grad_var.dtype,
'out_dtype': cast_grad_var.dtype,
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
# self._op_role_var_key: op_role_var
})
offset += 1
block._insert_op(
index=first_optimize_op_index + offset,
index=index + 2,
type='sum',
inputs={'X': [grad_var, cast_var]},
outputs={'Out': real_grad_var},
attrs={
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
# self._op_role_var_key: op_role_var
})
offset += 1
block._insert_op(
index=first_optimize_op_index + offset,
type='cast',
inputs={'X': real_grad_var},
inputs={'X': [fp32_grad_var, cast_grad_var]},
outputs={'Out': fp32_grad_var},
attrs={
'in_dtype': real_grad_var.dtype,
'out_dtype': fp32_grad_var.dtype,
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
# self._op_role_var_key: op_role_var
})
offset += 1
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[
# real_grad_name] # without _0 suffix
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
#offset += 1
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': real_grad_var},
# attrs={
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
#offset += 1
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='cast',
# inputs={'X': real_grad_var},
# outputs={'Out': fp32_grad_var},
# attrs={
# 'in_dtype': real_grad_var.dtype,
# 'out_dtype': fp32_grad_var.dtype,
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program
......@@ -4720,12 +4833,14 @@ class PipelineOptimizer(object):
inputs={'X': write_block.var(var_name), },
attrs={
self._op_device_key: write_device,
'use_calc_stream': True,
'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': read_dev_index,
'ring_id': self.ring_id,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
read_dev_index > write_dev_index else self.ring_id + 2,
})
read_block._insert_op(
index=0,
......@@ -4735,12 +4850,28 @@ class PipelineOptimizer(object):
'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device,
'use_calc_stream': True,
'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index,
'ring_id': self.ring_id,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index < read_dev_index else self.ring_id + 2,
})
read_block._insert_op(
index=1,
type='c_sync_comm_stream',
inputs={'X': [read_block.var(var_name)]},
outputs={'Out': [read_block.var(var_name)]},
attrs={
self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
#'ring_id': self.ring_id,
'ring_id': self.ring_id if
write_dev_index > read_dev_index else self.ring_id + 2,
})
def _is_gradient_clip_op(self, op):
......@@ -4809,8 +4940,8 @@ class PipelineOptimizer(object):
program_list = self._split_program(main_program, device_list)
for p in program_list:
self._create_vars(p["program"].block(0), main_block)
self._insert_sendrecv_for_data_var(main_block, program_list,
startup_program, device_list)
#self._insert_sendrecv_for_data_var(main_block, program_list,
# startup_program, device_list)
# Step4: Special Case: process persistable vars that exist in
# multiple sections
......@@ -4824,8 +4955,8 @@ class PipelineOptimizer(object):
place_list = []
for dev in device_list:
dev_index = int(dev.split(":")[1]) % 8
place_list.append(core.CUDAPlace(dev_index))
dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index % 8))
# Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program,
......@@ -4851,6 +4982,8 @@ class PipelineOptimizer(object):
"trainer": "PipelineTrainer",
"device_worker": "Section",
"inner_parallelism": len(device_list),
"num_pipeline_stages": len(device_list),
"pipeline_stage": local_rank,
"section_program": program_list[local_rank],
"place": place_list[local_rank],
"place_id": place_id,
......@@ -4858,7 +4991,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id,
}
return optimize_ops, params_grads, program_list
return optimize_ops, params_grads, program_list, self._pipeline_pair
class RecomputeOptimizer(Optimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册