未验证 提交 a9673b44 编写于 作者: W WangXi 提交者: GitHub

[Hybrid Performance] Move the cast op of AMP which cast fp32 param to fp16...

[Hybrid Performance] Move the cast op of AMP which cast fp32 param to fp16 param to the optimizer (#34965)
上级 51939c83
......@@ -42,6 +42,7 @@ message ShardingConfig {
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
}
message HybridConfig {
......
......@@ -888,6 +888,9 @@ class DistributedStrategy(object):
pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on.
This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False.
optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it
will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large.
Examples:
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name
__all__ = []
......@@ -84,7 +84,7 @@ class OffloadHelper(object):
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
def offload_fp32param(self, block, startup_block, offload=True):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
......@@ -113,11 +113,12 @@ class OffloadHelper(object):
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
if is_update_op(op):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
......@@ -125,7 +126,7 @@ class OffloadHelper(object):
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
# param which will be used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
......@@ -154,17 +155,19 @@ class OffloadHelper(object):
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
if is_update_op(op):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
if offload:
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param,
offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
......@@ -173,8 +176,9 @@ class OffloadHelper(object):
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
if offload:
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
......@@ -206,9 +210,10 @@ class OffloadHelper(object):
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
......@@ -217,6 +222,19 @@ class OffloadHelper(object):
block._sync_with_cpp()
startup_block._sync_with_cpp()
def cast_fp32param_in_optimize(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(pout,) = adam(p)
(p_fp16) = cast(p)
"""
self.offload_fp32param(block, startup_block, offload=False)
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
......
......@@ -400,7 +400,14 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block)
elif sharding_configs['optimize_cast']:
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper()
offload_helper.cast_fp32param_in_optimize(main_block, startup_block)
def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
......@@ -444,6 +451,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# loss div dp_degree
self._insert_loss_grad_scale_op()
# apply optimize offload or optimize cast
self._apply_optimize_offload_pass()
# step6: (optional) sharding gradient merge
......
......@@ -859,6 +859,197 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"optimize_cast": True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random',
'cast', 'fill_constant', 'cast', 'uniform_random', 'cast',
'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'cast',
'momentum', 'cast', 'momentum', 'cast', 'momentum', 'momentum',
'cast'
])
# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"optimize_offload": True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast',
'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant',
'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy',
'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast',
'memcpy', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'memcpy',
'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast',
'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'memcpy',
'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast',
'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'momentum',
'memcpy', 'momentum', 'cast', 'memcpy'
])
# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册