未验证 提交 a9673b44 编写于 作者: W WangXi 提交者: GitHub

[Hybrid Performance] Move the cast op of AMP which cast fp32 param to fp16...

[Hybrid Performance] Move the cast op of AMP which cast fp32 param to fp16 param to the optimizer (#34965)
上级 51939c83
...@@ -42,6 +42,7 @@ message ShardingConfig { ...@@ -42,6 +42,7 @@ message ShardingConfig {
optional bool optimize_offload = 9 [ default = false ]; optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ]; optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -888,6 +888,9 @@ class DistributedStrategy(object): ...@@ -888,6 +888,9 @@ class DistributedStrategy(object):
pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on.
This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False.
optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it
will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large.
Examples: Examples:
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
__all__ = [] __all__ = []
...@@ -84,7 +84,7 @@ class OffloadHelper(object): ...@@ -84,7 +84,7 @@ class OffloadHelper(object):
dtype=var.dtype, dtype=var.dtype,
persistable=True) persistable=True)
def offload_fp32param(self, block, startup_block): def offload_fp32param(self, block, startup_block, offload=True):
""" """
(p_fp16) = cast(p) (p_fp16) = cast(p)
(p_fp16_recompute) = cast(p) (p_fp16_recompute) = cast(p)
...@@ -113,11 +113,12 @@ class OffloadHelper(object): ...@@ -113,11 +113,12 @@ class OffloadHelper(object):
# step1: record param # step1: record param
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'): if is_update_op(op):
param = op.desc.input("Param")[0] param = op.desc.input("Param")[0]
param_to_idx[param] = idx param_to_idx[param] = idx
# step2: remove param which can't offload # step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if is_optimizer_op(op): if is_optimizer_op(op):
break break
...@@ -125,7 +126,7 @@ class OffloadHelper(object): ...@@ -125,7 +126,7 @@ class OffloadHelper(object):
if input_name not in param_to_idx: if input_name not in param_to_idx:
continue continue
# param is real used by fp32 op # param which will be used by fp32 op
if op.type != 'cast': if op.type != 'cast':
remove_param(input_name) remove_param(input_name)
continue continue
...@@ -154,17 +155,19 @@ class OffloadHelper(object): ...@@ -154,17 +155,19 @@ class OffloadHelper(object):
# step3: main_block add offload, cast op # step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16 # change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'): if is_update_op(op):
param = op.desc.input("Param")[0] param = op.desc.input("Param")[0]
if param not in param_to_idx: continue if param not in param_to_idx: continue
# step3.1: create offload_var # step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param) offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name, if offload:
[block, startup_block]) self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op # step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name) self._insert_offload_op(block, idx + 1, param,
offload_var_name)
assert param in param_to_fp16 assert param in param_to_fp16
fp16_param_name = param_to_fp16[param] fp16_param_name = param_to_fp16[param]
...@@ -173,8 +176,9 @@ class OffloadHelper(object): ...@@ -173,8 +176,9 @@ class OffloadHelper(object):
self._insert_cast_op(block, idx + 1, param, self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param]) param_to_fp16[param])
# step3.3: insert fetch op if offload:
self._insert_fetch_op(block, idx, offload_var_name, param) # step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue continue
# step3.4: remove cast op # step3.4: remove cast op
...@@ -206,9 +210,10 @@ class OffloadHelper(object): ...@@ -206,9 +210,10 @@ class OffloadHelper(object):
if out_name in param_name_to_offload_name: if out_name in param_name_to_offload_name:
var_name = out_name var_name = out_name
offload_var_name = param_name_to_offload_name[var_name] if offload:
self._insert_offload_op(startup_block, idx + 1, var_name, offload_var_name = param_name_to_offload_name[var_name]
offload_var_name) self._insert_offload_op(startup_block, idx + 1,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name, self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name]) param_to_fp16[var_name])
...@@ -217,6 +222,19 @@ class OffloadHelper(object): ...@@ -217,6 +222,19 @@ class OffloadHelper(object):
block._sync_with_cpp() block._sync_with_cpp()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
def cast_fp32param_in_optimize(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(pout,) = adam(p)
(p_fp16) = cast(p)
"""
self.offload_fp32param(block, startup_block, offload=False)
def offload(self, block, startup_block): def offload(self, block, startup_block):
""" """
(m1, m2) = prefetch(m1@offload, m2@offload) (m1, m2) = prefetch(m1@offload, m2@offload)
......
...@@ -400,7 +400,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -400,7 +400,14 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("Sharding with optimize offload !") logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper() offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block) offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block) offload_helper.offload_fp32param(main_block, startup_block)
elif sharding_configs['optimize_cast']:
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper()
offload_helper.cast_fp32param_in_optimize(main_block, startup_block)
def _dump_program_for_debug(self): def _dump_program_for_debug(self):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
...@@ -444,6 +451,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -444,6 +451,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# loss div dp_degree # loss div dp_degree
self._insert_loss_grad_scale_op() self._insert_loss_grad_scale_op()
# apply optimize offload or optimize cast
self._apply_optimize_offload_pass() self._apply_optimize_offload_pass()
# step6: (optional) sharding gradient merge # step6: (optional) sharding gradient merge
......
...@@ -859,6 +859,197 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -859,6 +859,197 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"optimize_cast": True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random',
'cast', 'fill_constant', 'cast', 'uniform_random', 'cast',
'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'cast',
'momentum', 'cast', 'momentum', 'cast', 'momentum', 'momentum',
'cast'
])
# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"optimize_offload": True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast',
'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant',
'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy',
'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast',
'memcpy', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'memcpy',
'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast',
'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'memcpy',
'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast',
'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'momentum',
'memcpy', 'momentum', 'cast', 'memcpy'
])
# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册