提交 767422ee 编写于 作者: S sandyhouse

update

上级 c4d789af
...@@ -231,20 +231,21 @@ def get_valid_op_role(block, insert_idx): ...@@ -231,20 +231,21 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward return OpRole.Forward or OpRole.Backward
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = block.ops[insert_idx].attr('op_role')
# if (insert_idx >= len(block.ops)) or ( if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward return OpRole.Backward
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward return OpRole.Forward
return get_valid_op_role(block, insert_idx + 1) return get_valid_op_role(block, insert_idx + 1)
# if insert_idx >= len(block.ops): return OpRole.Optimize
# if op_role == int(OpRole.Backward): return OpRole.Backward
# if op_role == int(OpRole.Optimize): return OpRole.Optimize
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
def insert_sync_calc_op(block, insert_idx, calc_dep_vars): def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
""" """
......
...@@ -153,7 +153,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -153,7 +153,6 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.use_pipeline: if self.use_pipeline:
pp_optimizer._rename_gradient_var_name(main_block) pp_optimizer._rename_gradient_var_name(main_block)
pp_optimizer._accumulate_gradients(main_block)
with open("main_%d" % self.role_maker._worker_index(), 'w') as f: with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program)) f.writelines(str(main_program))
...@@ -201,6 +200,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -201,6 +200,8 @@ class ShardingOptimizer(MetaOptimizerBase):
#if self._shard.has_param(param_name): continue #if self._shard.has_param(param_name): continue
if in_name not in main_block.vars: if in_name not in main_block.vars:
main_block._remove_op(idx) main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names) # accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize: if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ") print("persistable FP32 grad: ")
......
...@@ -4836,6 +4836,7 @@ class PipelineOptimizer(object): ...@@ -4836,6 +4836,7 @@ class PipelineOptimizer(object):
input_names = op.input_arg_names input_names = op.input_arg_names
output_names = op.output_arg_names output_names = op.output_arg_names
in_out_names = input_names + output_names in_out_names = input_names + output_names
if op.type == 'cast': continue
# append "MERGED" to the names of parameter gradients, # append "MERGED" to the names of parameter gradients,
# and mofify the op_role_var attribute (by rename_arg func). # and mofify the op_role_var attribute (by rename_arg func).
for name in in_out_names: for name in in_out_names:
...@@ -4857,13 +4858,16 @@ class PipelineOptimizer(object): ...@@ -4857,13 +4858,16 @@ class PipelineOptimizer(object):
if self._is_optimize_op(op) and op.type == 'cast': if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0] in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0] out_name = op.output_arg_names[0]
if out_name.strip('@GRAD@MERGED') in self._param_device_map: if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index) block._remove_op(index)
continue continue
if self._is_backward_op(op) and not first_opt_op_idx: if self._is_backward_op(op) and not first_opt_op_idx:
first_opt_op_idx = index + 1 first_opt_op_idx = index + 1
if block.ops[first_opt_op_idx].type == "c_sync_comm_stream":
#block.ops[first_opt_op_idx]._set_attr(self._op_role_key, self._op_role.Backward)
first_opt_op_idx += 1
if self._is_backward_op(op) and ( if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names): self._op_role_var_key in op.attr_names):
...@@ -4872,17 +4876,13 @@ class PipelineOptimizer(object): ...@@ -4872,17 +4876,13 @@ class PipelineOptimizer(object):
if len(op_role_var) == 0: if len(op_role_var) == 0:
continue continue
assert len(op_role_var) % 2 == 0 assert len(op_role_var) % 2 == 0
op._remove_attr(self._op_role_var_key) #op._remove_attr(self._op_role_var_key)
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
offset = 0 offset = 0
param_name = op_role_var[i] param_name = op_role_var[i]
assert block.has_var(param_name), ( if not block.has_var(param_name): continue
"parameter {} not in " if '@BroadCast' in param_name: continue
"current block.".format(param_name)) param_grad_name = param_name + core.grad_var_suffix()
# clear gradient
assert param_name in self.origin_main_block.vars, "[{}] not in original main block".format(
param_name)
param_grad_name = self._append_grad_suffix(param_name)
merged_param_grad_name = param_grad_name + '@MERGED' merged_param_grad_name = param_grad_name + '@MERGED'
if not block.has_var(merged_param_grad_name): if not block.has_var(merged_param_grad_name):
self._create_var(block, block.vars[param_name], self._create_var(block, block.vars[param_name],
...@@ -4944,7 +4944,7 @@ class PipelineOptimizer(object): ...@@ -4944,7 +4944,7 @@ class PipelineOptimizer(object):
attrs={ attrs={
# self._op_device_key: device, # self._op_device_key: device,
self._op_role_key: self._op_role.Backward, self._op_role_key: self._op_role.Backward,
self._op_role_var_key: op_role_var #self._op_role_var_key: op_role_var
}) })
offset += 1 offset += 1
merged_gradient_names.append(merged_param_grad_name) merged_gradient_names.append(merged_param_grad_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册