未验证 提交 201d99d6 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Gradient Fuse Allreduce (#45643)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* customize wait_comm

* group gradients

* bugfix

* update program
上级 5829069d
...@@ -354,40 +354,6 @@ class Engine: ...@@ -354,40 +354,6 @@ class Engine:
prune_startup_prog = dist_startup_prog._prune(uninitialized) prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog) self._executor.run(prune_startup_prog)
if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']:
# from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
def cast_parameters_to_fp16(place,
program,
scope=None,
to_fp16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
"""
from paddle.framework import core
import numpy as np
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
if param.dtype == core.VarDesc.VarType.FP16:
param_t = var_scope.find_var(
param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
cast_parameters_to_fp16(place, prune_startup_prog)
def fit(self, def fit(self,
train_data, train_data,
batch_size=1, batch_size=1,
......
...@@ -1504,3 +1504,15 @@ def ring_id_to_process_group(ring_id): ...@@ -1504,3 +1504,15 @@ def ring_id_to_process_group(ring_id):
if g.id == ring_id: if g.id == ring_id:
return g return g
return None return None
def find_higher_order_backward_op(program):
higher_order_op_suffix = ['_grad_grad', 'triple_grad']
for block in program.blocks:
for op in block.ops:
for suffix in higher_order_op_suffix:
if suffix in op.type:
return True
return False
...@@ -314,7 +314,9 @@ class AMPState(object): ...@@ -314,7 +314,9 @@ class AMPState(object):
consume_op_attr.set_input_dist_attr( consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr) cast_name, in_var_dist_attr)
else: else:
assert in_var.dtype == dst_dtype assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, in_name, dst_dtype, in_var.dtype,
str(grad_op))
for out_name in grad_op.output_names: for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import paddle import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
...@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [ ...@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16 __max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization") @register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
""" """
...@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase):
self._analyze_program() self._analyze_program()
self._prune_grad_scaling() self._prune_grad_scaling()
self._calc_comm_overlap() self._calc_comm_overlap()
self._fuse_allreduce() grad_group = self._fuse_allreduce()
# self.summary(grad_group)
def _prune_grad_scaling(self): def _prune_grad_scaling(self):
...@@ -99,7 +107,19 @@ class DataParallelOptimizationPass(PassBase): ...@@ -99,7 +107,19 @@ class DataParallelOptimizationPass(PassBase):
self._calc_wait_comms() self._calc_wait_comms()
def _fuse_allreduce(self): def _fuse_allreduce(self):
pass
if not self._could_be_fuse():
return []
with open('./before_program.txt.' + str(paddle.distributed.get_rank()),
'w') as f:
f.write(str(default_main_program()))
grad_group = self._group_grads()
self._update_program(grad_group)
with open('./after_program.txt.' + str(paddle.distributed.get_rank()),
'w') as f:
f.write(str(default_main_program()))
return grad_group
def _analyze_program(self): def _analyze_program(self):
""" """
...@@ -316,3 +336,247 @@ class DataParallelOptimizationPass(PassBase): ...@@ -316,3 +336,247 @@ class DataParallelOptimizationPass(PassBase):
'op_role': OpRole.Backward, 'op_role': OpRole.Backward,
'ring_id': ring_id 'ring_id': ring_id
}) })
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
# should analyse the dependencies of gradient in backward.
if find_higher_order_backward_op(default_main_program()):
return False
if self.use_sharding:
return False
return True
def _group_grads(self):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
block = default_main_program().global_block()
ops = block.ops
# group individual grad vars
# TODO consider fuse gradient for sharding reduce
# TODO let user to set fuse_grad_size
# emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
h = 2048
ffn_numel = 2 * (4 * h) * h
mha_numel = 3 * h * h + h * h
max_fuse_numel = ffn_numel + mha_numel
grad_groups = []
cur_group = GradientsGroup(ops, max_fuse_numel)
grouped_grad_names = set()
def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0:
cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else:
cur_group.finalize()
grad_groups.append(cur_group)
new_group = GradientsGroup(ops, max_fuse_numel)
if grad_var:
new_group.add(grad_var, ring_id, i)
grouped_grad_names.add(grad_var.name)
return new_group
def op_depend_on_group(op, group):
vars_ = set(op.input_arg_names + op.output_arg_names)
grad_names = set([grad.name for grad in group.gradients])
return len(vars_.intersection(grad_names)) > 0
for i, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
grouped_grad_names.add(grad_name)
cur_group.add(grad_var, ring_id, i)
else:
cur_group = collect_group(cur_group, grad_var, ring_id, i)
else:
if op_depend_on_group(op, cur_group):
cur_group = collect_group(cur_group, None, None, None)
# collect last group
collect_group(cur_group, None, None, None)
return grad_groups
def _update_program(self, grad_groups):
block = default_main_program().global_block()
remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
block._remove_op(idx)
# insert coalecse op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
block._sync_with_cpp()
# TODO update dist attr
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
self.gradients = []
self.numel = 0
self.dtype = None
self.ring_id = None
self.coalesce_var = None
self.coalesce_op_idx = -1
self.allreduce_op_idx = -1
self.scale_op_idx = -1
self.remove_wait_op_indices = []
self.remove_allreduce_op_indices = []
self.remove_scale_op_indices = []
def acceptable(self, grad_var, ring_id):
if len(self.gradients) == 0:
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
return True
def add(self, grad_var, ring_id, i):
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
# NOTE this pass rely on the original synchronization add in previous passes
# (same stream or calc_wait_comm & comm_wait_calc)
# to guarantee the correctness of comm_calc execution order.
# so the calc_wait_comm should be keep.
grad_op_idx = i - 1
if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
self.remove_wait_op_indices.append(i - 1)
grad_op_idx -= 1
if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
self.remove_scale_op_indices.append(i + 1)
if len(self.gradients) == 1:
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
self.coalesce_op_idx = grad_op_idx
def finalize(self):
self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
if len(self.remove_wait_op_indices) > 1:
self.remove_wait_op_indices.pop()
if len(self.remove_scale_op_indices) > 1:
self.scale_op_idx = self.remove_scale_op_indices.pop()
...@@ -16,6 +16,7 @@ from collections import defaultdict ...@@ -16,6 +16,7 @@ from collections import defaultdict
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import register_pass from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
...@@ -536,6 +537,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -536,6 +537,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
return output_var return output_var
def cast_startup_program():
main_program = default_main_program()
startup_program = default_startup_program()
param_to_dtype = {}
for block in main_program.blocks:
for p in block.all_parameters():
param_to_dtype[p.name] = p.dtype
def is_initialization_op(op):
comm_op_prefix = "c_"
op_type = op.type
if op_type.startswith(comm_op_prefix):
return False
if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
return False
return True
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name,
None) == core.VarDesc.VarType.FP16:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op))
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass): class FP16Pass(AMPPass):
...@@ -563,6 +597,8 @@ class FP16Pass(AMPPass): ...@@ -563,6 +597,8 @@ class FP16Pass(AMPPass):
input_data_var_names) input_data_var_names)
is_train = fp16_state._build_state() is_train = fp16_state._build_state()
cast_startup_program()
if is_train: if is_train:
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference # TODO (JZ-LIANG)support cast forward program only when inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册