# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import unittest import unittest.mock from io import StringIO import numpy as np import paddle import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils import paddle.tensor as tensor from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import set_default_distributed_context from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.process import new_process_group paddle.enable_static() _global_parallel_strategy = None _global_process_mesh = None ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) def get_programs(annotated_func): train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() global _global_process_mesh dist_context.set_process_mesh(_global_process_mesh) train_program, start_program = annotated_func(train_program, start_program) complete_train_program = auto.complete_annotation(train_program, dist_context) rank_id = 3 dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_strategy, dist_context, rank_id) test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward( complete_train_program, start_program) return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context def is_all_parameters_shape_equal(prog1, prog2): params1 = prog1.all_parameters() params2 = prog2.all_parameters() params1.sort(key=lambda x: x.name) params2.sort(key=lambda x: x.name) shape1 = [tensor.shape for tensor in params1] shape2 = [tensor.shape for tensor in params2] if len(shape1) != len(shape2): return False for i in range(len(shape1)): if shape1[i] != shape2[i]: return False return True def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): for i in range(len(varnames1)): var1 = prog1.global_block().var(varnames1[i]) var2 = prog2.global_block().var(varnames2[i]) if var1.shape[axis] != (var2.shape[axis] // nsplit): return False return True def initialization_check(mode, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast): if 'mp' in mode: mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, mp_parallel_axis, 3) mp_ring_id = new_process_group(group_ranks).id broadcast_ops = [ op for op in dist_startup_prog.global_block().ops if (op.type == "c_broadcast" and op.desc.attr("ring_id") == mp_ring_id) ] broadcast_varnames = sorted( [op.desc.output_arg_names()[0] for op in broadcast_ops]) if broadcast_varnames != var_need_broadcast: return False if 'dp' in mode: dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() group_ranks = _get_comm_group(process_mesh.process_group, process_mesh.topology, dp_parallel_axis, 3) dp_ring_id = new_process_group(group_ranks).id nparam = len(serial_startup_prog.all_parameters()) nbroadcast_dp = len([ op for op in dist_startup_prog.global_block().ops if (op.type == "c_broadcast" and op.desc.attr("ring_id") == dp_ring_id) ]) if nparam != nbroadcast_dp: return False if "dp" in mode and 'mp' in mode: nbroadcast = len([ op for op in dist_startup_prog.global_block().ops if op.type == "c_broadcast" ]) if len(var_need_broadcast) + nbroadcast_dp != nbroadcast: return False return True class MLPLayer(nn.Layer): def __init__(self, hidden_size=1024, intermediate_size=4 * 1024, dropout_ratio=0.1, initializer_range=0.02): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)) bias_attr = None self.linear0 = nn.Linear( d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) self.linear1 = nn.Linear( dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) else: auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, -1]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[-1, -1]) out = self.norm(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = self.linear1(out) out = self.dropout(out) return out def mlp_pretrain_forward(train_program, start_program): with static.program_guard(train_program, start_program), utils.unique_name.guard(): batch_size = 4 hidden_size = 1024 sequence_len = 512 input = static.data( name="input", shape=[batch_size, sequence_len, hidden_size], dtype='float32') if _global_parallel_strategy == "dp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) mlp = MLPLayer( hidden_size=hidden_size, intermediate_size=4 * hidden_size, dropout_ratio=0.1, initializer_range=0.02) out = mlp(input) return train_program, start_program class TestMLPAutoPartitioner(unittest.TestCase): def test_mlp_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) # parameter should not be partitioned self.assertTrue( is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)) self.assertTrue( is_all_parameters_shape_equal(serial_startup_prog, dist_startup_prog)) # op in main prog should be the same serial_ops = serial_main_prog.global_block().ops dist_ops = dist_main_prog.global_block().ops serial_ops = [op.type for op in serial_ops] dist_ops = [op.type for op in dist_ops] self.assertTrue(serial_ops == dist_ops) # parameter initialization var_need_broadcast = [] self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_mlp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) # param should be partition nrank = 4 # col parallel weights = ['linear_0.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 1, nrank)) weights = ['linear_0.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) # row parallel weights = ['linear_1.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) weights = ['linear_1.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, 1)) # row and col allreduce dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' ] self.assertTrue(dist_ops == ref_ops) # parameter initialization var_need_broadcast = sorted( ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_mlp_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) # param should be partition nrank = 4 # col parallel weights = ['linear_0.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 1, nrank)) weights = ['linear_0.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) # row parallel weights = ['linear_1.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) weights = ['linear_1.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, 1)) # row and col allreduce dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' ] self.assertTrue(dist_ops == ref_ops) # parameter initialization var_need_broadcast = sorted( ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) class AttentionLayer(nn.Layer): def __init__(self, hidden_size=1024, sequence_len=512, intermediate_size=4 * 1024, num_heads=16, dropout_ratio=0.1, initializer_range=0.02): super(AttentionLayer, self).__init__() self.hidden_size = hidden_size self.sequence_len = sequence_len self.embed_dim = self.hidden_size self.kdim = self.embed_dim self.vdim = self.embed_dim self.num_heads = num_heads self.head_dim = self.embed_dim // self.num_heads assert self.head_dim * self.num_heads == self.embed_dim, \ "embed_dim must be divisible by num_heads" self.dropout_ratio = dropout_ratio self.initializer_range = initializer_range self.training = True self.attn_mask = None weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)) bias_attr = None self.q_proj = nn.Linear( self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.k_proj = nn.Linear( self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.v_proj = nn.Linear( self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.out_proj = nn.Linear( self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) def forward(self, input): if _global_parallel_strategy == "dp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input, _global_process_mesh, dim_mapping=[0, -1, -1]) q = self.q_proj(input) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(input) v = self.v_proj(input) if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) # scale dot product attention product = layers.matmul( x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if self.attn_mask is not None: product = product + self.attn_mask weights = F.softmax(product) if self.dropout_ratio: weights = F.dropout( weights, self.dropout_ratio, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) return out def attn_pretrain_forward(train_program, start_program): with static.program_guard(train_program, start_program), utils.unique_name.guard(): batch_size = 4 hidden_size = 1024 sequence_len = 512 input = static.data( name="query", shape=[batch_size, sequence_len, hidden_size], dtype='float32') attn = AttentionLayer( hidden_size=hidden_size, sequence_len=sequence_len, intermediate_size=4 * hidden_size, num_heads=16, dropout_ratio=0.1, initializer_range=0.02) out = attn(input) return train_program, start_program class TestAttentionAutoPartitioner(unittest.TestCase): def test_attn_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) # parameter should not be partitioned self.assertTrue( is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)) self.assertTrue( is_all_parameters_shape_equal(serial_startup_prog, dist_startup_prog)) # op in main prog should be the same serial_ops = serial_main_prog.global_block().ops dist_ops = dist_main_prog.global_block().ops serial_ops = [op.type for op in serial_ops] dist_ops = [op.type for op in dist_ops] self.assertTrue(serial_ops == dist_ops) # parameter initialization var_need_broadcast = [] self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_attn_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[0, 1, 2, 3], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) # param should be partition nrank = 4 # col parallel weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 1, nrank)) weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) # row parallel weights = ['linear_3.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) weights = ['linear_3.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, 1)) # row and col allreduce dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) # parameter initialization var_need_broadcast = ['linear_3.b_0'] self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_attn_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) # param should be partition nrank = 4 # col parallel weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 1, nrank)) weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) # row parallel weights = ['linear_3.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) weights = ['linear_3.b_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, 1)) # row and col allreduce dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) # parameter initialization var_need_broadcast = ['linear_3.b_0'] self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) class DecoderLayer(nn.Layer): def __init__(self, vocab_size=32768, hidden_size=1024, sequence_len=512, max_position_embeddings=512, intermediate_size=4 * 1024, num_heads=16, dropout_ratio=0.1, initializer_range=0.02): super(DecoderLayer, self).__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.max_position_embeddings = max_position_embeddings self.sequence_len = sequence_len self.embed_dim = self.hidden_size self.kdim = self.embed_dim self.vdim = self.embed_dim self.num_heads = num_heads self.dropout_ratio = dropout_ratio self.initializer_range = initializer_range self.training = True self.attn_mask = None self.head_dim = self.embed_dim // self.num_heads assert self.head_dim * self.num_heads == self.embed_dim, \ "embed_dim must be divisible by num_heads" self.word_embeddings = nn.Embedding( self.vocab_size, self.hidden_size, weight_attr=paddle.ParamAttr( name="word_embeddings", initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range))) self.position_embeddings = nn.Embedding( self.max_position_embeddings, self.hidden_size, weight_attr=paddle.ParamAttr( name="pos_embeddings", initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range))) weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range)) bias_attr = None self.q_proj = nn.Linear( self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.k_proj = nn.Linear( self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.v_proj = nn.Linear( self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) self.out_proj = nn.Linear( self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) intermediate_size = 4 * self.hidden_size d_model = self.hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range)) bias_attr = None self.linear0 = nn.Linear( d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) self.linear1 = nn.Linear( dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.dropout1 = nn.Dropout(self.dropout_ratio) self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") def forward(self, input_ids, position_ids): if _global_parallel_strategy == "dp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) if _global_parallel_strategy == "mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, dim_mapping=[0, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.word_embeddings.weight, _global_process_mesh, dim_mapping=[1, -1]) embeddings = input_embeddings + position_embeddings embeddings = self.dropout1(embeddings) # Pre-norm target = self.norm(embeddings) # The following is the attention part q = self.q_proj(target) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) k = self.k_proj(target) v = self.v_proj(target) if _global_parallel_strategy == "mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) # scale dot product attention product = layers.matmul( x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if self.attn_mask is not None: product = product + self.attn_mask weights = F.softmax(product) if self.dropout_ratio: weights = F.dropout( weights, self.dropout_ratio, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[0, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[1, -1]) else: auto.shard_tensor( self.out_proj.weight, _global_process_mesh, dim_mapping=[-1, -1]) # Add residual residual = embeddings + self.dropout2(out) # Pre-norm out0 = self.norm(residual) # The following is the MLP part out1 = self.linear0(out0) out2 = F.gelu(out1, approximate=True) out3 = self.linear1(out2) if _global_parallel_strategy == "mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) elif _global_parallel_strategy == "dp_mp": auto.shard_tensor( self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) auto.shard_tensor( self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) # Add residual final = residual + self.dropout3(out3) return final def decoder_pretrain_forward(train_program, start_program): with static.program_guard(train_program, start_program), utils.unique_name.guard(): batch_size = 4 hidden_size = 1024 sequence_len = 512 input_ids = static.data( name="input_ids", shape=[batch_size, sequence_len], dtype='int64') position_ids = static.data( name="position_ids", shape=[batch_size, sequence_len], dtype='int64') decoder = DecoderLayer( vocab_size=32768, hidden_size=hidden_size, sequence_len=sequence_len, max_position_embeddings=512, intermediate_size=4 * hidden_size, num_heads=16, dropout_ratio=0.1, initializer_range=0.02) out = decoder(input_ids, position_ids) return train_program, start_program class TestDecoderLayerPartitioner(unittest.TestCase): def test_decoder_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( decoder_pretrain_forward) # param should be partition nrank = 4 # col parallel weights = [ 'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0' ] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 1, nrank)) weights = [ 'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0' ] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) # row parallel weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) weights = [ 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_5.b_0' ] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, 1)) # row and col allreduce dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'c_embedding', 'c_allreduce_sum', 'lookup_table_v2', 'elementwise_add', 'dropout', 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add', 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) # parameter initialization var_need_broadcast = sorted([ 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_5.b_0' ]) self.assertTrue( initialization_check(_global_parallel_strategy, dist_context, dist_startup_prog, serial_startup_prog, var_need_broadcast)) def test_decoder_noparallel(self): global _global_parallel_strategy _global_parallel_strategy = "None" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( decoder_pretrain_forward) # param should be partition nrank = 1 # col parallel weights = [ 'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0' ] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 1, nrank)) weights = [ 'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0' ] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) # row parallel weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0'] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, nrank)) weights = [ 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_5.b_0' ] self.assertTrue( check_tensor_split(dist_main_prog, weights, serial_main_prog, weights, 0, 1)) # row and col allreduce dist_ops = dist_main_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout', 'layer_norm', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'matmul', 'elementwise_add', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul', 'elementwise_add', 'dropout', 'elementwise_add', 'layer_norm', 'matmul', 'elementwise_add', 'gelu', 'matmul', 'elementwise_add', 'dropout', 'elementwise_add' ] self.assertTrue(dist_ops == ref_ops) dist_ops = dist_startup_prog.global_block().ops dist_ops = [op.type for op in dist_ops] ref_ops = [ 'gaussian_random', 'gaussian_random', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'gaussian_random', 'fill_constant', 'fill_constant', 'fill_constant' ] self.assertTrue(dist_ops == ref_ops) if __name__ == "__main__": unittest.main()