提交 c7f4af1d 编写于 作者: J JZ-LIANG 提交者: sandyhouse

update

上级 d7dd3f51
......@@ -275,6 +275,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
......@@ -329,6 +332,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
"""
_add_allreduce_ops
"""
if len(allreduce_vars) == 0:
return
for var in allreduce_vars:
block._insert_op_without_sync(
insert_idx,
......@@ -341,6 +347,52 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD']
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
......
......@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op, OpRole
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
......@@ -24,6 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging
from functools import reduce
......@@ -205,7 +206,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# if self._shard.has_param(param_name):
# param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer._accumulate_gradients(main_block)
accumulated_gradient_names, first_optimize_op_index = pp_optimizer._accumulate_gradients(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_gradient_names,
self._shard,
OpRole.Optimize,
use_calc_stream=True)
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name)
......@@ -378,19 +388,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self._init_comm()
# global
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
if self._as_outer_parallelism:
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank,
self.global_group_id, True)
self.global_group_id, False)
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
if self._as_outer_parallelism:
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False)
......@@ -408,7 +418,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.hybrid_dp:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False)
# pp
if self.use_pipeline:
print("pp_group_endpoints:", self.pp_group_endpoints)
......@@ -456,9 +466,13 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block())
def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints()
# only the first parallelsm group that init nccl need to be wait.
if self._as_outer_parallelism:
endpoints = self.global_group_endpoints[:]
else:
endpoints = self.sharding_group_endpoints[:]
current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0:
if self.sharding_rank == 0:
self._collective_helper._wait(current_endpoint, endpoints)
def _split_program(self, block):
......@@ -500,17 +514,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name))
# find reduce vars
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
segment._allreduce_vars.append(reduced_grad)
#assert (
# reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
if not self.use_pipeline:
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[
i + 1]
segment._allreduce_vars.append(reduced_grad)
#assert (
# reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params):
......@@ -629,10 +645,17 @@ class ShardingOptimizer(MetaOptimizerBase):
def _add_broadcast_allreduce(self, block):
"""
_add_broadcast_allreduce
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
"""
if len(self._segments) < 1:
return
# sharding
if self.use_pipeline:
for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0
if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars)
......@@ -780,6 +803,12 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self):
# sharding alone mode
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_endpoints = self.endpoints[:]
self.sharding_group_size = len(self.endpoints)
if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
......@@ -799,6 +828,9 @@ class ShardingOptimizer(MetaOptimizerBase):
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
self.global_group_endpoints = self.role_maker._get_trainer_endpoints(
)[:]
assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \
......
......@@ -4842,6 +4842,9 @@ class PipelineOptimizer(object):
"""
Accumulate the gradients generated in microbatch to the one in mini-batch.
"""
# the name of real grad vars that should be allreduce
accumulated_gradient_names = []
first_optimize_op_index = None
accumulated_grad_names = []
for index, op in reversed(tuple(enumerate(list(block.ops)))):
......@@ -4921,6 +4924,7 @@ class PipelineOptimizer(object):
#self._op_role_var_key: op_role_var
})
#offset += 1
accumulated_gradient_names.append(real_grad_var.name)
else:
grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name]
......@@ -4957,6 +4961,7 @@ class PipelineOptimizer(object):
# self._op_role_var_key: op_role_var
})
offset += 1
accumulated_gradient_names.append(fp32_grad_var.name)
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[
......@@ -4997,7 +5002,7 @@ class PipelineOptimizer(object):
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
return first_optimize_op_index, accumulated_grad_names
return accumulated_gradient_names, first_optimize_op_index
def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册