提交 d3105dbf 编写于 作者: S sandyhouse

update

上级 7aa0cc3c
......@@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def __init__(self, mp_ring_id):
self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
......@@ -31,6 +31,7 @@ class GradientClipHelper(object):
"""
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
......@@ -44,6 +45,8 @@ class GradientClipHelper(object):
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op:
deperate_op_idx.add(idx)
......@@ -65,31 +68,47 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
# this allreduce should not overlap with calc and should be scheduled in calc stream
# block._insert_op_without_sync(
# idx + 1,
# type='c_sync_comm_stream',
# inputs={'X': sum_res},
# outputs={'Out': sum_res},
# attrs={'ring_id': 0,
# OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
'ring_id': self.mp_ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})
# block._insert_op_without_sync(
# idx + 1,
# type='c_sync_calc_stream',
# inputs={'X': sum_res},
# outputs={'Out': sum_res},
# attrs={OP_ROLE_KEY: OpRole.Optimize})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param for param, worker_idx in shard.global_param2device.items()
if worker_idx == shard.worker_idx
]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
......
......@@ -227,14 +227,9 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
"""
op_role = block.ops[insert_idx].attr('op_role')
#if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if (insert_idx >= len(block.ops)) or (
op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return OpRole.Backward
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward
......@@ -485,9 +480,6 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
"""
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册