# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import unittest import paddle import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None PP_MESH_0 = None PP_MESH_1 = None class MLPLayer(nn.Layer): def __init__(self, hidden_size=1024, intermediate_size=4 * 1024, initializer_range=0.02): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr( initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) bias_attr = None self.linear0 = nn.Linear(d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) self.linear1 = nn.Linear(dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) def forward(self, input): if _global_parallel_strategy == "pp": auto.shard_tensor(self.linear0.weight, dist_attr={ "process_mesh": PP_MESH_0, "dims_mapping": [-1, -1] }) auto.shard_tensor(self.linear1.weight, dist_attr={ "process_mesh": PP_MESH_1, "dims_mapping": [-1, -1] }) else: auto.shard_tensor(self.linear0.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, -1] }) auto.shard_tensor(self.linear1.weight, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, -1] }) out = self.norm(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = self.linear1(out) return out def mlp_forward(train_program, start_program): with static.program_guard(train_program, start_program), utils.unique_name.guard(): batch_size = 4 hidden_size = 1024 sequence_len = 512 input = static.data(name="input", shape=[batch_size, hidden_size], dtype='float32') label = static.data(name="label", shape=[batch_size, 1], dtype='float32') if _global_parallel_strategy == "pp": auto.shard_tensor(input, dist_attr={ "process_mesh": PP_MESH_0, "dims_mapping": [-1, -1] }) auto.shard_tensor(label, dist_attr={ "process_mesh": PP_MESH_1, "dims_mapping": [-1, -1] }) elif _global_parallel_strategy == "dp": auto.shard_tensor(input, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [0, -1] }) else: auto.shard_tensor(input, dist_attr={ "process_mesh": _global_process_mesh, "dims_mapping": [-1, -1] }) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, initializer_range=0.02) predict = mlp(input) error_cost = paddle.nn.functional.square_error_cost(predict, label) loss = paddle.mean(error_cost) return loss, train_program, start_program def get_dist_prog(train_program, startup_program, dist_context, rank_id, change_process_mesh=False): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) if change_process_mesh: global PP_MESH_1 dist_context.get_tensor_dist_attr_for_program( train_program.global_block( ).vars["gelu_0.tmp_0"]).process_mesh = PP_MESH_1 params_grads = parallelizer._generate_backward(complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): has_dist_attr = True vars = dist_main_prog.global_block().vars op_dist_attr = dist_context.get_op_dist_attr_for_program(op_need_check) if not op_dist_attr or not op_dist_attr.process_mesh: has_dist_attr = False for var_name in op_need_check.input_arg_names: if not op_dist_attr.get_input_dims_mapping(var_name) or \ not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \ not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh: has_dist_attr = False break if has_dist_attr: for var_name in op_need_check.output_arg_names: if not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \ not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh: has_dist_attr = False break return has_dist_attr def check_send_recv_result(dist_main_prog, rank_id): send_result = False recv_result = False ops = dist_main_prog.global_block().ops if rank_id == 0: for idx, op in enumerate(ops): if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: send_result = True if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ 0]: recv_result = True else: for idx, op in enumerate(ops): if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names: send_result = True if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]: recv_result = True return send_result and recv_result def check_initialization(dist_startup_prog, rank_id): if rank_id == 0: need_check_params = [ "layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", "linear_0.b_0" ] else: need_check_params = ['linear_1.w_0', 'linear_1.b_0'] params = [] for var_name, var in dist_startup_prog.global_block().vars.items(): if var.is_parameter: params.append(var_name) return params == need_check_params def check_initialization_for_dp(dist_startup_prog): need_check_params = [ "layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", "linear_0.b_0" ] + ['linear_1.w_0', 'linear_1.b_0'] params = [] for var_name, var in dist_startup_prog.global_block().vars.items(): if var.is_parameter: params.append(var_name) broadcast_varnames = [] for op in dist_startup_prog.global_block().ops: if op.type == "c_broadcast": broadcast_varnames.append(op.output_arg_names[0]) return sorted(params) == sorted(need_check_params) == sorted( broadcast_varnames) class TestMLPReshard(unittest.TestCase): def test_complete_backward_annotation(self): global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, 0) op_need_check = None for op in dist_main_prog.global_block().ops: if op.type == "gelu_grad": op_need_check = op break # grad op should have dist attr self.assertTrue( check_backward_dist_attr(dist_context, dist_main_prog, op_need_check)) # clear _g_process_group_map _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) def test_mlp_pp(self): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global PP_MESH_0 PP_MESH_0 = auto.ProcessMesh(mesh=[0]) global PP_MESH_1 PP_MESH_1 = auto.ProcessMesh(mesh=[1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # parameter initialization of every rank should be different in the pipeline scene self.assertTrue(check_initialization(dist_startup_prog, rank_id)) # clear _g_process_group_map _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) def test_mlp_pp_diff_process_mesh(self): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global PP_MESH_0 PP_MESH_0 = auto.ProcessMesh(mesh=[0]) global PP_MESH_1 PP_MESH_1 = auto.ProcessMesh(mesh=[1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id, True) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id)) # clear _g_process_group_map _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) def test_mlp_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) # all parameters should be initialized in dp scene self.assertTrue(check_initialization_for_dp(dist_startup_prog)) # clear _g_process_group_map _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) if __name__ == "__main__": unittest.main()