提交 75644caf 编写于 作者: S sandyhouse

update

上级 5cd2bfec
......@@ -105,7 +105,7 @@ class FP16Utils(object):
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
param_name = input_name.strip("@GRAD@MERGED")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
......
......@@ -357,7 +357,7 @@ def get_grad_device(grad_name, shard):
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD_0', '.cast_fp16@GRAD', '@GRAD_0', '@GRAD'
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
......
......@@ -103,8 +103,6 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
if self.inner_opt is None:
raise ValueError(
......@@ -947,8 +945,9 @@ class ShardingOptimizer(MetaOptimizerBase):
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) if
(idx % self.sharding_group_size) == self.sharding_rank
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size
) == self.sharding_rank
]
else:
self.mp_group_id = 0
......@@ -972,12 +971,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) if
(idx //
(self._inner_parallelism_size *
self.sharding_group_size)) == self.sharding_group_id
and
idx % self._inner_parallelism_size == self.megatron_rank
ep for idx, ep in enumerate(self.endpoints)
if (idx // (self._inner_parallelism_size *
self.sharding_group_size)
) == self.sharding_group_id and idx %
self._inner_parallelism_size == self.megatron_rank
]
print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
......
......@@ -4898,6 +4898,7 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
else:
# cast gradient to fp32 to accumulate to merged gradient
cast_grad_var_name = param_grad_name + '@TMP'
......@@ -4928,6 +4929,8 @@ class PipelineOptimizer(object):
self._op_role_var_key: op_role_var
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
return merged_gradient_names
def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册