提交 1e4345d4 编写于 作者: S sandyhouse

update, test=develop

上级 49ac67bc
......@@ -227,12 +227,18 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
"""
op_role = block.ops[insert_idx].attr('op_role')
if (insert_idx >= len(block.ops)) or (
op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return OpRole.Backward
# if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward
# if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
# return get_valid_op_role(block, insert_idx + 1)
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward
return get_valid_op_role(block, insert_idx + 1)
......@@ -480,6 +486,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
"""
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
......
......@@ -182,10 +182,17 @@ class ShardingOptimizer(MetaOptimizerBase):
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
param_list = []
for param_name, grad_name in params_grads:
if self._shard.has_param(param_name):
param_list.append(param_name)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
#param_list = []
#for param_name, grad_name in params_grads:
# if self._shard.has_param(param_name):
# param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer._accumulate_gradients(main_block)
#if not self._shard.has_param(param_name): continue
......@@ -359,13 +366,19 @@ class ShardingOptimizer(MetaOptimizerBase):
# config sharding & dp groups
self._init_comm()
# inner & outer model parallelism
# global
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank,
self.global_group_id, True)
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
......@@ -387,6 +400,9 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
# pp
if self.use_pipeline:
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False)
......@@ -660,6 +676,7 @@ class ShardingOptimizer(MetaOptimizerBase):
fill_constant_vars)
# step4: add `cast` ops
print("cast_ops:", cast_ops)
insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册