未验证 提交 c5f4a9cc 编写于 作者: C carryyu 提交者: GitHub

add post layer norm (#44931)

上级 9336dd3e
......@@ -530,7 +530,10 @@ inline __device__ void zero(T &dst) { // NOLINT
dst = tmp.raw;
}
template <typename T, int Dh, int THREADS_PER_KEY, int THREADS_PER_VALUE,
template <typename T,
int Dh,
int THREADS_PER_KEY,
int THREADS_PER_VALUE,
int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
......@@ -830,8 +833,10 @@ __global__ void masked_multihead_attention_kernel(
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params, int dim_head,
int threads_per_value, int threads_per_block) {
const Masked_multihead_attention_params<T> &params,
int dim_head,
int threads_per_value,
int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
......@@ -848,14 +853,17 @@ inline size_t smem_size_in_bytes(
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
#define MMHA_LAUNCH_KERNEL( \
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel< \
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
masked_multihead_attention_kernel<T, \
Dh, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template <typename T, int Dh>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
......@@ -871,10 +879,17 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
}
template <typename T>
void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
const Tensor &qkv_bias_tensor, const Tensor &src_mask_tensor,
Tensor *cache_kv_tensor, Tensor *out_tensor, int batch_size,
int max_seq_length, int num_head, int dim_head, int timestep,
void fmha(const platform::CUDADeviceContext &dev_ctx,
const Tensor &qkv_tensor,
const Tensor &qkv_bias_tensor,
const Tensor &src_mask_tensor,
Tensor *cache_kv_tensor,
Tensor *out_tensor,
int batch_size,
int max_seq_length,
int num_head,
int dim_head,
int timestep,
float inv_sqrt_dh) {
Masked_multihead_attention_params<T> params;
params.out = out_tensor->data<T>();
......@@ -911,8 +926,11 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
constexpr int VEC_16B = 16;
template <typename T>
__global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head,
const int dim_head, const int seq_len,
__global__ void write_cache_k_kernel(T *cache_k,
const T *k,
const int num_head,
const int dim_head,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
......@@ -946,8 +964,11 @@ __global__ void write_cache_k_kernel(T *cache_k, const T *k, const int num_head,
}
template <typename T>
__global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head,
const int dim_head, const int seq_len,
__global__ void write_cache_v_kernel(T *cache_v,
const T *v,
const int num_head,
const int dim_head,
const int seq_len,
const int max_seq_len) {
const int bi = blockIdx.y;
const int hi = blockIdx.z;
......@@ -970,16 +991,23 @@ __global__ void write_cache_v_kernel(T *cache_v, const T *v, const int num_head,
}
template <typename T>
void write_cache_kv(const platform::CUDADeviceContext &dev_ctx, T *cache_k,
T *cache_v, const T *k, const T *v, const int bsz,
const int num_head, const int seq_len,
const int max_seq_len, const int dim_head) {
void write_cache_kv(const platform::CUDADeviceContext &dev_ctx,
T *cache_k,
T *cache_v,
const T *k,
const T *v,
const int bsz,
const int num_head,
const int seq_len,
const int max_seq_len,
const int dim_head) {
constexpr int block_sz = 128;
constexpr int x = VEC_16B / sizeof(T);
assert(dim_head % x == 0);
PADDLE_ENFORCE_EQ(
dim_head % x, 0,
dim_head % x,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", dim_head, x));
......@@ -1043,15 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, true, false)
auto qkv_compute = AttnMatMul<T>(dev_ctx, false, true, bsz_seq, output_size,
input_size, compute_bias);
auto qkv_compute = AttnMatMul<T>(
dev_ctx, false, true, bsz_seq, output_size, input_size, compute_bias);
Tensor qkv_out;
auto *qkv_out_data =
qkv_out.mutable_data<T>({bsz, seq_len, 3, num_head, dim_head}, place);
// 3. fmha
AttnDropoutParam attn_param(true, "upscale_in_train", 0.0, true, true, 0,
nullptr);
AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr);
auto fmha_compute =
FMHARef<T>(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param);
auto *src_mask = ctx.Input<Tensor>("SrcMask");
......@@ -1061,17 +1089,20 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto out_seq_len = seq_len;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step->place(), platform::CPUPlace(),
PADDLE_ENFORCE_EQ(time_step->place(),
platform::CPUPlace(),
platform::errors::PreconditionNotMet(
"The place of input(TimeStep) must be CPUPlace."));
// cache_seq_len
int time_step_value = time_step->data<int>()[0];
PADDLE_ENFORCE_GT(time_step_value, 0,
PADDLE_ENFORCE_GT(time_step_value,
0,
platform::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d",
time_step_value));
PADDLE_ENFORCE_EQ(
seq_len, 1,
seq_len,
1,
platform::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
......@@ -1107,8 +1138,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto out_linear_biases = ctx.MultiInput<Tensor>("OutLinearBias");
int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>(dev_ctx, false, false, bsz_seq,
dim_embed, hidden_size, false);
auto out_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false);
// 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
......@@ -1117,9 +1148,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn_ln_scales = ctx.MultiInput<Tensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<Tensor>("FFNLnBias");
Tensor bias_dropout_residual_out, dropout_mask_out;
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out.mutable_data<T>({bsz, seq_len, dim_embed},
place);
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out_data =
bias_dropout_residual_out.mutable_data<T>({bsz, seq_len, dim_embed},
place);
}
auto *dropout_mask_out_data = dropout_mask_out.mutable_data<uint8_t>(
{bsz, seq_len, dim_embed}, place);
......@@ -1129,8 +1163,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn1_weight_dim = ffn1_weights[0]->dims();
int dim_ffn = ffn1_weight_dim[1];
auto ffn1_linear_compute = AttnMatMul<T>(dev_ctx, false, false, bsz_seq,
dim_ffn, dim_embed, false);
auto ffn1_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false);
Tensor ffn1_out;
auto *ffn1_out_data = ffn1_out.mutable_data<T>({bsz_seq, dim_ffn}, place);
......@@ -1147,8 +1181,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// 8. ffn2 matmul
auto ffn2_weights = ctx.MultiInput<Tensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<Tensor>("FFN2Bias");
auto ffn2_linear_compute = AttnMatMul<T>(dev_ctx, false, false, bsz_seq,
dim_embed, dim_ffn, false);
auto ffn2_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false);
// 9. ffn2 residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
......@@ -1171,14 +1205,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (layers & 1) {
// odd, set buf1 as out
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
for (int i = 0; i < layers; ++i) {
......@@ -1187,11 +1226,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
// TODO(wangxi): can remove mean var in inference
ln_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
buf1->data<T>(), ln_mean_data, ln_var_data);
} else if (!pre_layer_norm) {
PADDLE_THROW(platform::errors::Unimplemented(
"Unimplemented post_layer_norm for now."));
ln_compute.ComputeForward(x_data,
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step1";
......@@ -1201,8 +1241,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr;
// NOTE: in decoder stage, bias is fused in fmha
const Tensor *bias = time_step ? nullptr : qkv_bias;
qkv_compute.ComputeForward(qkv_weights[i], buf1, bias, &qkv_out,
&qkv_out);
if (!pre_layer_norm && i == 0) {
qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out);
} else {
qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2";
#endif
......@@ -1214,15 +1259,32 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
if (time_step) { // generation decoder stage
// [2, batch_size, num_head, max_seq_len, head_size]
int max_seq_len = cache_kv->dims()[3];
fmha<T>(dev_ctx, qkv_out, *qkv_bias, *src_mask, cache_kv_out, &fmha_out,
bsz, max_seq_len, num_head, dim_head, time_step->data<int>()[0],
fmha<T>(dev_ctx,
qkv_out,
*qkv_bias,
*src_mask,
cache_kv_out,
&fmha_out,
bsz,
max_seq_len,
num_head,
dim_head,
time_step->data<int>()[0],
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(
qkv_out, nullptr, src_mask, &transpose_out_2, nullptr, &qk_out,
&src_mask_out, &softmax_out, &attn_dropout_mask_out,
&attn_dropout_out, &qktv_out, &fmha_out);
fmha_compute.ComputeForward(qkv_out,
nullptr,
src_mask,
&transpose_out_2,
nullptr,
&qk_out,
&src_mask_out,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
// [3, bsz, num_head, seq_len, head_dim]
T *qkv_data = transpose_out_2_data;
int64_t q_size = bsz * seq_len * num_head * dim_head;
......@@ -1239,23 +1301,45 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
T *cache_k_ptr = cache_kv_data;
T *cache_v_ptr = cache_kv_data + cache_k_size;
write_cache_kv<T>(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, bsz,
num_head, seq_len, max_seq_len, dim_head);
write_cache_kv<T>(dev_ctx,
cache_k_ptr,
cache_v_ptr,
k_ptr,
v_ptr,
bsz,
num_head,
seq_len,
max_seq_len,
dim_head);
} else { // not generation
// TODO(wangxi): can remove dropout in inference
fmha_compute.ComputeForward(
qkv_out, cache_kv, src_mask, &transpose_out_2, cache_kv_out,
&qk_out, &src_mask_out, &softmax_out, &attn_dropout_mask_out,
&attn_dropout_out, &qktv_out, &fmha_out);
fmha_compute.ComputeForward(qkv_out,
cache_kv,
src_mask,
&transpose_out_2,
cache_kv_out,
&qk_out,
&src_mask_out,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
#endif
// step4. out_linear
out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out,
nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, dev_ctx);
if (pre_layer_norm) {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, dev_ctx);
} else {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr);
AllReduce<T>(*buf0, ring_id, dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
#endif
......@@ -1268,39 +1352,75 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// inplace
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx, buf1->data<T>(), x_data, out_linear_bias_data,
ln_scale_data, ln_bias_data, bias_dropout_residual_out_data,
dropout_mask_out_data, buf1->data<T>(), ln_mean_data, ln_var_data);
dev_ctx,
buf1->data<T>(),
x_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
} else {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
auto *residual_data = (i == 0 ? x_data : buf1->data<T>());
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
buf0->data<T>(),
residual_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
#endif
// step6. ffn matmul1
ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, nullptr,
&ffn1_out, nullptr);
ffn1_linear_compute.ComputeForward(
ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step6";
#endif
// step7. act bias
// TODO(wangxi): remove dropout mask in inference
fused_act_dropout_helper.DropoutActBias(
dev_ctx, ffn1_out_data, ffn1_biases[i]->data<T>(), "gelu",
ffn1_dropout_out_data, ffn1_dropout_mask_data);
fused_act_dropout_helper.DropoutActBias(dev_ctx,
ffn1_out_data,
ffn1_biases[i]->data<T>(),
"gelu",
ffn1_dropout_out_data,
ffn1_dropout_mask_data);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
#endif
// step8. ffn matmul2
ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out,
nullptr, buf1, nullptr);
if (pre_layer_norm) {
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr);
} else {
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.0";
#endif
AllReduce<T>(*buf1, ring_id, dev_ctx);
if (pre_layer_norm) {
AllReduce<T>(*buf1, ring_id, dev_ctx);
} else {
AllReduce<T>(*buf0, ring_id, dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.1";
#endif
......@@ -1312,23 +1432,49 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *ln_scale_data = ln_scales[i + 1]->data<U>();
auto *ln_bias_data = ln_biases[i + 1]->data<U>();
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx, buf1->data<T>(), bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(), ln_scale_data, ln_bias_data,
buf1->data<T>(), dropout_mask_out_data, buf0->data<T>(),
ln_mean_data, ln_var_data);
dev_ctx,
buf1->data<T>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
dropout_mask_out_data,
buf0->data<T>(),
ln_mean_data,
ln_var_data);
} else {
ffn2_fused_dropout_helper.ResidualDropoutBias(
dev_ctx, buf1->data<T>(), bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(), buf1->data<T>(),
dev_ctx,
buf1->data<T>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
buf1->data<T>(),
dropout_mask_out_data);
}
} else {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx,
buf0->data<T>(),
buf1->data<T>(),
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step9";
#endif
x_data = buf1->data<T>();
std::swap(buf0, buf1);
if (pre_layer_norm) {
x_data = buf1->data<T>();
std::swap(buf0, buf1);
}
}
}
};
......
......@@ -39,6 +39,7 @@ default_main_program().random_seed = 42
class TestFusedMultiTransformerOp(OpTest):
def setUp(self):
self.config()
self.generate_input_data()
......@@ -61,39 +62,33 @@ class TestFusedMultiTransformerOp(OpTest):
bias_attr = paddle.fluid.ParamAttr(
initializer=paddle.fluid.initializer.Constant(value=0.0005))
self.q_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=bias_attr)
self.q_proj = Linear(self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=bias_attr)
#bias_attr=self.bias_attr)
self.k_proj = Linear(
self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.v_proj = Linear(
self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.out_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn1_proj = Linear(
self.embed_dim,
4 * self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn2_proj = Linear(
4 * self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.k_proj = Linear(self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.v_proj = Linear(self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.out_proj = Linear(self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn1_proj = Linear(self.embed_dim,
4 * self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.ffn2_proj = Linear(4 * self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
paddle.set_default_dtype(np.float32)
self.norm = LayerNorm(self.embed_dim)
......@@ -228,8 +223,10 @@ class TestFusedMultiTransformerOp(OpTest):
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(
x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5)
qk_out = layers.matmul(x=q_out,
y=k_out,
transpose_y=True,
alpha=self.head_dim**-0.5)
if self.debug:
print('qk out is')
......@@ -249,11 +246,10 @@ class TestFusedMultiTransformerOp(OpTest):
print('softmax out is')
print(softmax_out[0][0][0])
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
dropout_out = F.dropout(softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
......@@ -265,8 +261,7 @@ class TestFusedMultiTransformerOp(OpTest):
print('fmha out is')
print(fmha_out[0][0][0])
out_linear_in = tensor.reshape(
x=fmha_out,
shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
......@@ -296,44 +291,44 @@ class TestFusedMultiTransformerOp(OpTest):
def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False)
k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False)
v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False)
out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False)
ffn1_weight = paddle.to_tensor(
self.ffn1_proj.weight, stop_gradient=False)
ffn2_weight = paddle.to_tensor(
self.ffn2_proj.weight, stop_gradient=False)
q_proj_weight = paddle.to_tensor(self.q_proj.weight,
stop_gradient=False)
k_proj_weight = paddle.to_tensor(self.k_proj.weight,
stop_gradient=False)
v_proj_weight = paddle.to_tensor(self.v_proj.weight,
stop_gradient=False)
out_linear_weight = paddle.to_tensor(self.out_proj.weight,
stop_gradient=False)
ffn1_weight = paddle.to_tensor(self.ffn1_proj.weight,
stop_gradient=False)
ffn2_weight = paddle.to_tensor(self.ffn2_proj.weight,
stop_gradient=False)
if self.bias_attr is False:
qkv_bias_tensor = None
out_linear_bias = None
else:
q_proj_bias = paddle.to_tensor(
self.q_proj.bias, stop_gradient=False)
k_proj_bias = paddle.to_tensor(
self.k_proj.bias, stop_gradient=False)
v_proj_bias = paddle.to_tensor(
self.v_proj.bias, stop_gradient=False)
q_proj_bias = paddle.to_tensor(self.q_proj.bias,
stop_gradient=False)
k_proj_bias = paddle.to_tensor(self.k_proj.bias,
stop_gradient=False)
v_proj_bias = paddle.to_tensor(self.v_proj.bias,
stop_gradient=False)
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False)
ffn1_bias = paddle.to_tensor(
self.ffn1_proj.bias, stop_gradient=False)
ffn2_bias = paddle.to_tensor(
self.ffn2_proj.bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(self.out_proj.bias,
stop_gradient=False)
ffn1_bias = paddle.to_tensor(self.ffn1_proj.bias,
stop_gradient=False)
ffn2_bias = paddle.to_tensor(self.ffn2_proj.bias,
stop_gradient=False)
ln_scale = paddle.to_tensor(self.norm.weight, stop_gradient=False)
ln_bias = paddle.to_tensor(self.norm.bias, stop_gradient=False)
ffn_ln_scale = paddle.to_tensor(
self.ffn_norm.weight, stop_gradient=False)
ffn_ln_scale = paddle.to_tensor(self.ffn_norm.weight,
stop_gradient=False)
ffn_ln_bias = paddle.to_tensor(self.ffn_norm.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
......@@ -351,12 +346,11 @@ class TestFusedMultiTransformerOp(OpTest):
cache_kvs = []
max_seq_length = (self.cache_length + 128) // 128 * 128
cache_kv = np.zeros(
[
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
cache_kv = np.zeros([
2, self.batch_size, self.num_heads, max_seq_length,
self.head_dim
],
dtype=self.x_type)
elems = 4
if self.x_type is np.float16:
......@@ -384,8 +378,9 @@ class TestFusedMultiTransformerOp(OpTest):
assert self.query_length == self.cache_length
cache_kv[:] = 0
else:
time_step = paddle.to_tensor(
[self.cache_length], dtype='int32', place=paddle.CPUPlace())
time_step = paddle.to_tensor([self.cache_length],
dtype='int32',
place=paddle.CPUPlace())
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
......@@ -417,31 +412,29 @@ class TestFusedMultiTransformerOp(OpTest):
ffn_ln_scales.append(ffn_ln_scale)
ffn_ln_biases.append(ffn_ln_bias)
if self.has_cache_kv:
cache_kvs.append(
paddle.to_tensor(
cache_kv, stop_gradient=False))
final_out = fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
out_weights,
out_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
training=self.training)
cache_kvs.append(paddle.to_tensor(cache_kv,
stop_gradient=False))
final_out = fused_multi_transformer(x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
out_weights,
out_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
training=self.training)
if self.has_cache_kv:
return final_out[0], final_out[1]
......@@ -463,9 +456,9 @@ class TestFusedMultiTransformerOp(OpTest):
if self.debug:
print("cache_k out timestep=128")
print(cache_kv_out[0].reshape([
2, bsz, num_head, v_elems, max_seq_len, elems
])[0, 0, 0, :, self.cache_length, :])
print(cache_kv_out[0].reshape(
[2, bsz, num_head, v_elems, max_seq_len,
elems])[0, 0, 0, :, self.cache_length, :])
print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
......@@ -486,18 +479,25 @@ class TestFusedMultiTransformerOp(OpTest):
cache_v = cache_kv_out[i][1, :, :, :self.cache_length, :]
np.testing.assert_allclose(
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(cache_k_ref,
cache_k,
rtol=self.rtol,
atol=self.atol)
np.testing.assert_allclose(cache_v_ref,
cache_v,
rtol=self.rtol,
atol=self.atol)
if i == 0:
break
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(final_out_ref,
final_out,
rtol=self.rtol,
atol=self.atol)
class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
......@@ -505,6 +505,7 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
......@@ -514,6 +515,7 @@ class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
......@@ -523,6 +525,7 @@ class TestFusedMultiTransformerOpCacheKVFp16(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
......@@ -530,12 +533,68 @@ class TestFusedMultiTransformerOpGenCacheKV(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册