未验证 提交 b07e6b45 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Use cublaslt in multi transformer FFN (#48052)

* use fused mlp in multi transformer
* Restruct code
* use cublaslt to fuse ffn
* fix conflict
上级 78b30e97
......@@ -26,7 +26,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -37,6 +39,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
DECLARE_bool(gemm_use_half_precision_compute_type);
namespace paddle {
namespace operators {
......@@ -1336,10 +1340,10 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
constexpr int kBlockSize = 128;
constexpr int kNumWaves = 16;
const int device_id = paddle::platform::GetCurrentDeviceId();
const int sm_count = paddle::platform::GetGPUMultiProcessors(device_id);
const int device_id = phi::backends::gpu::GetCurrentDeviceId();
const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id);
const int max_thread_per_multiprocessor =
paddle::platform::GetGPUMultiProcessors(device_id);
phi::backends::gpu::GetGPUMultiProcessors(device_id);
*num_blocks =
std::max<int>(1,
......@@ -1400,6 +1404,249 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
}
}
#if CUDA_VERSION >= 11060
// Only Used in Inference
template <typename T>
class CublasFusedMLP {
public:
// (m, n, k) = bsz_seq, hidden_feature, in_feature
explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) {
// Set Math Type
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
if (FLAGS_gemm_use_half_precision_compute_type) {
compute_type = CUBLAS_COMPUTE_16F;
scale_type = CUDA_R_16F;
}
}
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F;
compute_type = CUBLAS_COMPUTE_64F;
}
// Just for init.
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc_, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc_, mat_type, 1, 1, 1));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&w_desc_, mat_type, 1, 1, 1));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc_, mat_type, 1, 1, 1));
}
~CublasFusedMLP() {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(w_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_));
}
// Change to use tensor's shape.
void Setup(const phi::DDim &x_shape,
const phi::DDim &w_shape,
bool trans_x,
bool trans_w) {
int64_t M = trans_x ? x_shape[1] : x_shape[0];
int64_t K = trans_w ? w_shape[1] : w_shape[0];
int64_t N = trans_w ? w_shape[0] : w_shape[1];
cublasOperation_t cublas_transA = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cublas_transB = trans_w ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&cublas_transA,
sizeof(cublas_transA)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&cublas_transB,
sizeof(cublas_transB)));
/*
cublas use col major: x(M, K) matmul w(K, N) = out(M, N) equals to w_t(N, K)
* x_t(K, M) = out(N, M)
*/
SetCublasMatrixLayout_(x_desc_, cublas_transA, K, M);
SetCublasMatrixLayout_(w_desc_, cublas_transB, N, K);
SetCublasMatrixLayout_(out_desc_, CUBLAS_OP_N, N, M);
}
void ComputeForward(const phi::DenseTensor *input,
const phi::DenseTensor *weight,
const phi::DenseTensor *bias,
phi::DenseTensor *residual,
phi::DenseTensor *output,
const std::string &activation) {
// here: (transa, transb): nt, input * weight.
// (M * K) * (K * N)
cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(16) * 1024 * 1024;
cudaStream_t stream = dev_ctx_.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx_.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
const bool add_residual = (residual == nullptr) ? false : true;
const bool add_bias = (bias == nullptr) ? false : true;
if (add_bias) {
SetCublasBiasPtr_(bias);
}
// Set cublasLt epilogue.
cublasLtEpilogue_t epiloque_func = GetEpilogueType_(activation, add_bias);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func,
sizeof(epiloque_func)));
const auto *x_data = input->data<T>();
const auto *w_data = weight->data<T>();
auto *residual_data =
add_residual ? residual->data<T>() : output->data<T>();
auto *out_data = output->data<T>();
// if add_residual, we compute result + 1.0 * residual, else result + 0.0 *
// out.
double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0;
float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f;
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else {
alpha = &alpha32;
beta = &beta32;
}
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc_,
w_desc_,
x_desc_,
out_desc_,
alpha,
beta,
w_data,
x_data,
out_data,
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
operation_desc_,
alpha,
w_data,
w_desc_,
x_data,
x_desc_,
beta,
residual_data,
out_desc_,
out_data,
out_desc_,
algo /*algo*/,
workspace->ptr() /*workspace*/,
workspace_size,
stream));
}
private:
static cublasLtEpilogue_t GetEpilogueType_(const std::string &activation,
const bool add_bias) {
if (activation == "relu") {
if (add_bias) {
return CUBLASLT_EPILOGUE_RELU_BIAS;
} else {
return CUBLASLT_EPILOGUE_RELU;
}
} else if (activation == "gelu") {
if (add_bias) {
return CUBLASLT_EPILOGUE_GELU_BIAS;
} else {
return CUBLASLT_EPILOGUE_GELU;
}
} else if (activation == "none") {
if (add_bias) {
return CUBLASLT_EPILOGUE_BIAS;
} else {
return CUBLASLT_EPILOGUE_DEFAULT;
}
} else {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
}
void SetCublasMatrixLayout_(cublasLtMatrixLayout_t layout_desc,
cublasOperation_t cublas_trans,
const size_t cublas_m,
const size_t cublas_n) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_ROWS,
cublas_trans == CUBLAS_OP_N ? &cublas_m : &cublas_n,
sizeof(cublas_m)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_COLS,
cublas_trans == CUBLAS_OP_N ? &cublas_n : &cublas_m,
sizeof(cublas_m)));
const size_t cublas_ld = cublas_trans == CUBLAS_OP_N ? cublas_m : cublas_n;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutSetAttribute(
layout_desc,
CUBLASLT_MATRIX_LAYOUT_LD,
&cublas_ld,
sizeof(cublas_ld)));
}
void SetCublasBiasPtr_(const phi::DenseTensor *bias) {
const T *bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
}
const phi::GPUContext &dev_ctx_;
cublasLtMatmulDesc_t operation_desc_;
cublasLtMatrixLayout_t x_desc_;
cublasLtMatrixLayout_t w_desc_;
cublasLtMatrixLayout_t out_desc_;
};
#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_
} // namespace
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册