diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index e38ac9a0ad2da52c62b4f64f4ea50eaaa90faec9..fdf19ac46c91cc7877ea55733ed312a73f22a651 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -530,7 +530,10 @@ inline __device__ void zero(T &dst) { // NOLINT dst = tmp.raw; } -template __global__ void masked_multihead_attention_kernel( Masked_multihead_attention_params params) { @@ -830,8 +833,10 @@ __global__ void masked_multihead_attention_kernel( template inline size_t smem_size_in_bytes( - const Masked_multihead_attention_params ¶ms, int dim_head, - int threads_per_value, int threads_per_block) { + const Masked_multihead_attention_params ¶ms, + int dim_head, + int threads_per_value, + int threads_per_block) { size_t qk_sz = div_up(params.timestep + 1, 4) * 16; size_t logits_sz = 0; @@ -848,14 +853,17 @@ inline size_t smem_size_in_bytes( return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ - THDS_PER_BLOCK, stream) \ +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ size_t smem_sz = \ smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel< \ - T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ - THDS_PER_BLOCK><<>>(params) + masked_multihead_attention_kernel \ + <<>>(params) template void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, @@ -871,10 +879,17 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, } template -void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, - const Tensor &qkv_bias_tensor, const Tensor &src_mask_tensor, - Tensor *cache_kv_tensor, Tensor *out_tensor, int batch_size, - int max_seq_length, int num_head, int dim_head, int timestep, +void fmha(const platform::CUDADeviceContext &dev_ctx, + const Tensor &qkv_tensor, + const Tensor &qkv_bias_tensor, + const Tensor &src_mask_tensor, + Tensor *cache_kv_tensor, + Tensor *out_tensor, + int batch_size, + int max_seq_length, + int num_head, + int dim_head, + int timestep, float inv_sqrt_dh) { Masked_multihead_attention_params params; params.out = out_tensor->data(); @@ -911,8 +926,11 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, constexpr int VEC_16B = 16; template -__global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head, - const int dim_head, const int seq_len, +__global__ void write_cache_k_kernel(T *cache_k, + const T *k, + const int num_head, + const int dim_head, + const int seq_len, const int max_seq_len) { const int bi = blockIdx.y; const int hi = blockIdx.z; @@ -946,8 +964,11 @@ __global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head, } template -__global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head, - const int dim_head, const int seq_len, +__global__ void write_cache_v_kernel(T *cache_v, + const T *v, + const int num_head, + const int dim_head, + const int seq_len, const int max_seq_len) { const int bi = blockIdx.y; const int hi = blockIdx.z; @@ -970,16 +991,23 @@ __global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head, } template -void write_cache_kv(const platform::CUDADeviceContext &dev_ctx, T *cache_k, - T *cache_v, const T *k, const T *v, const int bsz, - const int num_head, const int seq_len, - const int max_seq_len, const int dim_head) { +void write_cache_kv(const platform::CUDADeviceContext &dev_ctx, + T *cache_k, + T *cache_v, + const T *k, + const T *v, + const int bsz, + const int num_head, + const int seq_len, + const int max_seq_len, + const int dim_head) { constexpr int block_sz = 128; constexpr int x = VEC_16B / sizeof(T); assert(dim_head % x == 0); PADDLE_ENFORCE_EQ( - dim_head % x, 0, + dim_head % x, + 0, platform::errors::PreconditionNotMet( "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); @@ -1043,15 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; // (transA, transB, compute_bias) = (false, true, false) - auto qkv_compute = AttnMatMul(dev_ctx, false, true, bsz_seq, output_size, - input_size, compute_bias); + auto qkv_compute = AttnMatMul( + dev_ctx, false, true, bsz_seq, output_size, input_size, compute_bias); Tensor qkv_out; auto *qkv_out_data = qkv_out.mutable_data({bsz, seq_len, 3, num_head, dim_head}, place); // 3. fmha - AttnDropoutParam attn_param(true, "upscale_in_train", 0.0, true, true, 0, - nullptr); + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); auto fmha_compute = FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); auto *src_mask = ctx.Input("SrcMask"); @@ -1061,17 +1089,20 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto out_seq_len = seq_len; if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), platform::CPUPlace(), + PADDLE_ENFORCE_EQ(time_step->place(), + platform::CPUPlace(), platform::errors::PreconditionNotMet( "The place of input(TimeStep) must be CPUPlace.")); // cache_seq_len int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, 0, + PADDLE_ENFORCE_GT(time_step_value, + 0, platform::errors::PreconditionNotMet( "The value of time_step must > 0, but now is %d", time_step_value)); PADDLE_ENFORCE_EQ( - seq_len, 1, + seq_len, + 1, platform::errors::PreconditionNotMet( "In decode stage, the seq_len of input must be 1, but now is %d", seq_len)); @@ -1107,8 +1138,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = AttnMatMul(dev_ctx, false, false, bsz_seq, - dim_embed, hidden_size, false); + auto out_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); // 5. ln(residual + bias) DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); @@ -1117,9 +1148,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); Tensor bias_dropout_residual_out, dropout_mask_out; - auto *bias_dropout_residual_out_data = - bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, - place); + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out_data = + bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, + place); + } auto *dropout_mask_out_data = dropout_mask_out.mutable_data( {bsz, seq_len, dim_embed}, place); @@ -1129,8 +1163,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[1]; - auto ffn1_linear_compute = AttnMatMul(dev_ctx, false, false, bsz_seq, - dim_ffn, dim_embed, false); + auto ffn1_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false); Tensor ffn1_out; auto *ffn1_out_data = ffn1_out.mutable_data({bsz_seq, dim_ffn}, place); @@ -1147,8 +1181,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // 8. ffn2 matmul auto ffn2_weights = ctx.MultiInput("FFN2Weight"); auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - auto ffn2_linear_compute = AttnMatMul(dev_ctx, false, false, bsz_seq, - dim_embed, dim_ffn, false); + auto ffn2_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false); // 9. ffn2 residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); @@ -1171,14 +1205,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (layers & 1) { - // odd, set buf1 as out + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { buf0 = &tmp_out; buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; } for (int i = 0; i < layers; ++i) { @@ -1187,11 +1226,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto *ln_scale_data = ln_scales[i]->data(); auto *ln_bias_data = ln_biases[i]->data(); // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, - buf1->data(), ln_mean_data, ln_var_data); - } else if (!pre_layer_norm) { - PADDLE_THROW(platform::errors::Unimplemented( - "Unimplemented post_layer_norm for now.")); + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step1"; @@ -1201,8 +1241,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha const Tensor *bias = time_step ? nullptr : qkv_bias; - qkv_compute.ComputeForward(qkv_weights[i], buf1, bias, &qkv_out, - &qkv_out); + if (!pre_layer_norm && i == 0) { + qkv_compute.ComputeForward( + qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); + } else { + qkv_compute.ComputeForward( + qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2"; #endif @@ -1214,15 +1259,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, qkv_out, *qkv_bias, *src_mask, cache_kv_out, &fmha_out, - bsz, max_seq_len, num_head, dim_head, time_step->data()[0], + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + cache_kv_out, + &fmha_out, + bsz, + max_seq_len, + num_head, + dim_head, + time_step->data()[0], 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward( - qkv_out, nullptr, src_mask, &transpose_out_2, nullptr, &qk_out, - &src_mask_out, &softmax_out, &attn_dropout_mask_out, - &attn_dropout_out, &qktv_out, &fmha_out); + fmha_compute.ComputeForward(qkv_out, + nullptr, + src_mask, + &transpose_out_2, + nullptr, + &qk_out, + &src_mask_out, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); // [3, bsz, num_head, seq_len, head_dim] T *qkv_data = transpose_out_2_data; int64_t q_size = bsz * seq_len * num_head * dim_head; @@ -1239,23 +1301,45 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; - write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, bsz, - num_head, seq_len, max_seq_len, dim_head); + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward( - qkv_out, cache_kv, src_mask, &transpose_out_2, cache_kv_out, - &qk_out, &src_mask_out, &softmax_out, &attn_dropout_mask_out, - &attn_dropout_out, &qktv_out, &fmha_out); + fmha_compute.ComputeForward(qkv_out, + cache_kv, + src_mask, + &transpose_out_2, + cache_kv_out, + &qk_out, + &src_mask_out, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; #endif // step4. out_linear - out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out, - nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, dev_ctx); + if (pre_layer_norm) { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, dev_ctx); + } else { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + AllReduce(*buf0, ring_id, dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step4"; #endif @@ -1268,39 +1352,75 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // inplace fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, buf1->data(), x_data, out_linear_bias_data, - ln_scale_data, ln_bias_data, bias_dropout_residual_out_data, - dropout_mask_out_data, buf1->data(), ln_mean_data, ln_var_data); + dev_ctx, + buf1->data(), + x_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step5"; #endif // step6. ffn matmul1 - ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, nullptr, - &ffn1_out, nullptr); + ffn1_linear_compute.ComputeForward( + ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step6"; #endif // step7. act bias // TODO(wangxi): remove dropout mask in inference - fused_act_dropout_helper.DropoutActBias( - dev_ctx, ffn1_out_data, ffn1_biases[i]->data(), "gelu", - ffn1_dropout_out_data, ffn1_dropout_mask_data); + fused_act_dropout_helper.DropoutActBias(dev_ctx, + ffn1_out_data, + ffn1_biases[i]->data(), + "gelu", + ffn1_dropout_out_data, + ffn1_dropout_mask_data); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step7"; #endif // step8. ffn matmul2 - ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out, - nullptr, buf1, nullptr); + if (pre_layer_norm) { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + } else { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.0"; #endif - AllReduce(*buf1, ring_id, dev_ctx); + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, dev_ctx); + } else { + AllReduce(*buf0, ring_id, dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.1"; #endif @@ -1312,23 +1432,49 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto *ln_scale_data = ln_scales[i + 1]->data(); auto *ln_bias_data = ln_biases[i + 1]->data(); ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, buf1->data(), bias_dropout_residual_out_data, - ffn2_biases[i]->data(), ln_scale_data, ln_bias_data, - buf1->data(), dropout_mask_out_data, buf0->data(), - ln_mean_data, ln_var_data); + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf1->data(), + dropout_mask_out_data, + buf0->data(), + ln_mean_data, + ln_var_data); } else { ffn2_fused_dropout_helper.ResidualDropoutBias( - dev_ctx, buf1->data(), bias_dropout_residual_out_data, - ffn2_biases[i]->data(), buf1->data(), + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + buf1->data(), dropout_mask_out_data); } } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step9"; #endif - x_data = buf1->data(); - std::swap(buf0, buf1); + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py index 2e18471d3f3be6d5b24d1822fce2a1bba2fd83f6..76d9376c567ae299142c7e094feb7c599ce5c6a2 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -39,6 +39,7 @@ default_main_program().random_seed = 42 class TestFusedMultiTransformerOp(OpTest): + def setUp(self): self.config() self.generate_input_data() @@ -61,39 +62,33 @@ class TestFusedMultiTransformerOp(OpTest): bias_attr = paddle.fluid.ParamAttr( initializer=paddle.fluid.initializer.Constant(value=0.0005)) - self.q_proj = Linear( - self.embed_dim, - self.embed_dim, - self.weight_attr, - bias_attr=bias_attr) + self.q_proj = Linear(self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=bias_attr) #bias_attr=self.bias_attr) - self.k_proj = Linear( - self.kdim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - self.v_proj = Linear( - self.vdim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - self.out_proj = Linear( - self.embed_dim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - - self.ffn1_proj = Linear( - self.embed_dim, - 4 * self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) - self.ffn2_proj = Linear( - 4 * self.embed_dim, - self.embed_dim, - self.weight_attr, - bias_attr=self.bias_attr) + self.k_proj = Linear(self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.v_proj = Linear(self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.out_proj = Linear(self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + + self.ffn1_proj = Linear(self.embed_dim, + 4 * self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.ffn2_proj = Linear(4 * self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) paddle.set_default_dtype(np.float32) self.norm = LayerNorm(self.embed_dim) @@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest): # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] # --> [B, n_head, seq_len, out_seq_len] - qk_out = layers.matmul( - x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + qk_out = layers.matmul(x=q_out, + y=k_out, + transpose_y=True, + alpha=self.head_dim**-0.5) if self.debug: print('qk out is') @@ -249,11 +246,10 @@ class TestFusedMultiTransformerOp(OpTest): print('softmax out is') print(softmax_out[0][0][0]) if self.dropout_prob: - dropout_out = F.dropout( - softmax_out, - self.dropout_prob, - training=self.training, - mode="upscale_in_train") + dropout_out = F.dropout(softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] # --> [B, n_head, seq_len, head_dim] qktv_out = tensor.matmul(dropout_out, v_out) @@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest): print('fmha out is') print(fmha_out[0][0][0]) out_linear_in = tensor.reshape( - x=fmha_out, - shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) out = self.out_proj(out_linear_in) residual_out = residual + self.dropout(out) @@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest): def GetFusedMultiTransformerOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) - q_proj_weight = paddle.to_tensor( - self.q_proj.weight, stop_gradient=False) - k_proj_weight = paddle.to_tensor( - self.k_proj.weight, stop_gradient=False) - v_proj_weight = paddle.to_tensor( - self.v_proj.weight, stop_gradient=False) - out_linear_weight = paddle.to_tensor( - self.out_proj.weight, stop_gradient=False) - ffn1_weight = paddle.to_tensor( - self.ffn1_proj.weight, stop_gradient=False) - ffn2_weight = paddle.to_tensor( - self.ffn2_proj.weight, stop_gradient=False) + q_proj_weight = paddle.to_tensor(self.q_proj.weight, + stop_gradient=False) + k_proj_weight = paddle.to_tensor(self.k_proj.weight, + stop_gradient=False) + v_proj_weight = paddle.to_tensor(self.v_proj.weight, + stop_gradient=False) + out_linear_weight = paddle.to_tensor(self.out_proj.weight, + stop_gradient=False) + ffn1_weight = paddle.to_tensor(self.ffn1_proj.weight, + stop_gradient=False) + ffn2_weight = paddle.to_tensor(self.ffn2_proj.weight, + stop_gradient=False) if self.bias_attr is False: qkv_bias_tensor = None out_linear_bias = None else: - q_proj_bias = paddle.to_tensor( - self.q_proj.bias, stop_gradient=False) - k_proj_bias = paddle.to_tensor( - self.k_proj.bias, stop_gradient=False) - v_proj_bias = paddle.to_tensor( - self.v_proj.bias, stop_gradient=False) + q_proj_bias = paddle.to_tensor(self.q_proj.bias, + stop_gradient=False) + k_proj_bias = paddle.to_tensor(self.k_proj.bias, + stop_gradient=False) + v_proj_bias = paddle.to_tensor(self.v_proj.bias, + stop_gradient=False) qkv_bias = np.concatenate( (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) - out_linear_bias = paddle.to_tensor( - self.out_proj.bias, stop_gradient=False) - ffn1_bias = paddle.to_tensor( - self.ffn1_proj.bias, stop_gradient=False) - ffn2_bias = paddle.to_tensor( - self.ffn2_proj.bias, stop_gradient=False) + out_linear_bias = paddle.to_tensor(self.out_proj.bias, + stop_gradient=False) + ffn1_bias = paddle.to_tensor(self.ffn1_proj.bias, + stop_gradient=False) + ffn2_bias = paddle.to_tensor(self.ffn2_proj.bias, + stop_gradient=False) ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False) ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False) - ffn_ln_scale = paddle.to_tensor( - self.ffn_norm.weight, stop_gradient=False) + ffn_ln_scale = paddle.to_tensor(self.ffn_norm.weight, + stop_gradient=False) ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False) q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) @@ -351,12 +346,11 @@ class TestFusedMultiTransformerOp(OpTest): cache_kvs = [] max_seq_length = (self.cache_length + 128) // 128 * 128 - cache_kv = np.zeros( - [ - 2, self.batch_size, self.num_heads, max_seq_length, - self.head_dim - ], - dtype=self.x_type) + cache_kv = np.zeros([ + 2, self.batch_size, self.num_heads, max_seq_length, + self.head_dim + ], + dtype=self.x_type) elems = 4 if self.x_type is np.float16: @@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest): assert self.query_length == self.cache_length cache_kv[:] = 0 else: - time_step = paddle.to_tensor( - [self.cache_length], dtype='int32', place=paddle.CPUPlace()) + time_step = paddle.to_tensor([self.cache_length], + dtype='int32', + place=paddle.CPUPlace()) if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: @@ -417,31 +412,29 @@ class TestFusedMultiTransformerOp(OpTest): ffn_ln_scales.append(ffn_ln_scale) ffn_ln_biases.append(ffn_ln_bias) if self.has_cache_kv: - cache_kvs.append( - paddle.to_tensor( - cache_kv, stop_gradient=False)) - - final_out = fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - out_weights, - out_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - pre_layer_norm=self.pre_layer_norm, - epsilon=epsilon, - cache_kvs=cache_kvs, - time_step=time_step, - attn_mask=attn_mask, - dropout_rate=self.dropout_prob, - training=self.training) + cache_kvs.append(paddle.to_tensor(cache_kv, + stop_gradient=False)) + + final_out = fused_multi_transformer(x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + out_weights, + out_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm=self.pre_layer_norm, + epsilon=epsilon, + cache_kvs=cache_kvs, + time_step=time_step, + attn_mask=attn_mask, + dropout_rate=self.dropout_prob, + training=self.training) if self.has_cache_kv: return final_out[0], final_out[1] @@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest): if self.debug: print("cache_k out timestep=128") - print(cache_kv_out[0].reshape([ - 2, bsz, num_head, v_elems, max_seq_len, elems - ])[0, 0, 0, :, self.cache_length, :]) + print(cache_kv_out[0].reshape( + [2, bsz, num_head, v_elems, max_seq_len, + elems])[0, 0, 0, :, self.cache_length, :]) print("cache_v out timestep=128") print(cache_kv_out[0][1, 0, 0, self.cache_length, :]) @@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest): cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :] - np.testing.assert_allclose( - cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol) - np.testing.assert_allclose( - cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose(cache_k_ref, + cache_k, + rtol=self.rtol, + atol=self.atol) + np.testing.assert_allclose(cache_v_ref, + cache_v, + rtol=self.rtol, + atol=self.atol) if i == 0: break - np.testing.assert_allclose( - final_out_ref, final_out, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose(final_out_ref, + final_out, + rtol=self.rtol, + atol=self.atol) class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): + def config(self): super().config() self.x_type = np.float16 @@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): + def config(self): super().config() self.has_cache_kv = True @@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp): + def config(self): super().config() self.has_cache_kv = True @@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp): + def config(self): super().config() self.has_cache_kv = True @@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.x_type = np.float16 + self.layers = 3 # odd layers + + +class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpCacheKVPostLayerNorm( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16( + TestFusedMultiTransformerOp): + def config(self): super().config() self.has_cache_kv = True self.gen_cache_kv = True self.x_type = np.float16 self.layers = 3 # odd layers + self.pre_layer_norm = False if __name__ == "__main__":