From ceef73c9b5876311e947505c9d04b5cb51cd73f0 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 18 Apr 2022 11:32:08 +0800 Subject: [PATCH] [Auto parallel] Transformer MHA & FFN Fused Dist op (#41163) * adapot dist op * [Auto Parallel] Support the auto completion of while_op * add dist_fill_constant_batch_size_like * align infer accuracy --- .../distributed/auto_parallel/dist_loader.py | 18 +- .../distributed/auto_parallel/engine.py | 18 +- .../auto_parallel/operators/__init__.py | 2 + .../operators/dist_fused_attention.py | 211 ++++++++++++++++++ .../operators/dist_fused_feedforward.py | 203 +++++++++++++++++ .../distributed/passes/auto_parallel_amp.py | 1 + .../distributed/passes/auto_parallel_fp16.py | 21 +- 7 files changed, 460 insertions(+), 14 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 187c7cc028..9449b52952 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -97,15 +97,19 @@ class NonIterableGeneratorLoader(DistributedDataLoader): if not isinstance(data, list): data = to_list(data) - if batch_data is None: - batch_data = [[] for i in range(len(data))] + if self.batch_size == 1: + yield data + batch_data = None + else: + if batch_data is None: + batch_data = [[] for i in range(len(data))] - for idx in range(len(data)): - batch_data[idx].append(data[idx]) + for idx in range(len(data)): + batch_data[idx].append(data[idx]) - if (step + 1) % self.batch_size == 0: - yield batch_data - batch_data = None + if (step + 1) % self.batch_size == 0: + yield batch_data + batch_data = None dataloader = paddle.fluid.io.DataLoader.from_generator( feed_list=self.feed_list, capacity=70, iterable=False) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index c71ca9b7c6..a5fec789df 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -194,6 +194,9 @@ class Engine: self._apply_post_optimization(dist_main_prog, dist_startup_prog, rank, dist_params_grads) else: + # Apply pre optimization passes + self._apply_pre_optimization(serial_main_program, + serial_startup_program, None, None) # Do logical partition partitioner = Partitioner(dist_context, rank) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( @@ -231,15 +234,24 @@ class Engine: def _apply_pre_optimization(self, main_program, startup_program, loss, params_grads): + # apply amp pass if self.strategy.amp: config = copy.deepcopy(self.strategy.amp_configs) config["dist_context"] = self._dist_contexts[self.mode] config["params_grads"] = params_grads config["loss"] = loss - auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply([main_program], [startup_program], - self._pass_contexts[self.mode]) + config["input_data"] = self._feed_vars[self.mode][ + "inputs"] + self._feed_vars[self.mode]["labels"] + if config["use_pure_fp16"]: + config["base_opt"] = self._optimizer + auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) + auto_parallel_fp16_pass.apply( + [main_program], [startup_program], self._pass_context) + else: + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply([main_program], [startup_program], + self._pass_context) # apply recompute pass if self.strategy.recompute: diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index db6f909f8c..3c22974657 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -28,3 +28,5 @@ from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling from . import dist_split from . import dist_fill_constant_batch_size_like +from . import dist_fused_feedforward +from . import dist_fused_attention diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py b/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py new file mode 100644 index 0000000000..bc3992ec03 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_fused_attention.py @@ -0,0 +1,211 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard, is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 +from ..utils import _get_comm_group, _get_corresponding_rank +from ..process_group import new_process_group + + +class DistributedFusedAttention(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedFusedAttention, self).__init__(op_type) + + +register_distributed_operator_impl_container( + DistributedFusedAttention("fused_attention")) + + +class DistributedFusedAttentionImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedFusedAttentionImpl, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + qkv_w = op_desc.input('QKVW')[0] + qkv_bias = op_desc.input('QKVBias')[0] + out_w = op_desc.input('OutLinearW')[0] + out_bias = op_desc.input('OutLinearBias')[0] + + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + qkv_w_dims_mapping = op_dist_attr.get_input_dims_mapping(qkv_w) + qkv_bias_dims_mapping = op_dist_attr.get_input_dims_mapping(qkv_bias) + out_w_dims_mapping = op_dist_attr.get_input_dims_mapping(out_w) + out_bias_dims_mapping = op_dist_attr.get_input_dims_mapping(out_bias) + + head_axis = 1 + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + if len(qkv_w_dims_mapping) != 4 or is_dim_replicate(qkv_w_dims_mapping[ + head_axis]): + return False + if len(qkv_bias_dims_mapping) != 3 or is_dim_replicate( + qkv_bias_dims_mapping[head_axis]): + return False + if is_dim_replicate(out_w_dims_mapping[0]): + return False + if is_dim_shard(out_bias_dims_mapping[-1]): + return False + + replicated_dims = [ + qkv_w_dims_mapping[0], qkv_w_dims_mapping[-2], + qkv_w_dims_mapping[-1], qkv_bias_dims_mapping[0], + qkv_bias_dims_mapping[-1], out_w_dims_mapping[-1], + out_bias_dims_mapping[-1] + ] + for mapping in replicated_dims: + if is_dim_shard(mapping): + return False + if qkv_bias_dims_mapping[head_axis] != qkv_w_dims_mapping[head_axis]: + return False + if qkv_bias_dims_mapping[head_axis] != out_w_dims_mapping[0]: + return False + + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + + # none of output should be sharded + for out_name in op_desc.output_names(): + out = op_desc.output(out_name)[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out) + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_names = op_desc.output('Y') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if x_dims_mapping != out_dims_mapping: + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_names = op_desc.output('Y') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + # infer logic comm presentation + head_axis = 1 + qkv_w = src_op.input('QKVW')[0] + qkv_w_col_dim_mapping = op_dist_attr.get_input_dims_mapping(qkv_w)[ + head_axis] + assert qkv_w_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + qkv_w_col_dim_mapping) + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + + parallel_axis = qkv_w_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # insert op + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + # setting comm id + new_op = main_block.ops[-1] + assert new_op.type == "fused_attention" + new_op._set_attr("ring_id", int(group.id)) + + @staticmethod + def backward(ctx, *args, **kwargs): + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + # infer logic comm presentation + out_w = src_op.input('OutLinearW')[0] + out_w_col_dim_mapping = op_dist_attr.get_input_dims_mapping(out_w)[-1] + assert out_w_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + out_w_col_dim_mapping) + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + + parallel_axis = out_w_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # insert op + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + # setting comm id + new_op = main_block.ops[-1] + assert new_op.type == "fused_attention_grad" + new_op._set_attr("ring_id", int(group.id)) + + +register_distributed_operator_impl( + "fused_attention", DistributedFusedAttentionImpl("tensor_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py b/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py new file mode 100644 index 0000000000..76f526adbb --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_fused_feedforward.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard, is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 +from ..utils import _get_comm_group, _get_corresponding_rank +from ..process_group import new_process_group + + +class DistributedFusedFeedForward(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedFusedFeedForward, self).__init__(op_type) + + +register_distributed_operator_impl_container( + DistributedFusedFeedForward("fused_feedforward")) + + +class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedFusedFeedForwardImpl, self).__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + linear1_weight = op_desc.input('Linear1Weight')[0] + linear1_bias = op_desc.input('Linear1Bias')[0] + linear2_weight = op_desc.input('Linear2Weight')[0] + linear2_bias = op_desc.input('Linear2Bias')[0] + + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + linear1_weight_dims_mapping = op_dist_attr.get_input_dims_mapping( + linear1_weight) + linear1_bias_dims_mapping = op_dist_attr.get_input_dims_mapping( + linear1_bias) + linear2_weight_dims_mapping = op_dist_attr.get_input_dims_mapping( + linear2_weight) + linear2_bias_dims_mapping = op_dist_attr.get_input_dims_mapping( + linear2_bias) + + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + if is_dim_shard(linear1_weight_dims_mapping[-2]) or is_dim_replicate( + linear1_weight_dims_mapping[-1]): + return False + if is_dim_replicate(linear1_bias_dims_mapping[-1]): + return False + if is_dim_replicate(linear2_weight_dims_mapping[-2]) or is_dim_shard( + linear2_weight_dims_mapping[-1]): + return False + if is_dim_shard(linear2_bias_dims_mapping[-1]): + return False + if linear1_weight_dims_mapping[-1] != linear2_weight_dims_mapping[-2]: + return False + + return True + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + + # none of output should be sharded + for out_name in op_desc.output_names(): + out = op_desc.output(out_name)[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out) + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_names = op_desc.output('Out') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if x_dims_mapping != out_dims_mapping: + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_names = op_desc.output('Out') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + + for out_name in out_names: + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + # infer logic comm presentation + linear1_weight = src_op.input('Linear1Weight')[0] + linear1_weight_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + linear1_weight)[-1] + assert linear1_weight_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + linear1_weight_col_dim_mapping) + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + + parallel_axis = linear1_weight_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # insert op + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + # setting comm id + new_op = main_block.ops[-1] + assert new_op.type == "fused_feedforward" + new_op._set_attr("ring_id", int(group.id)) + + @staticmethod + def backward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if rank_id not in op_dist_attr.process_mesh.processes: + rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh, + rank_id) + + # infer logic comm presentation + linear2_weight = src_op.input('Linear2Weight')[0] + linear2_weight_col_dim_mapping = op_dist_attr.get_input_dims_mapping( + linear2_weight)[-1] + assert linear2_weight_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( + linear2_weight_col_dim_mapping) + process_mesh_shape = op_dist_attr.process_mesh.topology + process_mesh_group = op_dist_attr.process_mesh.processes + + parallel_axis = linear2_weight_col_dim_mapping + group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, + parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # insert op + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + # setting comm id + new_op = main_block.ops[-1] + assert new_op.type == "fused_feedforward_grad" + new_op._set_attr("ring_id", int(group.id)) + + +register_distributed_operator_impl( + "fused_feedforward", DistributedFusedFeedForwardImpl("tensor_parallel")) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 5fdd88ac1d..fe94c25e12 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -487,6 +487,7 @@ class AMPPass(PassBase): self.set_attr("incr_ratio", 2.0) self.set_attr("decr_ratio", 0.8) self.set_attr("use_dynamic_loss_scaling", False) + self.set_attr("input_data", []) self.set_attr("params_grads", []) self._loss_scaling = None self._num_good_steps = None diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 725b4459d7..69c3eef7e3 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -95,12 +95,21 @@ def _keep_fp32_output(op, out_name): class FP16State(object): - def __init__(self, program, amp_list, dist_context, use_fp16_guard): + def __init__(self, + program, + amp_list, + dist_context, + use_fp16_guard, + input_data_var_names=None): self.program = program self.amp_list = amp_list self.use_fp16_guard = use_fp16_guard self.dist_context = dist_context self.grad_op_to_op_map = self.dist_context.dist_op_context.grad_op_id_to_op_id + if input_data_var_names: + self.input_data_var_names = input_data_var_names + else: + self.input_data_var_names = [] self._op_fp16_dict = { } # op_id --> True/False. 'True' means that the op is should run in fp16 mode. # a trick to determine leaf tensor node in program {varname: generator_op_id} @@ -191,7 +200,7 @@ class FP16State(object): if _keep_fp32_input(op, in_name): continue for in_var_name in op.input(in_name): - if in_var_name not in self.forward_non_leaf_tensors: + if in_var_name not in self.forward_non_leaf_tensors and in_var_name not in self.input_data_var_names: self.set_var_to_fp16(in_var_name, block) for out_name in op.output_names: if _keep_fp32_output(op, out_name): @@ -498,10 +507,14 @@ class FP16Pass(AMPPass): set(self.get_attr("custom_white_list")), set(self.get_attr("custom_black_list")), None) - # TODO support multiple blocks + # NOTE don't not change input data dtype, since it is controled by dataloader + # and which is out of control of FP16 Pass + input_data_var_names = [var.name for var in self.get_attr("input_data")] + with paddle.static.program_guard(main_program, startup_program): fp16_state = FP16State(main_program, amp_list, self.dist_context, - self.get_attr("use_fp16_guard")) + self.get_attr("use_fp16_guard"), + input_data_var_names) is_train = fp16_state._build_state() if is_train: -- GitLab