未验证 提交 643c94e4 编写于 作者: C carryyu 提交者: GitHub

enhance fused_multi_transformer_op(post_layer_norm) (#44789)

* add fused_multi_transformer post_layer_norm

* add test post_layer_norm
上级 bdce552b
......@@ -1279,9 +1279,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn_ln_scales = ctx.MultiInput<Tensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<Tensor>("FFNLnBias");
Tensor bias_dropout_residual_out, dropout_mask_out;
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out.mutable_data<T>({bsz, seq_len, dim_embed},
place);
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out_data =
bias_dropout_residual_out.mutable_data<T>({bsz, seq_len, dim_embed},
place);
}
auto *dropout_mask_out_data = dropout_mask_out.mutable_data<uint8_t>(
{bsz, seq_len, dim_embed}, place);
......@@ -1333,14 +1336,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (layers & 1) {
// odd, set buf1 as out
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
for (int i = 0; i < layers; ++i) {
......@@ -1355,9 +1363,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
buf1->data<T>(),
ln_mean_data,
ln_var_data);
} else if (!pre_layer_norm) {
PADDLE_THROW(platform::errors::Unimplemented(
"Unimplemented post_layer_norm for now."));
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step1";
......@@ -1367,8 +1372,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr;
// NOTE: in decoder stage, bias is fused in fmha
const Tensor *bias = time_step ? nullptr : qkv_bias;
qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
if (!pre_layer_norm && i == 0) {
qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out);
} else {
qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2";
#endif
......@@ -1451,10 +1461,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
VLOG(0) << "step3";
#endif
// step4. out_linear
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, dev_ctx);
if (pre_layer_norm) {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, dev_ctx);
} else {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr);
AllReduce<T>(*buf0, ring_id, dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
#endif
......@@ -1479,6 +1494,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data,
ln_var_data);
} else {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
auto *residual_data = (i == 0 ? x_data : buf1->data<T>());
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
buf0->data<T>(),
residual_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
......@@ -1504,13 +1535,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
#endif
// step8. ffn matmul2
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr);
if (pre_layer_norm) {
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr);
} else {
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.0";
#endif
AllReduce<T>(*buf1, ring_id, dev_ctx);
if (pre_layer_norm) {
AllReduce<T>(*buf1, ring_id, dev_ctx);
} else {
AllReduce<T>(*buf0, ring_id, dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8.1";
#endif
......@@ -1543,12 +1583,28 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dropout_mask_out_data);
}
} else {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx,
buf0->data<T>(),
buf1->data<T>(),
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step9";
#endif
x_data = buf1->data<T>();
std::swap(buf0, buf1);
if (pre_layer_norm) {
x_data = buf1->data<T>();
std::swap(buf0, buf1);
}
}
}
};
......
......@@ -548,5 +548,60 @@ class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.pre_layer_norm = False
class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16(
TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册