提交 d3105dbf 编写于 作者: S sandyhouse

update

上级 7aa0cc3c
...@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole ...@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self, sharding_ring_id): def __init__(self, mp_ring_id):
self.sharding_ring_id = sharding_ring_id self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
...@@ -31,6 +31,7 @@ class GradientClipHelper(object): ...@@ -31,6 +31,7 @@ class GradientClipHelper(object):
""" """
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
...@@ -44,6 +45,8 @@ class GradientClipHelper(object): ...@@ -44,6 +45,8 @@ class GradientClipHelper(object):
if shard.is_param(param_name) and \ if shard.is_param(param_name) and \
not shard.has_param(param_name): not shard.has_param(param_name):
deperate_op = True deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op: if deperate_op:
deperate_op_idx.add(idx) deperate_op_idx.add(idx)
...@@ -65,31 +68,47 @@ class GradientClipHelper(object): ...@@ -65,31 +68,47 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
reversed_inputs.append(input_name) reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs) op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1, # this allreduce should not overlap with calc and should be scheduled in calc stream
type='c_sync_comm_stream', # block._insert_op_without_sync(
inputs={'X': sum_res}, # idx + 1,
outputs={'Out': sum_res}, # type='c_sync_comm_stream',
attrs={'ring_id': 0, # inputs={'X': sum_res},
OP_ROLE_KEY: OpRole.Optimize}) # outputs={'Out': sum_res},
# attrs={'ring_id': 0,
# OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'ring_id': self.sharding_ring_id, 'ring_id': self.mp_ring_id,
OP_ROLE_KEY: OpRole.Optimize 'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
}) })
block._insert_op_without_sync( # block._insert_op_without_sync(
idx + 1, # idx + 1,
type='c_sync_calc_stream', # type='c_sync_calc_stream',
inputs={'X': sum_res}, # inputs={'X': sum_res},
outputs={'Out': sum_res}, # outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize}) # attrs={OP_ROLE_KEY: OpRole.Optimize})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param for param, worker_idx in shard.global_param2device.items()
if worker_idx == shard.worker_idx
]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars: for var_name in deperated_vars:
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
......
...@@ -227,14 +227,9 @@ def get_valid_op_role(block, insert_idx): ...@@ -227,14 +227,9 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward return OpRole.Forward or OpRole.Backward
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = block.ops[insert_idx].attr('op_role')
#if (insert_idx >= len(block.ops)) or ( if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward return OpRole.Forward
...@@ -485,9 +480,6 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -485,9 +480,6 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training. This function handles the model saving for sharding training.
""" """
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding # now only Momentum and adam are compatible with sharding
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册