提交 c7f4af1d 编写于 作者: J JZ-LIANG 提交者: sandyhouse

update

上级 d7dd3f51
...@@ -275,6 +275,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): ...@@ -275,6 +275,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
""" """
insert sync_comm_op for vars insert sync_comm_op for vars
""" """
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -329,6 +332,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -329,6 +332,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
if len(allreduce_vars) == 0:
return
for var in allreduce_vars: for var in allreduce_vars:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -341,6 +347,52 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -341,6 +347,52 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD']
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
......
...@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core ...@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op, OpRole
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
...@@ -24,6 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor ...@@ -24,6 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import * from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging import logging
from functools import reduce from functools import reduce
...@@ -205,7 +206,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -205,7 +206,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# if self._shard.has_param(param_name): # if self._shard.has_param(param_name):
# param_list.append(param_name) # param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list) #pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer._accumulate_gradients(main_block) accumulated_gradient_names, first_optimize_op_index = pp_optimizer._accumulate_gradients(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_gradient_names,
self._shard,
OpRole.Optimize,
use_calc_stream=True)
#if not self._shard.has_param(param_name): continue #if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue ##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name) #assert main_block.has_var(grad_name)
...@@ -378,19 +388,19 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -378,19 +388,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self._init_comm() self._init_comm()
# global # global
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
if self._as_outer_parallelism: if self._as_outer_parallelism:
print("global_group_endpoints:", self.global_group_endpoints)
print("global_rank:", self.global_rank)
print("global_ring_id:", self.global_group_id)
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank, self.global_group_endpoints, self.global_rank,
self.global_group_id, True) self.global_group_id, False)
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
if self._as_outer_parallelism: if self._as_outer_parallelism:
print("mp_group_endpoints:", self.mp_group_endpoints)
print("mp_rank:", self.mp_rank)
print("mp_ring_id:", self.mp_group_id)
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False) self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False)
...@@ -408,7 +418,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -408,7 +418,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.hybrid_dp: if self.hybrid_dp:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False)
# pp # pp
if self.use_pipeline: if self.use_pipeline:
print("pp_group_endpoints:", self.pp_group_endpoints) print("pp_group_endpoints:", self.pp_group_endpoints)
...@@ -456,9 +466,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -456,9 +466,13 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block()) self._main_program.global_block())
def _wait(self, ): def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints() # only the first parallelsm group that init nccl need to be wait.
if self._as_outer_parallelism:
endpoints = self.global_group_endpoints[:]
else:
endpoints = self.sharding_group_endpoints[:]
current_endpoint = endpoints[self.role_maker._worker_index()] current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0: if self.sharding_rank == 0:
self._collective_helper._wait(current_endpoint, endpoints) self._collective_helper._wait(current_endpoint, endpoints)
def _split_program(self, block): def _split_program(self, block):
...@@ -500,17 +514,19 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -500,17 +514,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name))
# find reduce vars # find reduce vars
if is_backward_op(op) and \ if not self.use_pipeline:
OP_ROLE_VAR_KEY in op.attr_names: if is_backward_op(op) and \
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] OP_ROLE_VAR_KEY in op.attr_names:
if len(op_role_var) != 0: op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
assert len(op_role_var) % 2 == 0 if len(op_role_var) != 0:
for i in range(0, len(op_role_var), 2): assert len(op_role_var) % 2 == 0
param, reduced_grad = op_role_var[i], op_role_var[i + 1] for i in range(0, len(op_role_var), 2):
segment._allreduce_vars.append(reduced_grad) param, reduced_grad = op_role_var[i], op_role_var[
#assert ( i + 1]
# reduced_grad not in self._reduced_grads_to_param) segment._allreduce_vars.append(reduced_grad)
self._reduced_grads_to_param[reduced_grad] = param #assert (
# reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op # find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params): if FP16Utils.is_fp16_cast_op(block, op, self._params):
...@@ -629,10 +645,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -629,10 +645,17 @@ class ShardingOptimizer(MetaOptimizerBase):
def _add_broadcast_allreduce(self, block): def _add_broadcast_allreduce(self, block):
""" """
_add_broadcast_allreduce _add_broadcast_allreduce
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
""" """
if len(self._segments) < 1: if len(self._segments) < 1:
return return
# sharding # sharding
if self.use_pipeline:
for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1] shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars) ._allreduce_vars)
...@@ -780,6 +803,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -780,6 +803,12 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self): def _init_comm(self):
# sharding alone mode
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_endpoints = self.endpoints[:]
self.sharding_group_size = len(self.endpoints)
if self.hybrid_dp: if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_group_size = self.user_defined_strategy.sharding_configs[
...@@ -799,6 +828,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -799,6 +828,9 @@ class ShardingOptimizer(MetaOptimizerBase):
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank if (idx % self.sharding_group_size) == self.sharding_rank
] ]
self.global_group_endpoints = self.role_maker._get_trainer_endpoints(
)[:]
assert self.global_word_size > self.sharding_group_size, \ assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \ assert self.global_word_size % self.sharding_group_size == 0, \
......
...@@ -4842,6 +4842,9 @@ class PipelineOptimizer(object): ...@@ -4842,6 +4842,9 @@ class PipelineOptimizer(object):
""" """
Accumulate the gradients generated in microbatch to the one in mini-batch. Accumulate the gradients generated in microbatch to the one in mini-batch.
""" """
# the name of real grad vars that should be allreduce
accumulated_gradient_names = []
first_optimize_op_index = None first_optimize_op_index = None
accumulated_grad_names = [] accumulated_grad_names = []
for index, op in reversed(tuple(enumerate(list(block.ops)))): for index, op in reversed(tuple(enumerate(list(block.ops)))):
...@@ -4921,6 +4924,7 @@ class PipelineOptimizer(object): ...@@ -4921,6 +4924,7 @@ class PipelineOptimizer(object):
#self._op_role_var_key: op_role_var #self._op_role_var_key: op_role_var
}) })
#offset += 1 #offset += 1
accumulated_gradient_names.append(real_grad_var.name)
else: else:
grad_name = op_role_var[i + 1] # with _0 suffix grad_name = op_role_var[i + 1] # with _0 suffix
grad_var = block.vars[grad_name] grad_var = block.vars[grad_name]
...@@ -4957,6 +4961,7 @@ class PipelineOptimizer(object): ...@@ -4957,6 +4961,7 @@ class PipelineOptimizer(object):
# self._op_role_var_key: op_role_var # self._op_role_var_key: op_role_var
}) })
offset += 1 offset += 1
accumulated_gradient_names.append(fp32_grad_var.name)
#real_grad_name = grad_name[0:grad_name.find( #real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD' # '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[ #real_grad_var = block.vars[
...@@ -4997,7 +5002,7 @@ class PipelineOptimizer(object): ...@@ -4997,7 +5002,7 @@ class PipelineOptimizer(object):
# self._op_role_key: self._op_role.Backward, # self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var # # self._op_role_var_key: op_role_var
# }) # })
return first_optimize_op_index, accumulated_grad_names return accumulated_gradient_names, first_optimize_op_index
def _add_sub_blocks(self, main_block, program_list): def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program main_program = main_block.program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册