未验证 提交 8a717a3e 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Support more activation in fused multi transformer (#48371)

* add activation support
* fix cublasLt bug
* remove useless code and fix test random range
上级 e9ca7600
...@@ -270,7 +270,17 @@ class FusedMultiTransformerOpOpMaker ...@@ -270,7 +270,17 @@ class FusedMultiTransformerOpOpMaker
"dropout_implementation can only be downgrade_in_infer or " "dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train")); "upscale_in_train"));
}); });
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu"); AddAttr<std::string>("act_method", "act_method")
.SetDefault("gelu")
.AddCustomChecker([](const std::string &act_type) {
PADDLE_ENFORCE_EQ(
act_type == "gelu" || act_type == "relu" || act_type == "none",
true,
platform::errors::InvalidArgument(
"Only support `gelu`, `relu`, `none` activation in "
"FusedMultiTransformer. "));
});
AddAttr<bool>( AddAttr<bool>(
"trans_qkvw", "trans_qkvw",
"Whether the weights of qkv should be transposed. If true," "Whether the weights of qkv should be transposed. If true,"
......
...@@ -31,6 +31,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -31,6 +31,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int seq_len = input_x_dims[1]; int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2]; int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len; int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method");
// 1. layer norm // 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
...@@ -61,7 +62,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -61,7 +62,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false) // (transA, transB, compute_bias) = (false, trans_qkvw, false)
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set
// compute_bias as false. // compute_bias as false.
auto qkv_compute = AttnMatMul<T>(dev_ctx, auto qkv_compute = AttnMatMul<T>(dev_ctx,
...@@ -191,24 +191,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -191,24 +191,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>( auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
// 6. ffn1 matmul + bias_add + gelu. // 6. ffn1 matmul + act + bias
auto ffn1_weights = ctx.MultiInput<phi::DenseTensor>("FFN1Weight"); auto ffn1_weights = ctx.MultiInput<phi::DenseTensor>("FFN1Weight");
auto ffn1_biases = ctx.MultiInput<phi::DenseTensor>("FFN1Bias"); auto ffn1_biases = ctx.MultiInput<phi::DenseTensor>("FFN1Bias");
auto ffn1_weight_dim = ffn1_weights[0]->dims(); auto ffn1_weight_dim = ffn1_weights[0]->dims();
int dim_ffn = ffn1_weight_dim[1]; int dim_ffn = ffn1_weight_dim[1];
auto ffn1_cublas_linear = CublasFusedMLP<T>(dev_ctx);
const phi::DDim ffn1_input_shape({bsz_seq, dim_embed});
ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false);
Tensor ffn1_out; Tensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}}); ffn1_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_out_data = auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
auto ffn1_linear_bias_gelu = CublasFusedMLP<T>(dev_ctx); // 7. ffn2 matmul + bias + residual.
const phi::DDim ffn1_input_shape({bsz_seq, dim_ffn});
ffn1_linear_bias_gelu.Setup(
ffn1_input_shape, ffn1_weight_dim, false, false);
// 8. ffn2 matmul + bias_add + residual.
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight"); auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias"); auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
...@@ -216,7 +215,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -216,7 +215,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ffn2_linear_bias_residual.Setup( ffn2_linear_bias_residual.Setup(
ffn1_out.dims(), ffn2_weights[0]->dims(), false, false); ffn1_out.dims(), ffn2_weights[0]->dims(), false, false);
// 9. ffn2 residual bias // 8. ffn2 Layernorm
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper( FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
...@@ -333,7 +332,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -333,7 +332,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_out, &attn_dropout_out,
&qktv_out, &qktv_out,
&fmha_out); &fmha_out);
const T *k_ptr = nullptr; const T *k_ptr = nullptr;
const T *v_ptr = nullptr; const T *v_ptr = nullptr;
...@@ -450,20 +448,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -450,20 +448,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data, ln_mean_data,
ln_var_data); ln_var_data);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5"; VLOG(0) << "step5";
#endif #endif
// step6. ffn1 matmul + bias_add + gelu.
ffn1_linear_bias_gelu.ComputeForward( // step6. ffn matmul1
buf1, ffn1_weights[i], ffn1_biases[i], nullptr, &ffn1_out, "gelu"); ffn1_cublas_linear.ComputeForward(buf1,
ffn1_weights[i],
ffn1_biases[i],
nullptr,
&ffn1_out,
act_method);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step6"; VLOG(0) << "step6";
#endif #endif
// step7. ffn2 matmul + bias_add + residual. // step7. ffn2 matmul
if (pre_layer_norm) { if (pre_layer_norm) {
ffn2_linear_bias_residual.ComputeForward(&ffn1_out, ffn2_linear_bias_residual.ComputeForward(&ffn1_out,
ffn2_weights[i], ffn2_weights[i],
...@@ -477,18 +478,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -477,18 +478,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none"); &ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none");
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
#endif
if (pre_layer_norm) { if (pre_layer_norm) {
AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx); AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx);
} else { } else {
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx); AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7"; VLOG(0) << "step7.1";
#endif #endif
// step8. layer norm or do nothing(because bias_add + residual has been // step8. layer norm or do nothing
// fused into cublasFusedMLP. ) // because bias_add + residual has been fused into cublasFusedMLP
if (pre_layer_norm) { if (pre_layer_norm) {
if (i < layers - 1) { if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>(); auto *ln_scale_data = ln_scales[i + 1]->data<U>();
...@@ -512,6 +516,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -512,6 +516,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data, ln_mean_data,
ln_var_data); ln_var_data);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8"; VLOG(0) << "step8";
#endif #endif
...@@ -540,6 +545,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -540,6 +545,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int seq_len = input_x_dims[1]; int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2]; int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len; int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method");
// 1. layer norm // 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
...@@ -570,8 +576,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -570,8 +576,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false) // (transA, transB, compute_bias) = (false, trans_qkvw, false)
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we
// compute_bias as false. // set compute_bias as false.
auto qkv_compute = AttnMatMul<T>(dev_ctx, auto qkv_compute = AttnMatMul<T>(dev_ctx,
false, false,
trans_qkvw, trans_qkvw,
...@@ -979,7 +985,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -979,7 +985,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
fused_act_dropout_helper.DropoutActBias(dev_ctx, fused_act_dropout_helper.DropoutActBias(dev_ctx,
ffn1_out_data, ffn1_out_data,
ffn1_biases[i]->data<T>(), ffn1_biases[i]->data<T>(),
"gelu", act_method,
ffn1_dropout_out_data, ffn1_dropout_out_data,
ffn1_dropout_mask_data); ffn1_dropout_mask_data);
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......
...@@ -1414,14 +1414,15 @@ class CublasFusedMLP { ...@@ -1414,14 +1414,15 @@ class CublasFusedMLP {
public: public:
// (m, n, k) = bsz_seq, hidden_feature, in_feature // (m, n, k) = bsz_seq, hidden_feature, in_feature
explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) { explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) {
// Set Math Type
cudaDataType_t mat_type = CUDA_R_32F; cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F; cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) { if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F; mat_type = CUDA_R_16F;
if (FLAGS_gemm_use_half_precision_compute_type) { if (FLAGS_gemm_use_half_precision_compute_type) {
// This option default value is true, it tends to result NaN, but get
// better inference speed. you can turn off by using `export
// FLAGS_gemm_use_half_precision_compute_type=0`.
compute_type = CUBLAS_COMPUTE_16F; compute_type = CUBLAS_COMPUTE_16F;
scale_type = CUDA_R_16F; scale_type = CUDA_R_16F;
} }
...@@ -1435,7 +1436,6 @@ class CublasFusedMLP { ...@@ -1435,7 +1436,6 @@ class CublasFusedMLP {
compute_type = CUBLAS_COMPUTE_64F; compute_type = CUBLAS_COMPUTE_64F;
} }
// Just for init.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc_, compute_type, scale_type)); &operation_desc_, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
...@@ -1445,7 +1445,6 @@ class CublasFusedMLP { ...@@ -1445,7 +1445,6 @@ class CublasFusedMLP {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc_, mat_type, 1, 1, 1)); &out_desc_, mat_type, 1, 1, 1));
} }
~CublasFusedMLP() { ~CublasFusedMLP() {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc_)); platform::dynload::cublasLtMatmulDescDestroy(operation_desc_));
...@@ -1457,7 +1456,6 @@ class CublasFusedMLP { ...@@ -1457,7 +1456,6 @@ class CublasFusedMLP {
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_)); platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_));
} }
// Change to use tensor's shape.
void Setup(const phi::DDim &x_shape, void Setup(const phi::DDim &x_shape,
const phi::DDim &w_shape, const phi::DDim &w_shape,
bool trans_x, bool trans_x,
...@@ -1481,39 +1479,34 @@ class CublasFusedMLP { ...@@ -1481,39 +1479,34 @@ class CublasFusedMLP {
&cublas_transB, &cublas_transB,
sizeof(cublas_transB))); sizeof(cublas_transB)));
/* SetCublasMatrixLayout(x_desc_, trans_x, M, K);
cublas use col major: x(M, K) matmul w(K, N) = out(M, N) equals to w_t(N, K) SetCublasMatrixLayout(w_desc_, trans_w, K, N);
* x_t(K, M) = out(N, M) SetCublasMatrixLayout(out_desc_, false, M, N);
*/
SetCublasMatrixLayout_(x_desc_, cublas_transA, K, M);
SetCublasMatrixLayout_(w_desc_, cublas_transB, N, K);
SetCublasMatrixLayout_(out_desc_, CUBLAS_OP_N, N, M);
} }
void ComputeForward(const phi::DenseTensor *input, void ComputeForward(const phi::DenseTensor *x,
const phi::DenseTensor *weight, const phi::DenseTensor *weight,
const phi::DenseTensor *bias, const phi::DenseTensor *bias,
phi::DenseTensor *residual, phi::DenseTensor *residual,
phi::DenseTensor *output, phi::DenseTensor *output,
const std::string &activation) { const std::string &activation) {
// here: (transa, transb): nt, input * weight. T *out_data = output->data<T>();
// (M * K) * (K * N)
cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(16) * 1024 * 1024;
cudaStream_t stream = dev_ctx_.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx_.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
const bool add_residual = (residual == nullptr) ? false : true; const bool add_residual = (residual == nullptr) ? false : true;
const bool add_bias = (bias == nullptr) ? false : true; const bool add_bias = (bias == nullptr) ? false : true;
const T *bias_data = nullptr;
if (add_bias) { if (add_bias) {
SetCublasBiasPtr_(bias); bias_data = bias->data<T>();
} }
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
// Set cublasLt epilogue. cublasLtEpilogue_t epiloque_func = GetEpilogueType(activation, add_bias);
cublasLtEpilogue_t epiloque_func = GetEpilogueType_(activation, add_bias);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_, operation_desc_,
...@@ -1521,25 +1514,44 @@ class CublasFusedMLP { ...@@ -1521,25 +1514,44 @@ class CublasFusedMLP {
&epiloque_func, &epiloque_func,
sizeof(epiloque_func))); sizeof(epiloque_func)));
const auto *x_data = input->data<T>(); T *residual_data = add_residual ? residual->data<T>() : out_data;
const auto *w_data = weight->data<T>();
auto *residual_data = cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle();
add_residual ? residual->data<T>() : output->data<T>(); size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
auto *out_data = output->data<T>(); cudaStream_t stream = dev_ctx_.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc(
dev_ctx_.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx_.stream())));
// if add_residual, we compute result + 1.0 * residual, else result + 0.0 * // if add_residual, we compute result + 1.0 * residual,
// out. // else result + 0.0 * out.
double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0; double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0;
float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f; float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f;
half alpha16 = static_cast<half>(1.0),
beta16 =
add_residual ? static_cast<half>(1.0) : static_cast<half>(0.0);
void *alpha = nullptr, *beta = nullptr; void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) { if (std::is_same<T, double>::value) {
alpha = &alpha64; alpha = &alpha64;
beta = &beta64; beta = &beta64;
} else if (std::is_same<T, float>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, phi::dtype::float16>::value) {
alpha = &alpha16;
beta = &beta16;
} else { } else {
alpha = &alpha32; PADDLE_ENFORCE_EQ(true,
beta = &beta32; false,
platform::errors::InvalidArgument(
"Only support double, float, half data type. "));
} }
const auto *x_data = x->data<T>();
const auto *w_data = weight->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc_, operation_desc_,
w_desc_, w_desc_,
...@@ -1567,15 +1579,15 @@ class CublasFusedMLP { ...@@ -1567,15 +1579,15 @@ class CublasFusedMLP {
out_desc_, out_desc_,
out_data, out_data,
out_desc_, out_desc_,
algo /*algo*/, algo,
workspace->ptr() /*workspace*/, workspace->ptr(),
workspace_size, workspace_size,
stream)); stream));
} }
private: private:
static cublasLtEpilogue_t GetEpilogueType_(const std::string &activation, cublasLtEpilogue_t GetEpilogueType(const std::string &activation,
const bool add_bias) { const bool add_bias) {
if (activation == "relu") { if (activation == "relu") {
if (add_bias) { if (add_bias) {
return CUBLASLT_EPILOGUE_RELU_BIAS; return CUBLASLT_EPILOGUE_RELU_BIAS;
...@@ -1606,23 +1618,41 @@ class CublasFusedMLP { ...@@ -1606,23 +1618,41 @@ class CublasFusedMLP {
} }
} }
void SetCublasMatrixLayout_(cublasLtMatrixLayout_t layout_desc, void SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc,
cublasOperation_t cublas_trans, const bool transpose,
const size_t cublas_m, const uint64_t cublas_row,
const size_t cublas_n) { const uint64_t cublas_col) {
cudaDataType_t mat_type = CUDA_R_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
}
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_TYPE,
&mat_type,
sizeof(mat_type)));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute( platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc, layout_desc,
CUBLASLT_MATRIX_LAYOUT_ROWS, CUBLASLT_MATRIX_LAYOUT_ROWS,
cublas_trans == CUBLAS_OP_N ? &cublas_m : &cublas_n, transpose ? &cublas_row : &cublas_col,
sizeof(cublas_m))); sizeof(cublas_row)));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute( platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc, layout_desc,
CUBLASLT_MATRIX_LAYOUT_COLS, CUBLASLT_MATRIX_LAYOUT_COLS,
cublas_trans == CUBLAS_OP_N ? &cublas_n : &cublas_m, transpose ? &cublas_col : &cublas_row,
sizeof(cublas_m))); sizeof(cublas_col)));
const size_t cublas_ld = cublas_trans == CUBLAS_OP_N ? cublas_m : cublas_n; int64_t cublas_ld = transpose ? cublas_row : cublas_col;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute( platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc, layout_desc,
...@@ -1631,21 +1661,11 @@ class CublasFusedMLP { ...@@ -1631,21 +1661,11 @@ class CublasFusedMLP {
sizeof(cublas_ld))); sizeof(cublas_ld)));
} }
void SetCublasBiasPtr_(const phi::DenseTensor *bias) {
const T *bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
}
const phi::GPUContext &dev_ctx_; const phi::GPUContext &dev_ctx_;
cublasLtMatmulDesc_t operation_desc_; cublasLtMatmulDesc_t operation_desc_ = NULL;
cublasLtMatrixLayout_t x_desc_; cublasLtMatrixLayout_t x_desc_ = NULL;
cublasLtMatrixLayout_t w_desc_; cublasLtMatrixLayout_t w_desc_ = NULL;
cublasLtMatrixLayout_t out_desc_; cublasLtMatrixLayout_t out_desc_ = NULL;
}; };
#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_ #endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_
......
...@@ -124,6 +124,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -124,6 +124,7 @@ class TestFusedMultiTransformerOp(OpTest):
self.training = False self.training = False
self.layers = 4 self.layers = 4
self.batch_size = 8 self.batch_size = 8
self.query_length = 128 self.query_length = 128
self.cache_length = 128 self.cache_length = 128
...@@ -144,21 +145,27 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -144,21 +145,27 @@ class TestFusedMultiTransformerOp(OpTest):
) )
def generate_input_data(self): def generate_input_data(self):
self.query = np.random.rand( self.query = np.random.uniform(
self.batch_size, self.query_length, self.embed_dim -1, 1, (self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type) ).astype(self.x_type)
out_seq_len = self.key_length out_seq_len = self.key_length
if self.has_cache_kv: if self.has_cache_kv:
assert self.training is False, ValueError( assert self.training is False, ValueError(
'cache_kv can only used in inference' 'cache_kv can only used in inference'
) )
self.cache_kv = np.random.rand( self.cache_kv = np.random.uniform(
2, -1,
self.batch_size, 1,
self.num_heads, (
self.cache_length, 2,
self.head_dim, self.batch_size,
self.num_heads,
self.cache_length,
self.head_dim,
),
).astype(self.x_type) ).astype(self.x_type)
if self.gen_cache_kv: if self.gen_cache_kv:
self.cache_kv[:] = 0 self.cache_kv[:] = 0
else: else:
...@@ -168,12 +175,16 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -168,12 +175,16 @@ class TestFusedMultiTransformerOp(OpTest):
if self.has_pre_cache: if self.has_pre_cache:
out_seq_len += self.pre_cache_num out_seq_len += self.pre_cache_num
self.pre_cache_kv = np.random.rand( self.pre_cache_kv = np.random.uniform(
2, -1,
self.batch_size, 1,
self.num_heads, (
self.pre_cache_num, 2,
self.head_dim, self.batch_size,
self.num_heads,
self.pre_cache_num,
self.head_dim,
),
).astype(self.x_type) ).astype(self.x_type)
if self.has_attn_mask: if self.has_attn_mask:
...@@ -204,8 +215,8 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -204,8 +215,8 @@ class TestFusedMultiTransformerOp(OpTest):
self.attn_mask = None self.attn_mask = None
self.key, self.value = self.query, self.query self.key, self.value = self.query, self.query
self.dout = np.random.random( self.dout = np.random.uniform(
(self.batch_size, self.query_length, self.embed_dim) -1, 1, (self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type) ).astype(self.x_type)
def GetBaselineOut(self): def GetBaselineOut(self):
...@@ -544,6 +555,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -544,6 +555,7 @@ class TestFusedMultiTransformerOp(OpTest):
time_step=time_step, time_step=time_step,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_prob, dropout_rate=self.dropout_prob,
activation=self.act_method,
training=self.training, training=self.training,
) )
...@@ -668,6 +680,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -668,6 +680,7 @@ class TestFusedMultiTransformerOp(OpTest):
self.num_heads, self.num_heads,
4 * self.embed_dim, 4 * self.embed_dim,
self.dropout_prob, self.dropout_prob,
activation=self.act_method,
normalize_before=self.pre_layer_norm, normalize_before=self.pre_layer_norm,
ln_scale_attrs=ln_scales_attr, ln_scale_attrs=ln_scales_attr,
ln_bias_attrs=ln_biases_attr, ln_bias_attrs=ln_biases_attr,
...@@ -797,6 +810,14 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): ...@@ -797,6 +810,14 @@ class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
self.layers = 3 # odd layers self.layers = 3 # odd layers
class TestFusedMultiTransformerOpActReluFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.act_method = "relu"
self.layers = 3 # odd layers
class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpCacheKV(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册