未验证 提交 b07e6b45 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Use cublaslt in multi transformer FFN (#48052)

* use fused mlp in multi transformer
* Restruct code
* use cublaslt to fuse ffn
* fix conflict
上级 78b30e97
......@@ -14,6 +14,517 @@ limitations under the License. */
namespace paddle {
namespace operators {
#if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation.
template <typename T>
class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto &dev_ctx = ctx.cuda_device_context();
auto *time_step = ctx.Input<phi::DenseTensor>("TimeStep");
// 0. input
auto *input_x = ctx.Input<phi::DenseTensor>("X");
const auto input_x_dims = input_x->dims();
int bsz = input_x_dims[0];
int seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale");
auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias");
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
Tensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}});
auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto qkv_weights = ctx.MultiInput<phi::DenseTensor>("QKVW");
auto qkv_biases = ctx.MultiInput<phi::DenseTensor>("QKVBias");
const bool trans_qkvw = ctx.Attr<bool>("trans_qkvw");
const auto qkv_w_dims = qkv_weights[0]->dims();
int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr;
// (transA, transB, compute_bias) = (false, trans_qkvw, false)
// Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set
// compute_bias as false.
auto qkv_compute = AttnMatMul<T>(dev_ctx,
false,
trans_qkvw,
bsz_seq,
output_size,
input_size,
/*compute_bias=*/false);
Tensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 3. fmha
AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr);
auto fmha_compute =
FMHARef<T>(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param);
auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
auto cache_kvs = ctx.MultiInput<phi::DenseTensor>("CacheKV");
auto cache_kv_outs = ctx.MultiOutput<phi::DenseTensor>("CacheKVOut");
// auto *time_step = ctx.Input<phi::DenseTensor>("TimeStep");
auto pre_caches = ctx.MultiInput<phi::DenseTensor>("PreCaches");
int cache_offset = 0;
if (pre_caches.size() > 0) {
cache_offset = pre_caches[0]->dims()[3];
}
auto out_seq_len = seq_len;
if (time_step) {
PADDLE_ENFORCE_EQ(time_step->place(),
platform::CPUPlace(),
platform::errors::PreconditionNotMet(
"The place of input(TimeStep) must be CPUPlace."));
// cache_seq_len
int time_step_value = time_step->data<int>()[0];
PADDLE_ENFORCE_GT(time_step_value,
0,
platform::errors::PreconditionNotMet(
"The value of time_step must > 0, but now is %d",
time_step_value));
PADDLE_ENFORCE_EQ(
seq_len,
1,
platform::errors::PreconditionNotMet(
"In decode stage, the seq_len of input must be 1, but now is %d",
seq_len));
out_seq_len += time_step_value;
} else {
out_seq_len += cache_offset;
}
Tensor q_transpose_out, kv_transpose_out, qk_out;
q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *q_transpose_out_data =
dev_ctx.Alloc<T>(&q_transpose_out, q_transpose_out.numel() * sizeof(T));
kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}});
auto *kv_transpose_out_data = dev_ctx.Alloc<T>(
&kv_transpose_out, kv_transpose_out.numel() * sizeof(T));
qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
Tensor src_mask_out;
if (cache_offset > 0) {
src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *src_mask_out_data =
dev_ctx.Alloc<T>(&src_mask_out, src_mask_out.numel() * sizeof(T));
}
// [2, bs, num_head, cache_seq_len + seq_len, head_dim]
Tensor pre_cache_kv_out;
if (cache_offset > 0) {
pre_cache_kv_out.Resize(
{{2, bsz, num_head, seq_len + cache_offset, dim_head}});
auto *pre_cache_kv_out_data = dev_ctx.Alloc<T>(
&pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T));
}
Tensor softmax_out;
Tensor attn_dropout_mask_out, attn_dropout_out;
Tensor qktv_out, fmha_out;
softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *softmax_out_data =
dev_ctx.Alloc<T>(&softmax_out, softmax_out.numel() * sizeof(T));
attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_mask_out_data = dev_ctx.Alloc<T>(
&attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T));
attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_data_data = dev_ctx.Alloc<T>(
&attn_dropout_out, attn_dropout_out.numel() * sizeof(T));
qktv_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *qktv_out_data =
dev_ctx.Alloc<T>(&qktv_out, qktv_out.numel() * sizeof(T));
fmha_out.Resize({{bsz, seq_len, num_head, dim_head}});
auto *fmha_out_data =
dev_ctx.Alloc<T>(&fmha_out, fmha_out.numel() * sizeof(T));
// 4. out_linear
auto out_linear_weights = ctx.MultiInput<phi::DenseTensor>("OutLinearW");
auto out_linear_biases = ctx.MultiInput<phi::DenseTensor>("OutLinearBias");
int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false);
// 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias");
Tensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}});
bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T));
}
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
// 6. ffn1 matmul + bias_add + gelu.
auto ffn1_weights = ctx.MultiInput<phi::DenseTensor>("FFN1Weight");
auto ffn1_biases = ctx.MultiInput<phi::DenseTensor>("FFN1Bias");
auto ffn1_weight_dim = ffn1_weights[0]->dims();
int dim_ffn = ffn1_weight_dim[1];
Tensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
auto ffn1_linear_bias_gelu = CublasFusedMLP<T>(dev_ctx);
const phi::DDim ffn1_input_shape({bsz_seq, dim_ffn});
ffn1_linear_bias_gelu.Setup(
ffn1_input_shape, ffn1_weight_dim, false, false);
// 8. ffn2 matmul + bias_add + residual.
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_bias_residual = CublasFusedMLP<T>(dev_ctx);
ffn2_linear_bias_residual.Setup(
ffn1_out.dims(), ffn2_weights[0]->dims(), false, false);
// 9. ffn2 residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
// calc
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
Tensor *from_tensor = out;
Tensor tmp_out;
tmp_out.Resize({{bsz, seq_len, dim_embed}});
auto *tmp_out_data =
dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>();
Tensor *buf0 = nullptr;
Tensor *buf1 = nullptr;
// step0: x --> buf1
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
}
for (int i = 0; i < layers; ++i) {
// step1. layer_norm
if (i == 0 && pre_layer_norm) {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
// TODO(wangxi): can remove mean var in inference
ln_compute.ComputeForward(x_data,
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step1";
#endif
// step2. qkv
const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr;
// NOTE: in decoder stage, bias is fused in fmha
const Tensor *bias = time_step ? nullptr : qkv_bias;
if (!pre_layer_norm && i == 0) {
qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out);
} else {
qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step2";
#endif
// step3. fmha
const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr;
Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr;
if (time_step) { // generation decoder stage
// [2, batch_size, num_head, max_seq_len, head_size]
int max_seq_len = cache_kv->dims()[3];
fmha<T>(dev_ctx,
qkv_out,
*qkv_bias,
*src_mask,
cache_kv_out,
&fmha_out,
bsz,
max_seq_len,
num_head,
dim_head,
time_step->data<int>()[0],
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
const Tensor *pre_cache_kv_tensor =
pre_caches.size() > 0 ? pre_caches[i] : nullptr;
Tensor *pre_cache_kv_out_tmp =
cache_offset > 0 ? &pre_cache_kv_out : nullptr;
Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr;
qkv_bias_add_transpose_split<T>(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
bsz,
num_head,
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
pre_cache_kv_tensor,
src_mask,
&q_transpose_out,
&kv_transpose_out,
pre_cache_kv_out_tmp,
&qk_out,
src_mask_tmp,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
const T *k_ptr = nullptr;
const T *v_ptr = nullptr;
if (cache_offset > 0) {
// [2, bsz, num_head, cache_offset + seq_len, head_dim]
const T *kv_data = pre_cache_kv_out.data<T>();
k_ptr = kv_data;
int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head;
v_ptr = k_ptr + k_size;
} else {
// [3, bsz, num_head, seq_len, head_dim]
int64_t k_size = bsz * seq_len * num_head * dim_head;
const T *q_ptr = q_transpose_out_data;
k_ptr = kv_transpose_out_data;
v_ptr = k_ptr + k_size;
}
// [2, bsz, num_head, max_seq_len, head_dim]
int max_seq_len = cache_kv_out->dims()[3];
T *cache_kv_data = cache_kv_out->data<T>();
int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head;
T *cache_k_ptr = cache_kv_data;
T *cache_v_ptr = cache_kv_data + cache_k_size;
const int seq_len_tmp = seq_len + cache_offset;
write_cache_kv<T>(dev_ctx,
cache_k_ptr,
cache_v_ptr,
k_ptr,
v_ptr,
bsz,
num_head,
seq_len_tmp,
max_seq_len,
dim_head);
} else { // not generation
// TODO(wangxi): can remove dropout in inference
qkv_bias_add_transpose_split<T>(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
bsz,
num_head,
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
cache_kv,
src_mask,
&q_transpose_out,
&kv_transpose_out,
cache_kv_out,
&qk_out,
nullptr,
&softmax_out,
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
#endif
if (pre_layer_norm) {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr);
AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx);
} else {
out_linear_compute.ComputeForward(
out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr);
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
#endif
// step5. ln(residual + dropout(input + bias))
if (pre_layer_norm) {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
// inplace
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
buf1->data<T>(),
x_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
} else {
auto *ln_scale_data = ln_scales[i]->data<U>();
auto *ln_bias_data = ln_biases[i]->data<U>();
auto *out_linear_bias_data = out_linear_biases[i]->data<T>();
auto *residual_data = (i == 0 ? x_data : buf1->data<T>());
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
buf0->data<T>(),
residual_data,
out_linear_bias_data,
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
#endif
// step6. ffn1 matmul + bias_add + gelu.
ffn1_linear_bias_gelu.ComputeForward(
buf1, ffn1_weights[i], ffn1_biases[i], nullptr, &ffn1_out, "gelu");
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step6";
#endif
// step7. ffn2 matmul + bias_add + residual.
if (pre_layer_norm) {
ffn2_linear_bias_residual.ComputeForward(&ffn1_out,
ffn2_weights[i],
ffn2_biases[i],
&bias_dropout_residual_out,
buf1,
"none");
} else {
ffn2_linear_bias_residual.ComputeForward(
&ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none");
}
if (pre_layer_norm) {
AllReduce<T>(*buf1, ring_id, buf1->numel(), dev_ctx);
} else {
AllReduce<T>(*buf0, ring_id, buf0->numel(), dev_ctx);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step7";
#endif
// step8. layer norm or do nothing(because bias_add + residual has been
// fused into cublasFusedMLP. )
if (pre_layer_norm) {
if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>();
auto *ln_bias_data = ln_biases[i + 1]->data<U>();
ffn2_fused_dropout_helper.LayerNorm(dev_ctx,
buf1->data<T>(),
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
ln_mean_data,
ln_var_data);
}
} else {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
ffn2_fused_dropout_helper.LayerNorm(dev_ctx,
buf0->data<T>(),
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step8";
#endif
if (pre_layer_norm) {
x_data = buf1->data<T>();
std::swap(buf0, buf1);
}
}
}
};
#else
template <typename T>
class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
public:
......@@ -550,6 +1061,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
}
};
#endif // CUDA_VERSION >= 11060
} // namespace operators
} // namespace paddle
......
......@@ -26,7 +26,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -37,6 +39,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
DECLARE_bool(gemm_use_half_precision_compute_type);
namespace paddle {
namespace operators {
......@@ -1336,10 +1340,10 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
constexpr int kBlockSize = 128;
constexpr int kNumWaves = 16;
const int device_id = paddle::platform::GetCurrentDeviceId();
const int sm_count = paddle::platform::GetGPUMultiProcessors(device_id);
const int device_id = phi::backends::gpu::GetCurrentDeviceId();
const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id);
const int max_thread_per_multiprocessor =
paddle::platform::GetGPUMultiProcessors(device_id);
phi::backends::gpu::GetGPUMultiProcessors(device_id);
*num_blocks =
std::max<int>(1,
......@@ -1400,6 +1404,249 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
}
}
#if CUDA_VERSION >= 11060
// Only Used in Inference
template <typename T>
class CublasFusedMLP {
public:
// (m, n, k) = bsz_seq, hidden_feature, in_feature
explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) {
// Set Math Type
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
if (FLAGS_gemm_use_half_precision_compute_type) {
compute_type = CUBLAS_COMPUTE_16F;
scale_type = CUDA_R_16F;
}
}
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F;
compute_type = CUBLAS_COMPUTE_64F;
}
// Just for init.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc_, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc_, mat_type, 1, 1, 1));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&w_desc_, mat_type, 1, 1, 1));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc_, mat_type, 1, 1, 1));
}
~CublasFusedMLP() {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(w_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_));
}
// Change to use tensor's shape.
void Setup(const phi::DDim &x_shape,
const phi::DDim &w_shape,
bool trans_x,
bool trans_w) {
int64_t M = trans_x ? x_shape[1] : x_shape[0];
int64_t K = trans_w ? w_shape[1] : w_shape[0];
int64_t N = trans_w ? w_shape[0] : w_shape[1];
cublasOperation_t cublas_transA = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cublas_transB = trans_w ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&cublas_transA,
sizeof(cublas_transA)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&cublas_transB,
sizeof(cublas_transB)));
/*
cublas use col major: x(M, K) matmul w(K, N) = out(M, N) equals to w_t(N, K)
* x_t(K, M) = out(N, M)
*/
SetCublasMatrixLayout_(x_desc_, cublas_transA, K, M);
SetCublasMatrixLayout_(w_desc_, cublas_transB, N, K);
SetCublasMatrixLayout_(out_desc_, CUBLAS_OP_N, N, M);
}
void ComputeForward(const phi::DenseTensor *input,
const phi::DenseTensor *weight,
const phi::DenseTensor *bias,
phi::DenseTensor *residual,
phi::DenseTensor *output,
const std::string &activation) {
// here: (transa, transb): nt, input * weight.
// (M * K) * (K * N)
cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(16) * 1024 * 1024;
cudaStream_t stream = dev_ctx_.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx_.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
const bool add_residual = (residual == nullptr) ? false : true;
const bool add_bias = (bias == nullptr) ? false : true;
if (add_bias) {
SetCublasBiasPtr_(bias);
}
// Set cublasLt epilogue.
cublasLtEpilogue_t epiloque_func = GetEpilogueType_(activation, add_bias);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func,
sizeof(epiloque_func)));
const auto *x_data = input->data<T>();
const auto *w_data = weight->data<T>();
auto *residual_data =
add_residual ? residual->data<T>() : output->data<T>();
auto *out_data = output->data<T>();
// if add_residual, we compute result + 1.0 * residual, else result + 0.0 *
// out.
double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0;
float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f;
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else {
alpha = &alpha32;
beta = &beta32;
}
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc_,
w_desc_,
x_desc_,
out_desc_,
alpha,
beta,
w_data,
x_data,
out_data,
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
operation_desc_,
alpha,
w_data,
w_desc_,
x_data,
x_desc_,
beta,
residual_data,
out_desc_,
out_data,
out_desc_,
algo /*algo*/,
workspace->ptr() /*workspace*/,
workspace_size,
stream));
}
private:
static cublasLtEpilogue_t GetEpilogueType_(const std::string &activation,
const bool add_bias) {
if (activation == "relu") {
if (add_bias) {
return CUBLASLT_EPILOGUE_RELU_BIAS;
} else {
return CUBLASLT_EPILOGUE_RELU;
}
} else if (activation == "gelu") {
if (add_bias) {
return CUBLASLT_EPILOGUE_GELU_BIAS;
} else {
return CUBLASLT_EPILOGUE_GELU;
}
} else if (activation == "none") {
if (add_bias) {
return CUBLASLT_EPILOGUE_BIAS;
} else {
return CUBLASLT_EPILOGUE_DEFAULT;
}
} else {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
}
void SetCublasMatrixLayout_(cublasLtMatrixLayout_t layout_desc,
cublasOperation_t cublas_trans,
const size_t cublas_m,
const size_t cublas_n) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_ROWS,
cublas_trans == CUBLAS_OP_N ? &cublas_m : &cublas_n,
sizeof(cublas_m)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_COLS,
cublas_trans == CUBLAS_OP_N ? &cublas_n : &cublas_m,
sizeof(cublas_m)));
const size_t cublas_ld = cublas_trans == CUBLAS_OP_N ? cublas_m : cublas_n;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_LD,
&cublas_ld,
sizeof(cublas_ld)));
}
void SetCublasBiasPtr_(const phi::DenseTensor *bias) {
const T *bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
}
const phi::GPUContext &dev_ctx_;
cublasLtMatmulDesc_t operation_desc_;
cublasLtMatrixLayout_t x_desc_;
cublasLtMatrixLayout_t w_desc_;
cublasLtMatrixLayout_t out_desc_;
};
#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_
} // namespace
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册