未验证 提交 01222f52 编写于 作者: W wuhuachaocoding 提交者: GitHub

[hybrid] Fix out of memory bug (#39009)

上级 e43b6f65
...@@ -135,6 +135,11 @@ class TestFleetMetaOptimizer(unittest.TestCase): ...@@ -135,6 +135,11 @@ class TestFleetMetaOptimizer(unittest.TestCase):
learning_rate=0.01, learning_rate=0.01,
regularization=regularization, regularization=regularization,
grad_clip=grad_clip) grad_clip=grad_clip)
elif name == 'adamw':
optimizer = paddle.optimizer.AdamW(
learning_rate=0.01,
weight_decay=0.01,
grad_clip=grad_clip)
optimizer = fleet.distributed_optimizer( optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(loss) optimizer.minimize(loss)
......
...@@ -771,6 +771,125 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -771,6 +771,125 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_mp_pp_amp_gclip_for_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
self.set_strategy(strategy, 'amp')
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 2,
"pp_degree": 2,
"dp_degree": 1,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost,
strategy,
train_prog,
startup_prog,
grad_clip=clip,
name="adamw")
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'memcpy', 'fill_constant', 'c_allreduce_sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'adamw',
'adamw', 'adamw', 'adamw', 'adamw', 'adamw', 'adamw', 'adamw'
])
# pp + mp, partial send recv
self.assertIn('partial_recv', main_prog_op_types)
self.assertIn('partial_allgather', main_prog_op_types)
self.assertIn('partial_send', main_prog_op_types)
# amp check_finite_and_unscale, allreduce(mp)->allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 2)
# global gradient clip, allreduce(mp)->allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_sum'), 2)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.mp_ring_id, created_ring_ids)
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
# check correctness of pp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
mp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(mp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_1":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_pp_dp_amp_fp16allreduce(self): def test_hybrid_with_pp_dp_amp_fp16allreduce(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
) )
......
...@@ -722,6 +722,13 @@ class Optimizer(object): ...@@ -722,6 +722,13 @@ class Optimizer(object):
self._append_optimize_multi_tensor_op( self._append_optimize_multi_tensor_op(
target_block, parameters_and_grads) target_block, parameters_and_grads)
else: else:
if not framework.in_dygraph_mode():
params_grads_device_map = parameters_and_grads[
'params'] if isinstance(parameters_and_grads,
dict) else parameters_and_grads
self._update_param_device_map(params_grads_device_map,
target_block)
if isinstance(parameters_and_grads, list): if isinstance(parameters_and_grads, list):
self._create_accumulators(target_block, [ self._create_accumulators(target_block, [
p[0] for p in parameters_and_grads if not p[0].stop_gradient p[0] for p in parameters_and_grads if not p[0].stop_gradient
...@@ -757,11 +764,6 @@ class Optimizer(object): ...@@ -757,11 +764,6 @@ class Optimizer(object):
self._append_optimize_op(target_block, self._append_optimize_op(target_block,
param_grad_dict) param_grad_dict)
else: else:
params_grads_device_map = parameters_and_grads[
'params'] if isinstance(parameters_and_grads,
dict) else parameters_and_grads
self._update_param_device_map(params_grads_device_map,
target_block)
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册