未验证 提交 4d7af372 编写于 作者: W WangXi 提交者: GitHub

[hybrid] pp+dp support fp16 allreduce (#34762)

上级 3f962e77
......@@ -14,7 +14,7 @@
import paddle
from paddle.fluid import core, unique_name
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
import re
......@@ -431,15 +431,19 @@ def insert_reduce_ops(block,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
use_calc_stream=False,
rank=None):
"""
_add_allreduce_ops
"""
grad_in_this_device = []
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
root_id)
if rank is not None and rank == root_id:
grad_in_this_device.append(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
......@@ -451,16 +455,23 @@ def insert_reduce_ops(block,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
return grad_in_this_device
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
# NOTE: mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
# sharding gm
'.cast_fp16@GRAD@MERGED',
'.cast_fp16@GRAD',
# pipeline
'@GRAD@MERGED@FP16',
'@GRAD@MERGED',
'@GRAD',
]
for suffix in possible_suffixes:
if suffix in grad_name:
......@@ -487,6 +498,15 @@ def get_first_check_finite_and_unscale_op_idx(block, raise_error=True):
return -1
def get_first_optimize_op_idx(block):
first_opt_op_idx = None
for index, op in reversed(tuple(enumerate(block.ops))):
if is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
break
return first_opt_op_idx
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
......@@ -672,23 +692,6 @@ def save_persistables(exe, dirname, main_program, filename=None):
return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD']
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def append_naive_sync(block, sync_var, ring_id):
# NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
# sync within global
......
......@@ -294,6 +294,8 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree == 1: return
strategy = self.user_defined_strategy
fp16_allreduce = strategy.fp16_allreduce
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
......@@ -317,33 +319,44 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(idx)
accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
main_block, fp16_allreduce=fp16_allreduce)
len_of_ops = len(main_block.ops)
first_optimize_op_index = get_first_optimize_op_idx(main_block)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=strategy.amp)
insert_reduce_ops(
logger.info("Pipeline Persistable grad is {}".format(
accumulated_grad_names))
# FIXME(wangxi): accumulated_grad get from pipeline is not
# include sharding's param@BroadCast grad when
# pp_allreduce_in_optimize
accumulated_grad_names = insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
use_calc_stream=True,
rank=self.sharding_rank)
logger.info("PP-Sharding grad is {}".format(accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=strategy.amp)
if first_optimize_op_index >= 0:
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True,
user_defined_strategy=strategy)
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True,
user_defined_strategy=strategy)
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)
# FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there?
def _adapt_amp_clip_without_sharding(self):
if self.sharding_degree > 1: return
......
......@@ -4528,7 +4528,7 @@ class PipelineOptimizer(object):
op._rename_input(old_name, new_name)
op._rename_output(old_name, new_name)
def _create_var(self, block, ref_var, name):
def _create_var(self, block, ref_var, name, dtype=None):
"""
Create a new var for block, which has the same type,
shape and dtype as ref_var, then rename it with the
......@@ -4537,7 +4537,7 @@ class PipelineOptimizer(object):
new_var = block.create_var(
name=name,
shape=ref_var.shape,
dtype=ref_var.dtype,
dtype=ref_var.dtype if dtype is None else dtype,
type=ref_var.type,
lod_level=ref_var.lod_level,
persistable=ref_var.persistable,
......@@ -5044,7 +5044,10 @@ class PipelineOptimizer(object):
new_grad_name = name + "@MERGED"
self._rename_arg(op, name, new_grad_name)
def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False):
def _accumulate_gradients(self,
block,
pp_allreduce_in_optimize=False,
fp16_allreduce=False):
"""
Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it.
......@@ -5052,6 +5055,9 @@ class PipelineOptimizer(object):
merged_gradient_names = []
first_opt_op_idx = None
merged_suffix = '@MERGED@FP16' if fp16_allreduce else '@MERGED'
dtype = paddle.float16 if fp16_allreduce else None
for index, op in reversed(tuple(enumerate(list(block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
......@@ -5062,12 +5068,10 @@ class PipelineOptimizer(object):
block._remove_op(index)
continue
if self._is_backward_op(op) and not first_opt_op_idx:
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(block.ops): return
if block.ops[first_opt_op_idx].type == "c_sync_comm_stream":
first_opt_op_idx += 1
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
......@@ -5079,12 +5083,14 @@ class PipelineOptimizer(object):
param_name = op_role_var[i]
if not block.has_var(param_name): continue
if '@BroadCast' in param_name: continue
param_grad_name = param_name + core.grad_var_suffix()
merged_param_grad_name = param_grad_name + '@MERGED'
merged_param_grad_name = param_grad_name + merged_suffix
if not block.has_var(merged_param_grad_name):
self._create_var(block, block.vars[param_name],
merged_param_grad_name)
merged_param_grad_name, dtype)
assert block.has_var(merged_param_grad_name)
param_grad_var = block.var(param_grad_name)
merged_param_grad_var = block.var(merged_param_grad_name)
merged_param_grad_var.persistable = True
......@@ -5103,22 +5109,18 @@ class PipelineOptimizer(object):
offset += 1
grad_name = op_role_var[i + 1]
grad_var = block.vars[grad_name]
if not 'cast_fp16' in grad_name:
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={'X': [grad_var, merged_param_grad_var]},
outputs={'Out': merged_param_grad_var},
attrs={
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
else:
# cast gradient to fp32 to accumulate to merged gradient
is_fp16_grad = 'cast_fp16' in grad_name
need_cast = (is_fp16_grad is not fp16_allreduce)
if need_cast:
# if fp16_allreduce:
# cast grad to fp16 to accumulate to merged gradient
# else:
# cast grad to fp32 to accumulate to merged gradient
cast_grad_var_name = param_grad_name + '@TMP'
cast_grad_var = self._create_var(block, param_grad_var,
cast_grad_var_name)
cast_grad_var = self._create_var(
block, param_grad_var, cast_grad_var_name, dtype)
cast_grad_var.persistable = False
block._insert_op(
index=first_opt_op_idx + offset,
......@@ -5131,18 +5133,52 @@ class PipelineOptimizer(object):
self._op_role_key: self._op_role.Backward,
})
offset += 1
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={
'X': [merged_param_grad_var, cast_grad_var]
},
outputs={'Out': merged_param_grad_var},
attrs={
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
grad_var = cast_grad_var
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={'X': [merged_param_grad_var, grad_var]},
outputs={'Out': merged_param_grad_var},
attrs={self._op_role_key: self._op_role.Backward, })
offset += 1
merged_gradient_names.append(merged_param_grad_name)
if not fp16_allreduce: return merged_gradient_names
first_opt_op_idx = None
for index, op in reversed(tuple(enumerate(list(block.ops)))):
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
break
assert first_opt_op_idx is not None
# insert cast op from fp16->fp32
# FIXME(wangxi): maybe put in sharding is better, for some grad
# is not in sharding device.
for fp16_grad_name in merged_gradient_names:
grad_name = fp16_grad_name.replace('@FP16', '')
param_name = fp16_grad_name.replace('@GRAD@MERGED@FP16', '')
if not block.has_var(grad_name):
self._create_var(block, block.vars[param_name], grad_name)
assert block.has_var(grad_name)
fp16_grad_var = block.var(fp16_grad_name)
grad_var = block.var(grad_name)
grad_var.persistable = False
block._insert_op(
index=first_opt_op_idx,
type='cast',
inputs={'X': fp16_grad_var},
outputs={'Out': grad_var},
attrs={
'in_dtype': fp16_grad_var.dtype,
'out_dtype': grad_var.dtype,
self._op_role_key: self._op_role.Optimize,
})
return merged_gradient_names
def _add_sub_blocks(self, main_block, program_list):
......
......@@ -552,9 +552,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'c_sync_comm_stream', 'fill_constant', 'sum', 'fill_constant',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'momentum', 'momentum', 'momentum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum',
'momentum', 'momentum'
])
......@@ -694,6 +694,171 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_pp_dp_amp_fp16allreduce(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
'recv_v2', 'cast', 'mul', 'cast', 'elementwise_add', 'tanh', 'cast',
'mul', 'cast', 'elementwise_add', 'tanh', 'cast', 'mul', 'cast',
'elementwise_add', 'tanh', 'cast', 'mul', 'cast', 'elementwise_add',
'softmax', 'cross_entropy2', 'mean', 'elementwise_mul',
'fill_constant', 'scale', 'scale', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'softmax_grad',
'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'momentum', 'momentum', 'momentum', 'momentum', 'momentum',
'momentum', 'momentum'
])
# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_sharding_pp_amp_fp16allreduce_in_optimize(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"segment_broadcast_MB": 0.1,
"sharding_degree": 2,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 1,
'pp_allreduce_in_optimize': True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: sharding, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init'
])
# FIXME(wangxi): some bug in sharding+pp with pp_allreduce_in_optimize
# self.assertEqual(main_prog_op_types, [])
# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 2)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.sharding_ring_id, created_ring_ids)
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
# check correctness of sharding group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_1":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册