提交 75644caf 编写于 作者: S sandyhouse

update

上级 5cd2bfec
...@@ -105,7 +105,7 @@ class FP16Utils(object): ...@@ -105,7 +105,7 @@ class FP16Utils(object):
reversed_x = [] reversed_x = []
reversed_x_paramname = [] reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD@MERGED")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError( raise ValueError(
"Input 'X' of check_finite_and_unscale must" "Input 'X' of check_finite_and_unscale must"
......
...@@ -357,7 +357,7 @@ def get_grad_device(grad_name, shard): ...@@ -357,7 +357,7 @@ def get_grad_device(grad_name, shard):
base_name = None base_name = None
# mind the traversal order # mind the traversal order
possible_suffixes = [ possible_suffixes = [
'.cast_fp16@GRAD_0', '.cast_fp16@GRAD', '@GRAD_0', '@GRAD' '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
] ]
for suffix in possible_suffixes: for suffix in possible_suffixes:
if suffix in grad_name: if suffix in grad_name:
......
...@@ -103,8 +103,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -103,8 +103,6 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"] "pp_allreduce_in_optimize"]
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
...@@ -947,8 +945,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -947,8 +945,9 @@ class ShardingOptimizer(MetaOptimizerBase):
] ]
self.pp_group_size = self.pipeline_nodes self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [ self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) if ep for idx, ep in enumerate(self.endpoints)
(idx % self.sharding_group_size) == self.sharding_rank if (idx % self.sharding_group_size
) == self.sharding_rank
] ]
else: else:
self.mp_group_id = 0 self.mp_group_id = 0
...@@ -972,12 +971,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -972,12 +971,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._inner_parallelism_size * self.sharding_group_size) self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) if ep for idx, ep in enumerate(self.endpoints)
(idx // if (idx // (self._inner_parallelism_size *
(self._inner_parallelism_size * self.sharding_group_size)
self.sharding_group_size)) == self.sharding_group_id ) == self.sharding_group_id and idx %
and self._inner_parallelism_size == self.megatron_rank
idx % self._inner_parallelism_size == self.megatron_rank
] ]
print("sharding_endpoint:", self.sharding_group_endpoints) print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank) print("sharding_rank:", self.sharding_rank)
......
...@@ -4898,6 +4898,7 @@ class PipelineOptimizer(object): ...@@ -4898,6 +4898,7 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Backward, self._op_role_key: self._op_role.Backward,
}) })
offset += 1 offset += 1
merged_gradient_names.append(merged_param_grad_name)
else: else:
# cast gradient to fp32 to accumulate to merged gradient # cast gradient to fp32 to accumulate to merged gradient
cast_grad_var_name = param_grad_name + '@TMP' cast_grad_var_name = param_grad_name + '@TMP'
...@@ -4928,6 +4929,8 @@ class PipelineOptimizer(object): ...@@ -4928,6 +4929,8 @@ class PipelineOptimizer(object):
self._op_role_var_key: op_role_var self._op_role_var_key: op_role_var
}) })
offset += 1 offset += 1
merged_gradient_names.append(merged_param_grad_name)
return merged_gradient_names
def _add_sub_blocks(self, main_block, program_list): def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program main_program = main_block.program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册