未验证 提交 229befc8 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Data Parallel Comm & Calc Overlap Optimization (#45173)

* bugfix

* remove scaling

* support rescale_grad opt

* add unitest
上级 60e072d3
...@@ -189,8 +189,9 @@ class Engine: ...@@ -189,8 +189,9 @@ class Engine:
serial_main_prog = self._orig_main_prog.clone() serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone() serial_startup_prog = self._orig_startup_prog.clone()
# FIXME to support grad clip # FIXME to support grad clip
with static.program_guard(serial_main_prog, serial_startup_prog), \ # with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard(): # utils.unique_name.guard():
with static.program_guard(serial_main_prog, serial_startup_prog):
inputs_spec = self.inputs_spec inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else [] labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec] inputs = [s._create_feed_layer() for s in inputs_spec]
...@@ -440,6 +441,7 @@ class Engine: ...@@ -440,6 +441,7 @@ class Engine:
for epoch in range(epochs): for epoch in range(epochs):
train_logs = {"epoch: {:d} ": epoch} train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader): for step, _ in enumerate(train_dataloader):
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_cache, use_program_cache=use_cache,
......
...@@ -16,6 +16,7 @@ from collections import OrderedDict ...@@ -16,6 +16,7 @@ from collections import OrderedDict
import paddle import paddle
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
...@@ -26,6 +27,9 @@ __rescale_grad_supported_opts__ = [ ...@@ -26,6 +27,9 @@ __rescale_grad_supported_opts__ = [
'merge_momentum' 'merge_momentum'
] ]
# a heuristic number
__max_stream_num_allow__ = 16
@register_pass("auto_parallel_data_parallel_optimization") @register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
...@@ -71,7 +75,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -71,7 +75,7 @@ class DataParallelOptimizationPass(PassBase):
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
self._prune_grad_scaling() self._prune_grad_scaling()
self._overlap_comm() self._calc_comm_overlap()
self._fuse_allreduce() self._fuse_allreduce()
def _prune_grad_scaling(self): def _prune_grad_scaling(self):
...@@ -86,14 +90,18 @@ class DataParallelOptimizationPass(PassBase): ...@@ -86,14 +90,18 @@ class DataParallelOptimizationPass(PassBase):
self._remove_grad_scaling() self._remove_grad_scaling()
def _overlap_comm(self): def _calc_comm_overlap(self):
pass if not self._could_be_overlap():
return
self._calc_overlap_comms()
self._update_wait_comms()
def _fuse_allreduce(self): def _fuse_allreduce(self):
pass pass
def _analyze_program(self): def _analyze_program(self):
""" """
build two maps
{param_grad_name: data_parallel_group} {param_grad_name: data_parallel_group}
{pdata_parallel_group: aram_grad_name} {pdata_parallel_group: aram_grad_name}
""" """
...@@ -103,8 +111,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -103,8 +111,9 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = [] scaled_grads = []
for op in ops: for op in ops:
grad_name = op.output_arg_names[0]
if is_data_parallel_reduce_op(op): if is_data_parallel_reduce_op(op):
grad_name = op.output_arg_names[0]
if grad_name in self._grad_name_to_group_map: if grad_name in self._grad_name_to_group_map:
continue continue
assert op.has_attr( assert op.has_attr(
...@@ -123,7 +132,6 @@ class DataParallelOptimizationPass(PassBase): ...@@ -123,7 +132,6 @@ class DataParallelOptimizationPass(PassBase):
self._group_to_grad_name_map[group].append(grad_name) self._group_to_grad_name_map[group].append(grad_name)
elif is_data_parallel_scale_op(op): elif is_data_parallel_scale_op(op):
grad_name = op.output_arg_names[0]
scaled_grads.append(grad_name) scaled_grads.append(grad_name)
# TODO support multiple optimizers in on network in future. # TODO support multiple optimizers in on network in future.
...@@ -206,3 +214,66 @@ class DataParallelOptimizationPass(PassBase): ...@@ -206,3 +214,66 @@ class DataParallelOptimizationPass(PassBase):
assert scaled_grads == set(self._grad_name_to_group_map.keys( assert scaled_grads == set(self._grad_name_to_group_map.keys(
)), "Unexception: gradients [{}] are unscaled.".format( )), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads) set(self._grad_name_to_group_map.keys()) - scaled_grads)
def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream
# so if there too many dp group there will be too many stream need to be
# created and sync.
# revise here when framework support custom stream in static mode.
num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
if num_dp_comm_stream > __max_stream_num_allow__:
return False
return True
def _calc_overlap_comms(self):
# TODO support InterpreterCore executor for overlap.
# InterpreterCore has a different logic for overlapping
# which is different from use_calc_stream
block = default_main_program().global_block()
ops = block.ops
# comm wait calc to finish
for idx, op in reversed(list(enumerate(block.ops))):
if is_data_parallel_reduce_op(op):
assert op.has_attr('use_calc_stream')
assert op.has_attr('ring_id')
op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id")
block._insert_op_without_sync(idx,
type='c_wait_compute',
inputs={'X': []},
outputs={'Out': []},
attrs={
'op_role': OpRole.Backward,
'ring_id': ring_id
})
block._sync_with_cpp()
def _update_wait_comms(self):
block = default_main_program().global_block()
ops = block.ops
# update wait comm to finish
first_optimize_op_idx = -1
for idx, op in enumerate(ops):
if is_optimize_op(op):
first_optimize_op_idx = idx
break
assert first_optimize_op_idx > -1, "Unexception: not found optimizer op in program"
for group in self._group_to_grad_name_map.keys():
ring_id = group.id
block._insert_op_without_sync(first_optimize_op_idx,
type='c_wait_comm',
inputs={'X': []},
outputs={'Out': []},
attrs={
'op_role': OpRole.Backward,
'ring_id': ring_id
})
...@@ -542,9 +542,12 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): ...@@ -542,9 +542,12 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() fp16_var_names = to_fp16_var_names if to_fp16_var_names else set()
var_scope = scope if scope else global_scope() var_scope = scope if scope else global_scope()
print(
"======================cast_parameters_to_fp16=============================="
)
for param in all_parameters: for param in all_parameters:
if param.name in fp16_var_names: if param.name in fp16_var_names:
_logger.debug("---- cast {} to fp16 dtype ----".format(param.name)) print("---- cast {} to fp16 dtype ----".format(param.name))
param_t = var_scope.find_var(param.name).get_tensor() param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t) data = np.array(param_t)
param_t.set(np.float16(data), place) param_t.set(np.float16(data), place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册