提交 eeca5ef6 编写于 作者: J JZ-LIANG 提交者: sandyhouse

update

上级 479efeeb
......@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op, OpRole
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
......@@ -208,7 +208,8 @@ class ShardingOptimizer(MetaOptimizerBase):
#pp_optimizer._clear_gradients(main_block, param_list)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
accumulated_grad_names = sorted(accumulated_grad_names)
# accumulated_grad_names = sorted(accumulated_grad_names)
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
......@@ -218,7 +219,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
OpRole.Optimize,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
......@@ -470,10 +471,20 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block())
def _wait(self, ):
# only the first parallelsm group that init nccl need to be wait.
if self._as_outer_parallelism:
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
else:
endpoints = self.sharding_group_endpoints[:]
current_endpoint = self.sharding_group_endpoints[self.sharding_rank]
if self._as_outer_parallelism:
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints)
else:
if self.sharding_rank == 0:
self._collective_helper._wait(current_endpoint, endpoints)
# def _wait(self, ):
# # only the first parallelsm group that init nccl need to be wait.
......
......@@ -4879,8 +4879,9 @@ class PipelineOptimizer(object):
if '@BroadCast' in param_name:
param_name = param_name[0:param_name.find('@BroadCast')]
# clear gradient
assert param_name in self.origin_main_block.vars, "[{}] not in original main block".format(
param_name)
param_grad_name = self._append_grad_suffix(param_name)
accumulated_grad_names.append(param_grad_name)
if not block.has_var(param_grad_name):
self._create_var(
block, self.origin_main_block.vars[param_name],
......@@ -4925,7 +4926,7 @@ class PipelineOptimizer(object):
#self._op_role_var_key: op_role_var
})
#offset += 1
# accumulated_gradient_names.append(param_grad_var.name)
accumulated_grad_names.append(param_grad_var.name)
else:
grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name]
......@@ -4962,7 +4963,7 @@ class PipelineOptimizer(object):
# self._op_role_var_key: op_role_var
})
offset += 1
# accumulated_gradient_names.append(param_grad_var.name)
accumulated_grad_names.append(param_grad_var.name)
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册