未验证 提交 8a717a3e 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Support more activation in fused multi transformer (#48371)

* add activation support
* fix cublasLt bug
* remove useless code and fix test random range
上级 e9ca7600
......@@ -270,7 +270,17 @@ class FusedMultiTransformerOpOpMaker
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<std::string>("act_method", "act_method")
.SetDefault("gelu")
.AddCustomChecker([](const std::string &act_type) {
PADDLE_ENFORCE_EQ(
act_type == "gelu" || act_type == "relu" || act_type == "none",
true,
platform::errors::InvalidArgument(
"Only support `gelu`, `relu`, `none` activation in "
"FusedMultiTransformer. "));
});
AddAttr<bool>(
"trans_qkvw",
"Whether the weights of qkv should be transposed. If true,"
......
......@@ -31,6 +31,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method");
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
......@@ -61,7 +62,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set
// compute_bias as false.
auto qkv_compute = AttnMatMul<T>(dev_ctx,
......@@ -191,24 +191,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
// 6. ffn1 matmul + bias_add + gelu.
// 6. ffn1 matmul + act + bias
auto ffn1_weights = ctx.MultiInput<phi::DenseTensor>("FFN1Weight");
auto ffn1_biases = ctx.MultiInput<phi::DenseTensor>("FFN1Bias");
auto ffn1_weight_dim = ffn1_weights[0]->dims();
int dim_ffn = ffn1_weight_dim[1];
auto ffn1_cublas_linear = CublasFusedMLP<T>(dev_ctx);
const phi::DDim ffn1_input_shape({bsz_seq, dim_embed});
ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false);
Tensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
auto ffn1_linear_bias_gelu = CublasFusedMLP<T>(dev_ctx);
const phi::DDim ffn1_input_shape({bsz_seq, dim_ffn});
ffn1_linear_bias_gelu.Setup(
ffn1_input_shape, ffn1_weight_dim, false, false);
// 8. ffn2 matmul + bias_add + residual.
// 7. ffn2 matmul + bias + residual.
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
......@@ -216,7 +215,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ffn2_linear_bias_residual.Setup(
ffn1_out.dims(), ffn2_weights[0]->dims(), false, false);
// 9. ffn2 residual bias
// 8. ffn2 Layernorm
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
......@@ -333,7 +332,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_out,
&qktv_out,
&fmha_out);
const T *k_ptr = nullptr;
const T *v_ptr = nullptr;
......@@ -450,20 +448,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
#endif
// step6. ffn1 matmul + bias_add + gelu.
ffn1_linear_bias_gelu.ComputeForward(
buf1, ffn1_weights[i], ffn1_biases[i], nullptr, &ffn1_out, "gelu");
// step6. ffn matmul1
ffn1_cublas_linear.ComputeForward(buf1,
ffn1_weights[i],
ffn1_biases[i],
nullptr,
&ffn1_out,
act_method);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step6";
#endif
// step7. ffn2 matmul + bias_add + residual.
// step7. ffn2 matmul
if (pre_layer_norm) {
ffn2_linear_bias_residual.ComputeForward(&ffn1_out,
ffn2_weights[i],
......@@ -477,18 +478,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none");
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
#endif
if (pre_layer_norm) {
AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx);
} else {
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
VLOG(0) << "step7.1";
#endif
// step8. layer norm or do nothing(because bias_add + residual has been
// fused into cublasFusedMLP. )
// step8. layer norm or do nothing
// because bias_add + residual has been fused into cublasFusedMLP
if (pre_layer_norm) {
if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>();
......@@ -512,6 +516,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8";
#endif
......@@ -540,6 +545,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method");
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
......@@ -570,8 +576,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set
// compute_bias as false.
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we
// set compute_bias as false.
auto qkv_compute = AttnMatMul<T>(dev_ctx,
false,
trans_qkvw,
......@@ -979,7 +985,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
fused_act_dropout_helper.DropoutActBias(dev_ctx,
ffn1_out_data,
ffn1_biases[i]->data<T>(),
"gelu",
act_method,
ffn1_dropout_out_data,
ffn1_dropout_mask_data);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......
......@@ -1414,14 +1414,15 @@ class CublasFusedMLP {
public:
// (m, n, k) = bsz_seq, hidden_feature, in_feature
explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) {
// Set Math Type
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
if (FLAGS_gemm_use_half_precision_compute_type) {
// This option default value is true, it tends to result NaN, but get
// better inference speed. you can turn off by using `export
// FLAGS_gemm_use_half_precision_compute_type=0`.
compute_type = CUBLAS_COMPUTE_16F;
scale_type = CUDA_R_16F;
}
......@@ -1435,7 +1436,6 @@ class CublasFusedMLP {
compute_type = CUBLAS_COMPUTE_64F;
}
// Just for init.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc_, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
......@@ -1445,7 +1445,6 @@ class CublasFusedMLP {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc_, mat_type, 1, 1, 1));
}
~CublasFusedMLP() {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc_));
......@@ -1457,7 +1456,6 @@ class CublasFusedMLP {
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_));
}
// Change to use tensor's shape.
void Setup(const phi::DDim &x_shape,
const phi::DDim &w_shape,
bool trans_x,
......@@ -1481,39 +1479,34 @@ class CublasFusedMLP {
&cublas_transB,
sizeof(cublas_transB)));
/*
cublas use col major: x(M, K) matmul w(K, N) = out(M, N) equals to w_t(N, K)
* x_t(K, M) = out(N, M)
*/
SetCublasMatrixLayout_(x_desc_, cublas_transA, K, M);
SetCublasMatrixLayout_(w_desc_, cublas_transB, N, K);
SetCublasMatrixLayout_(out_desc_, CUBLAS_OP_N, N, M);
SetCublasMatrixLayout(x_desc_, trans_x, M, K);
SetCublasMatrixLayout(w_desc_, trans_w, K, N);
SetCublasMatrixLayout(out_desc_, false, M, N);
}
void ComputeForward(const phi::DenseTensor *input,
void ComputeForward(const phi::DenseTensor *x,
const phi::DenseTensor *weight,
const phi::DenseTensor *bias,
phi::DenseTensor *residual,
phi::DenseTensor *output,
const std::string &activation) {
// here: (transa, transb): nt, input * weight.
// (M * K) * (K * N)
cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(16) * 1024 * 1024;
cudaStream_t stream = dev_ctx_.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx_.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
T *out_data = output->data<T>();
const bool add_residual = (residual == nullptr) ? false : true;
const bool add_bias = (bias == nullptr) ? false : true;
const T *bias_data = nullptr;
if (add_bias) {
SetCublasBiasPtr_(bias);
bias_data = bias->data<T>();
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
// Set cublasLt epilogue.
cublasLtEpilogue_t epiloque_func = GetEpilogueType_(activation, add_bias);
cublasLtEpilogue_t epiloque_func = GetEpilogueType(activation, add_bias);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
......@@ -1521,25 +1514,44 @@ class CublasFusedMLP {
&epiloque_func,
sizeof(epiloque_func)));
const auto *x_data = input->data<T>();
const auto *w_data = weight->data<T>();
auto *residual_data =
add_residual ? residual->data<T>() : output->data<T>();
auto *out_data = output->data<T>();
T *residual_data = add_residual ? residual->data<T>() : out_data;
cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx_.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc(
dev_ctx_.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx_.stream())));
// if add_residual, we compute result + 1.0 * residual, else result + 0.0 *
// out.
// if add_residual, we compute result + 1.0 * residual,
// else result + 0.0 * out.
double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0;
float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f;
half alpha16 = static_cast<half>(1.0),
beta16 =
add_residual ? static_cast<half>(1.0) : static_cast<half>(0.0);
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, float>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, phi::dtype::float16>::value) {
alpha = &alpha16;
beta = &beta16;
} else {
alpha = &alpha32;
beta = &beta32;
PADDLE_ENFORCE_EQ(true,
false,
platform::errors::InvalidArgument(
"Only support double, float, half data type. "));
}
const auto *x_data = x->data<T>();
const auto *w_data = weight->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc_,
w_desc_,
......@@ -1567,15 +1579,15 @@ class CublasFusedMLP {
out_desc_,
out_data,
out_desc_,
algo /*algo*/,
workspace->ptr() /*workspace*/,
algo,
workspace->ptr(),
workspace_size,
stream));
}
private:
static cublasLtEpilogue_t GetEpilogueType_(const std::string &activation,
const bool add_bias) {
cublasLtEpilogue_t GetEpilogueType(const std::string &activation,
const bool add_bias) {
if (activation == "relu") {
if (add_bias) {
return CUBLASLT_EPILOGUE_RELU_BIAS;
......@@ -1606,23 +1618,41 @@ class CublasFusedMLP {
}
}
void SetCublasMatrixLayout_(cublasLtMatrixLayout_t layout_desc,
cublasOperation_t cublas_trans,
const size_t cublas_m,
const size_t cublas_n) {
void SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc,
const bool transpose,
const uint64_t cublas_row,
const uint64_t cublas_col) {
cudaDataType_t mat_type = CUDA_R_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
}
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_TYPE,
&mat_type,
sizeof(mat_type)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_ROWS,
cublas_trans == CUBLAS_OP_N ? &cublas_m : &cublas_n,
sizeof(cublas_m)));
transpose ? &cublas_row : &cublas_col,
sizeof(cublas_row)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_COLS,
cublas_trans == CUBLAS_OP_N ? &cublas_n : &cublas_m,
sizeof(cublas_m)));
const size_t cublas_ld = cublas_trans == CUBLAS_OP_N ? cublas_m : cublas_n;
transpose ? &cublas_col : &cublas_row,
sizeof(cublas_col)));
int64_t cublas_ld = transpose ? cublas_row : cublas_col;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
......@@ -1631,21 +1661,11 @@ class CublasFusedMLP {
sizeof(cublas_ld)));
}
void SetCublasBiasPtr_(const phi::DenseTensor *bias) {
const T *bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
}
const phi::GPUContext &dev_ctx_;
cublasLtMatmulDesc_t operation_desc_;
cublasLtMatrixLayout_t x_desc_;
cublasLtMatrixLayout_t w_desc_;
cublasLtMatrixLayout_t out_desc_;
cublasLtMatmulDesc_t operation_desc_ = NULL;
cublasLtMatrixLayout_t x_desc_ = NULL;
cublasLtMatrixLayout_t w_desc_ = NULL;
cublasLtMatrixLayout_t out_desc_ = NULL;
};
#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_
......
......@@ -124,6 +124,7 @@ class TestFusedMultiTransformerOp(OpTest):
self.training = False
self.layers = 4
self.batch_size = 8
self.query_length = 128
self.cache_length = 128
......@@ -144,21 +145,27 @@ class TestFusedMultiTransformerOp(OpTest):
)
def generate_input_data(self):
self.query = np.random.rand(
self.batch_size, self.query_length, self.embed_dim
self.query = np.random.uniform(
-1, 1, (self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type)
out_seq_len = self.key_length
if self.has_cache_kv:
assert self.training is False, ValueError(
'cache_kv can only used in inference'
)
self.cache_kv = np.random.rand(
2,
self.batch_size,
self.num_heads,
self.cache_length,
self.head_dim,
self.cache_kv = np.random.uniform(
-1,
1,
(
2,
self.batch_size,
self.num_heads,
self.cache_length,
self.head_dim,
),
).astype(self.x_type)
if self.gen_cache_kv:
self.cache_kv[:] = 0
else:
......@@ -168,12 +175,16 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_pre_cache:
out_seq_len += self.pre_cache_num
self.pre_cache_kv = np.random.rand(
2,
self.batch_size,
self.num_heads,
self.pre_cache_num,
self.head_dim,
self.pre_cache_kv = np.random.uniform(
-1,
1,
(
2,
self.batch_size,
self.num_heads,
self.pre_cache_num,
self.head_dim,
),
).astype(self.x_type)
if self.has_attn_mask:
......@@ -204,8 +215,8 @@ class TestFusedMultiTransformerOp(OpTest):
self.attn_mask = None
self.key, self.value = self.query, self.query
self.dout = np.random.random(
(self.batch_size, self.query_length, self.embed_dim)
self.dout = np.random.uniform(
-1, 1, (self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type)
def GetBaselineOut(self):
......@@ -544,6 +555,7 @@ class TestFusedMultiTransformerOp(OpTest):
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
activation=self.act_method,
training=self.training,
)
......@@ -668,6 +680,7 @@ class TestFusedMultiTransformerOp(OpTest):
self.num_heads,
4 * self.embed_dim,
self.dropout_prob,
activation=self.act_method,
normalize_before=self.pre_layer_norm,
ln_scale_attrs=ln_scales_attr,
ln_bias_attrs=ln_biases_attr,
......@@ -797,6 +810,14 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpActReluFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.act_method = "relu"
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
def config(self):
super().config()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册