提交 767422ee 编写于 作者: S sandyhouse

update

上级 c4d789af
......@@ -231,20 +231,21 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
"""
op_role = block.ops[insert_idx].attr('op_role')
# if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if (insert_idx >= len(block.ops)) or (
op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return OpRole.Backward
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward
return get_valid_op_role(block, insert_idx + 1)
# if insert_idx >= len(block.ops): return OpRole.Optimize
# if op_role == int(OpRole.Backward): return OpRole.Backward
# if op_role == int(OpRole.Optimize): return OpRole.Optimize
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
"""
......
......@@ -153,7 +153,6 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.use_pipeline:
pp_optimizer._rename_gradient_var_name(main_block)
pp_optimizer._accumulate_gradients(main_block)
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
......@@ -201,6 +200,8 @@ class ShardingOptimizer(MetaOptimizerBase):
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
......
......@@ -4836,6 +4836,7 @@ class PipelineOptimizer(object):
input_names = op.input_arg_names
output_names = op.output_arg_names
in_out_names = input_names + output_names
if op.type == 'cast': continue
# append "MERGED" to the names of parameter gradients,
# and mofify the op_role_var attribute (by rename_arg func).
for name in in_out_names:
......@@ -4857,13 +4858,16 @@ class PipelineOptimizer(object):
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD@MERGED') in self._param_device_map:
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index)
continue
if self._is_backward_op(op) and not first_opt_op_idx:
first_opt_op_idx = index + 1
if block.ops[first_opt_op_idx].type == "c_sync_comm_stream":
#block.ops[first_opt_op_idx]._set_attr(self._op_role_key, self._op_role.Backward)
first_opt_op_idx += 1
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
......@@ -4872,17 +4876,13 @@ class PipelineOptimizer(object):
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
op._remove_attr(self._op_role_var_key)
#op._remove_attr(self._op_role_var_key)
for i in range(0, len(op_role_var), 2):
offset = 0
param_name = op_role_var[i]
assert block.has_var(param_name), (
"parameter {} not in "
"current block.".format(param_name))
# clear gradient
assert param_name in self.origin_main_block.vars, "[{}] not in original main block".format(
param_name)
param_grad_name = self._append_grad_suffix(param_name)
if not block.has_var(param_name): continue
if '@BroadCast' in param_name: continue
param_grad_name = param_name + core.grad_var_suffix()
merged_param_grad_name = param_grad_name + '@MERGED'
if not block.has_var(merged_param_grad_name):
self._create_var(block, block.vars[param_name],
......@@ -4944,7 +4944,7 @@ class PipelineOptimizer(object):
attrs={
# self._op_device_key: device,
self._op_role_key: self._op_role.Backward,
self._op_role_var_key: op_role_var
#self._op_role_var_key: op_role_var
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册