未验证 提交 57f6a469 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable matmul + bias fusion in fused_gat_attention. (#50755)

* Enable matmul + bias fusion in fused_gat_attention.

* Add a variable to control whether using fused matmul + bias.
上级 7c73910e
...@@ -14,12 +14,13 @@ limitations under the License. */ ...@@ -14,12 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -44,13 +45,43 @@ class AttnMatMul { ...@@ -44,13 +45,43 @@ class AttnMatMul {
input_size_(input_size), input_size_(input_size),
compute_bias_(compute_bias) {} compute_bias_(compute_bias) {}
~AttnMatMul() {}
void ComputeForward(const phi::DenseTensor* weight, void ComputeForward(const phi::DenseTensor* weight,
const phi::DenseTensor* input, const phi::DenseTensor* input,
const phi::DenseTensor* bias, const phi::DenseTensor* bias,
phi::DenseTensor* output, phi::DenseTensor* output,
phi::DenseTensor* bias_out) { phi::DenseTensor* bias_out,
bool fused = false) {
VLOG(6) << "input.shape={" << input->dims() << "}, weight.shape={"
<< weight->dims() << "}, output.shape={" << output->dims()
<< "}, batch_size=" << bsz_seq_ << ", output_size=" << output_size_
<< ", input_size=" << input_size_ << ", transA=" << transA_
<< ", transB=" << transB_ << ", compute_bias=" << compute_bias_
<< ", fused=" << fused;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
PADDLE_ENFORCE_EQ(
!output || output == bias_out,
true,
phi::errors::InvalidArgument(
"The output (= input * weight) is expected to be nullptr or the "
"same as bias_out when fused is true."));
ComputeFusedGemmEpilogueForward<T>(dev_ctx_,
input,
weight,
bias,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
bias_out,
nullptr);
return;
}
#endif
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight. // here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans; CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans;
...@@ -85,7 +116,29 @@ class AttnMatMul { ...@@ -85,7 +116,29 @@ class AttnMatMul {
phi::DenseTensor* d_input, phi::DenseTensor* d_input,
phi::DenseTensor* d_weight, phi::DenseTensor* d_weight,
phi::DenseTensor* d_bias, phi::DenseTensor* d_bias,
bool use_addto = false) { bool use_addto = false,
bool fused = false) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
d_output,
input,
weight,
nullptr,
bsz_seq_, // M
output_size_, // N
input_size_, // K
transA_,
transB_,
"none",
d_input,
d_weight,
d_bias,
use_addto);
return;
}
#endif
T alpha = static_cast<T>(1.0); T alpha = static_cast<T>(1.0);
T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0); T beta_dA = use_addto ? static_cast<T>(1.0) : static_cast<T>(0.0);
T beta_dB = static_cast<T>(0.0); T beta_dB = static_cast<T>(0.0);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" #include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
......
...@@ -209,7 +209,8 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, ...@@ -209,7 +209,8 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const phi::DenseTensor *query, const phi::DenseTensor *query,
const phi::DenseTensor *fmha_out, const phi::DenseTensor *fmha_out,
phi::DenseTensor *gate_out) { phi::DenseTensor *gate_bias_out,
bool use_fused_matmul_bias) {
auto *gate_weight = ctx.Input<phi::DenseTensor>("GateWeight"); auto *gate_weight = ctx.Input<phi::DenseTensor>("GateWeight");
auto *gate_bias = ctx.Input<phi::DenseTensor>("GateBias"); auto *gate_bias = ctx.Input<phi::DenseTensor>("GateBias");
...@@ -220,14 +221,18 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, ...@@ -220,14 +221,18 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.head_dim; int n = config.num_heads * config.head_dim;
int k = config.q_dim; int k = config.q_dim;
auto gate_attn_compute = auto gate_linear =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
gate_attn_compute.ComputeForward( gate_linear.ComputeForward(gate_weight,
gate_weight, query, gate_bias, gate_out, gate_out); query,
gate_bias,
gate_bias_out,
gate_bias_out,
use_fused_matmul_bias);
// gate_out = sigmoid(gate_out) * fmha_out // gate_out = sigmoid(gate_out) * fmha_out
std::vector<const phi::DenseTensor *> ins = {gate_out, fmha_out}; std::vector<const phi::DenseTensor *> ins = {gate_bias_out, fmha_out};
std::vector<phi::DenseTensor *> outs = {gate_out}; std::vector<phi::DenseTensor *> outs = {gate_bias_out};
phi::funcs::ElementwiseKernel<T>( phi::funcs::ElementwiseKernel<T>(
ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyFunctor<T>()); ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyFunctor<T>());
} }
...@@ -239,10 +244,12 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -239,10 +244,12 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
const phi::DenseTensor *fmha_out, const phi::DenseTensor *fmha_out,
const phi::DenseTensor *gate_out_grad, const phi::DenseTensor *gate_out_grad,
phi::DenseTensor *query_grad, phi::DenseTensor *query_grad,
phi::DenseTensor *fmha_out_grad) { phi::DenseTensor *fmha_out_grad,
bool use_fused_matmul_bias) {
const auto *gate_weight = ctx.Input<phi::DenseTensor>("GateWeight"); const auto *gate_weight = ctx.Input<phi::DenseTensor>("GateWeight");
const auto *gate_bias = ctx.Input<phi::DenseTensor>("GateBias"); const auto *gate_bias = ctx.Input<phi::DenseTensor>("GateBias");
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
// Re-compute gate_bias_out // Re-compute gate_bias_out
phi::DenseTensor gate_bias_out; phi::DenseTensor gate_bias_out;
gate_bias_out.Resize(config.gate_out_dims); gate_bias_out.Resize(config.gate_out_dims);
...@@ -251,10 +258,14 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -251,10 +258,14 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.head_dim; int n = config.num_heads * config.head_dim;
int k = config.q_dim; int k = config.q_dim;
auto gate_attn_compute = auto gate_linear =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
gate_attn_compute.ComputeForward( gate_linear.ComputeForward(gate_weight,
gate_weight, query, gate_bias, &gate_bias_out, &gate_bias_out); query,
gate_bias,
&gate_bias_out,
&gate_bias_out,
use_fused_matmul_bias);
// Gradient of sigmoid(gate_bias_out) * fmha_out // Gradient of sigmoid(gate_bias_out) * fmha_out
// Compute inplace and save gate_bias_out_grad to gate_bias_out. // Compute inplace and save gate_bias_out_grad to gate_bias_out.
...@@ -272,19 +283,22 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -272,19 +283,22 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
dev_ctx.Alloc<T>(gate_weight_grad, gate_weight_grad->numel() * sizeof(T)); dev_ctx.Alloc<T>(gate_weight_grad, gate_weight_grad->numel() * sizeof(T));
dev_ctx.Alloc<T>(gate_bias_grad, gate_bias_grad->numel() * sizeof(T)); dev_ctx.Alloc<T>(gate_bias_grad, gate_bias_grad->numel() * sizeof(T));
gate_attn_compute.ComputeBackward(query, gate_linear.ComputeBackward(query,
gate_weight, gate_weight,
&gate_bias_out, &gate_bias_out,
query_grad, query_grad,
gate_weight_grad, gate_weight_grad,
gate_bias_grad); gate_bias_grad,
false,
use_fused_matmul_bias);
} }
template <typename T> template <typename T>
void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const phi::DenseTensor *fmha_or_gate_out, const phi::DenseTensor *fmha_or_gate_out,
phi::DenseTensor *out) { phi::DenseTensor *out,
bool use_fused_matmul_bias) {
const auto *out_linear_weight = const auto *out_linear_weight =
ctx.Input<phi::DenseTensor>("OutLinearWeight"); ctx.Input<phi::DenseTensor>("OutLinearWeight");
const auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias"); const auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
...@@ -293,17 +307,22 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, ...@@ -293,17 +307,22 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
int k = config.num_heads * config.head_dim; int k = config.num_heads * config.head_dim;
auto out_linear_compute = auto out_linear =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeForward( out_linear.ComputeForward(out_linear_weight,
out_linear_weight, fmha_or_gate_out, out_linear_bias, out, out); fmha_or_gate_out,
out_linear_bias,
out,
out,
use_fused_matmul_bias);
} }
template <typename T> template <typename T>
void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
const phi::DenseTensor *input, const phi::DenseTensor *input,
phi::DenseTensor *input_grad) { phi::DenseTensor *input_grad,
bool use_fused_matmul_bias) {
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
const auto *out_grad = const auto *out_grad =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out")); ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
...@@ -323,14 +342,16 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -323,14 +342,16 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
int k = config.num_heads * config.head_dim; int k = config.num_heads * config.head_dim;
auto out_linear_compute = auto out_linear =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeBackward(input, out_linear.ComputeBackward(input,
out_linear_weight, out_linear_weight,
out_grad, out_grad,
input_grad, input_grad,
out_linear_weight_grad, out_linear_weight_grad,
out_linear_bias_grad); out_linear_bias_grad,
false,
use_fused_matmul_bias);
} }
template <typename T> template <typename T>
...@@ -358,6 +379,7 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -358,6 +379,7 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
const bool merge_qkv = ctx.Attr<bool>("merge_qkv"); const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating"); const bool has_gating = ctx.Attr<bool>("has_gating");
bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out); AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out); AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
...@@ -413,12 +435,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -413,12 +435,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
// 3. Gating Linear // 3. Gating Linear
if (has_gating) { if (has_gating) {
ComputeGatingLinearForward<T>(ctx, config, query, fmha_out, gate_out); ComputeGatingLinearForward<T>(
ctx, config, query, fmha_out, gate_out, use_fused_matmul_bias);
} }
// 4. Output Linear // 4. Output Linear
phi::DenseTensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out; phi::DenseTensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out;
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out, out); ComputeOutputLinearForward<T>(
ctx, config, fmha_or_gate_out, out, use_fused_matmul_bias);
} }
}; };
...@@ -454,6 +478,7 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -454,6 +478,7 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
bool has_gating = ctx.Attr<bool>("has_gating"); bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv"); bool merge_qkv = ctx.Attr<bool>("merge_qkv");
bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad); AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
...@@ -468,7 +493,8 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -468,7 +493,8 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
phi::DenseTensor gate_out_grad; phi::DenseTensor gate_out_grad;
gate_out_grad.Resize(config.gate_out_dims); gate_out_grad.Resize(config.gate_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad); AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad);
ComputeOutputLinearBackward<T>(ctx, config, gate_out, &gate_out_grad); ComputeOutputLinearBackward<T>(
ctx, config, gate_out, &gate_out_grad, use_fused_matmul_bias);
// 2. Gradient of Gating Linear // 2. Gradient of Gating Linear
// Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out // Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out
...@@ -478,10 +504,12 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -478,10 +504,12 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
fmha_out, fmha_out,
&gate_out_grad, &gate_out_grad,
query_grad, query_grad,
&fmha_out_grad); &fmha_out_grad,
use_fused_matmul_bias);
} else { } else {
// 1. Gradient of Output Linear: out = Linear(fmha_grad) // 1. Gradient of Output Linear: out = Linear(fmha_grad)
ComputeOutputLinearBackward<T>(ctx, config, fmha_out, &fmha_out_grad); ComputeOutputLinearBackward<T>(
ctx, config, fmha_out, &fmha_out_grad, use_fused_matmul_bias);
} }
// 3. Gradient of FMHA // 3. Gradient of FMHA
......
...@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and ...@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
......
...@@ -16,14 +16,14 @@ limitations under the License. */ ...@@ -16,14 +16,14 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#if CUDA_VERSION >= 11060
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> { class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
public: public:
...@@ -42,294 +42,36 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -42,294 +42,36 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
std::string activation = ctx.Attr<std::string>("activation"); std::string activation = ctx.Attr<std::string>("activation");
VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
<< " , activation = " << activation;
bool enable_auxiliary = reserve_space == nullptr ? false : true;
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T)); dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
auto* out_data = out->data<T>();
// (M * K) * (K * N)
auto x_mat_dims = auto x_mat_dims =
phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
// (M * K) * (K * N)
int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F; ComputeFusedGemmEpilogueForward<T>(dev_ctx,
cudaDataType_t scale_type = CUDA_R_32F; x,
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; y,
if (std::is_same<T, paddle::platform::float16>::value) { bias,
mat_type = CUDA_R_16F; M,
} N,
if (std::is_same<T, platform::bfloat16>::value) { K,
mat_type = CUDA_R_16BF; trans_x,
} trans_y,
if (std::is_same<T, double>::value) { activation,
mat_type = CUDA_R_64F; out,
scale_type = CUDA_R_64F; reserve_space);
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtMatmulDesc_t operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc, compute_type, scale_type));
cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&transx,
sizeof(transx)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&transy,
sizeof(transy)));
cublasLtEpilogue_t epiloque_func =
get_epilogue_type_(activation, enable_auxiliary);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func,
sizeof(epiloque_func)));
const T* bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
if (enable_auxiliary && activation != "none") {
// Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") {
paddle::experimental::DataType rs_type =
paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else {
size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
}
void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
sizeof(aux_ld)));
} }
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
if (trans_x)
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, M, K, M));
else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
if (trans_y)
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, K, N, K));
else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f;
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else {
alpha = &alpha32;
beta = &beta32;
}
const auto* y_data = y->data<T>();
const auto* x_data = x->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc,
y_desc,
x_desc,
out_desc,
alpha,
beta,
y_data,
x_data,
out_data,
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
operation_desc,
alpha,
y_data,
y_desc,
x_data,
x_desc,
beta,
out_data,
out_desc,
out_data,
out_desc,
algo,
workspace->ptr(),
workspace_size,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
}
private:
static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation,
bool enable_auxiliary) {
if (activation == "relu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
: CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == "gelu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
: CUBLASLT_EPILOGUE_GELU_BIAS;
} else if (activation == "none") {
return CUBLASLT_EPILOGUE_BIAS;
} else {
PADDLE_ENFORCE_EQ(
true,
false,
platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
}
};
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;
template <>
struct FusedGEMMGradTrait<false, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = false;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<false, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = false;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = true;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = true;
}; };
static constexpr auto BoolToCuBlasEnum(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
if (transpose_x) {
if (transpose_y) {
ComputeImpl<true, true>(ctx);
} else {
ComputeImpl<true, false>(ctx);
}
} else {
if (transpose_y) {
ComputeImpl<false, true>(ctx);
} else {
ComputeImpl<false, false>(ctx);
}
}
}
private:
template <bool TransX, bool TransY>
static void ComputeImpl(const framework::ExecutionContext& ctx) {
using Trait = FusedGEMMGradTrait<TransX, TransY>;
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
const phi::DenseTensor* dout = ctx.Input<phi::DenseTensor>("DOut"); const phi::DenseTensor* dout = ctx.Input<phi::DenseTensor>("DOut");
const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X"); const phi::DenseTensor* x = ctx.Input<phi::DenseTensor>("X");
const phi::DenseTensor* y = ctx.Input<phi::DenseTensor>("Y"); const phi::DenseTensor* y = ctx.Input<phi::DenseTensor>("Y");
...@@ -342,352 +84,33 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -342,352 +84,33 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
std::string activation_grad = ctx.Attr<std::string>("activation_grad"); std::string activation_grad = ctx.Attr<std::string>("activation_grad");
VLOG(10) << "trans_x = " << TransX << " , trans_y = " << TransY bool trans_x = ctx.Attr<bool>("trans_x");
<< " , activation_grad = " << activation_grad; bool trans_y = ctx.Attr<bool>("trans_y");
auto x_mat_dims =
phi::flatten_to_2d(x->dims(), TransX ? 1 : x->dims().size() - 1);
// (M * K) * (K * N) // (M * K) * (K * N)
int64_t M = TransX ? x_mat_dims[1] : x_mat_dims[0]; auto x_mat_dims =
int64_t K = TransY ? y->dims()[1] : y->dims()[0]; phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
int64_t N = TransY ? y->dims()[0] : y->dims()[1]; int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
VLOG(10) << "M = " << M << " , K = " << K << " , N = " << N; int64_t N = trans_y ? y->dims()[0] : y->dims()[1];
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
}
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F;
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f;
void *alpha = nullptr, *beta = nullptr;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else {
alpha = &alpha32;
beta = &beta32;
}
cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr;
cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr;
cublasLtMatmulDesc_t dx_operation_desc = nullptr,
dy_operation_desc = nullptr;
DEFINE_PADDLE_SCOPE_GUARD([&] {
auto descs = {dout_desc,
dout_trans_desc,
x_desc,
x_trans_desc,
y_desc,
y_trans_desc,
dx_desc,
dy_desc};
for (auto desc : descs) {
if (desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(desc));
}
}
if (dx_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
}
if (dy_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
}
});
auto x_row = TransX ? K : M;
auto x_col = TransX ? M : K;
auto y_row = TransY ? N : K;
auto y_col = TransY ? K : N;
auto z_row = TransX ? N : M;
auto z_col = TransX ? M : N;
// dx = func(dout, y)
if (dx) {
constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc;
if (TransX) {
dx_dout_desc = &dout_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_row, z_col, z_row));
} else {
dx_dout_desc = &dout_desc;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_col, z_row, z_col));
}
dx_y_desc = &y_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dx_y_desc, mat_type, y_col, y_row, y_col));
auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc);
auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, x_col, x_row, x_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&a_trans,
sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&b_trans,
sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dx =
get_epilogue_type_(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dx,
sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") {
auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = TransX ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
sizeof(aux_ld)));
}
auto dx_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
const auto* b_data = kXGradAIsDZ ? y_data : dout_data;
auto algo =
GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
dx_operation_desc,
b_desc,
a_desc,
dx_desc,
alpha,
beta,
b_data,
a_data,
dx_data,
stream,
dx_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
dx_operation_desc,
alpha,
b_data,
b_desc,
a_data,
a_desc,
beta,
dx_data,
dx_desc,
dx_data,
dx_desc,
algo,
dx_workspace->ptr(),
workspace_size,
stream));
}
// dy = func(dout, x)
if (dy) {
constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);
cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr;
if (TransX) {
dy_dout_desc = &dout_trans_desc;
if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_row, z_col, z_row));
}
} else {
dy_dout_desc = &dout_desc;
if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_col, z_row, z_col));
}
}
dy_x_desc = &x_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dy_x_desc, mat_type, x_col, x_row, x_col));
auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc);
auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, y_col, y_row, y_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&a_trans,
sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&b_trans,
sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dy;
if (dbias == nullptr) {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT;
} else {
if (TransY) {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB;
} else {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA;
}
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dy,
sizeof(epiloque_func_for_dy)));
if (dbias) {
auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&dbias_data,
sizeof(dbias_data)));
}
auto dy_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
const auto* b_data = kYGradAIsDZ ? x_data : dout_data;
auto algo =
GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
dy_operation_desc,
b_desc,
a_desc,
dy_desc,
alpha,
beta,
b_data,
a_data,
dy_data,
stream,
dy_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
dy_operation_desc,
alpha,
b_data,
b_desc,
a_data,
a_desc,
beta,
dy_data,
dy_desc,
dy_data,
dy_desc,
algo,
dy_workspace->ptr(),
workspace_size,
stream));
}
}
private: ComputeFusedGemmEpilogueBackward<T>(dev_ctx,
static cublasLtEpilogue_t get_epilogue_type_( dout,
const std::string& activation_grad) { x,
if (activation_grad == "relu_grad") { y,
return CUBLASLT_EPILOGUE_DRELU; reserve_space,
} else if (activation_grad == "gelu_grad") { M,
return CUBLASLT_EPILOGUE_DGELU; N,
} else if (activation_grad == "none") { K,
return CUBLASLT_EPILOGUE_DEFAULT; trans_x,
} else { trans_y,
PADDLE_ENFORCE_EQ( activation_grad,
true, dx,
false, dy,
platform::errors::InvalidArgument( dbias);
"The activation_grad attribute of fused_gemm_epilogue op should "
"be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation_grad=%s.",
activation_grad));
}
} }
}; };
#endif
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -15,21 +15,26 @@ limitations under the License. */ ...@@ -15,21 +15,26 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <mutex>
#include <unordered_map>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h> // NOLINT
#include "cuda.h" // NOLINT #include "cuda.h" // NOLINT
#if CUDA_VERSION >= 11060 #if CUDA_VERSION >= 11060
#include <algorithm>
#include <mutex>
#include <unordered_map>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
DECLARE_int64(cublaslt_exhaustive_search_times); DECLARE_int64(cublaslt_exhaustive_search_times);
...@@ -39,27 +44,27 @@ namespace operators { ...@@ -39,27 +44,27 @@ namespace operators {
class GemmEpilogueAlgoCache { class GemmEpilogueAlgoCache {
public: public:
static GemmEpilogueAlgoCache &Instance() { static GemmEpilogueAlgoCache& Instance() {
static GemmEpilogueAlgoCache instance( static GemmEpilogueAlgoCache instance(
FLAGS_cublaslt_exhaustive_search_times); FLAGS_cublaslt_exhaustive_search_times);
return instance; return instance;
} }
GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete; GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const&) = delete;
void operator=(GemmEpilogueAlgoCache const &) = delete; void operator=(GemmEpilogueAlgoCache const&) = delete;
cublasLtMatmulAlgo_t *GetGemmAlgo(cublasLtHandle_t lt_handle, cublasLtMatmulAlgo_t* GetGemmAlgo(cublasLtHandle_t lt_handle,
cublasLtMatmulDesc_t op_desc, cublasLtMatmulDesc_t op_desc,
cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t a_desc,
cublasLtMatrixLayout_t b_desc, cublasLtMatrixLayout_t b_desc,
cublasLtMatrixLayout_t c_desc, cublasLtMatrixLayout_t c_desc,
const void *alpha, const void* alpha,
const void *beta, const void* beta,
const void *a, const void* a,
const void *b, const void* b,
void *c, void* c,
cudaStream_t stream, cudaStream_t stream,
void *workspace, void* workspace,
size_t workspace_size) { size_t workspace_size) {
if (search_times_ <= 0) return nullptr; if (search_times_ <= 0) return nullptr;
...@@ -207,7 +212,7 @@ class GemmEpilogueAlgoCache { ...@@ -207,7 +212,7 @@ class GemmEpilogueAlgoCache {
<< ") not found in GemmEpilogueAlgoCache"; << ") not found in GemmEpilogueAlgoCache";
std::lock_guard<std::mutex> lock(cache_mutex_); std::lock_guard<std::mutex> lock(cache_mutex_);
auto &algo_in_map = map_[seed]; auto& algo_in_map = map_[seed];
algo_in_map = ret; algo_in_map = ret;
return &algo_in_map; return &algo_in_map;
} }
...@@ -223,8 +228,8 @@ class GemmEpilogueAlgoCache { ...@@ -223,8 +228,8 @@ class GemmEpilogueAlgoCache {
std::mutex cache_mutex_; std::mutex cache_mutex_;
void HashMatmulDesc_(cublasLtMatmulDesc_t desc, void HashMatmulDesc_(cublasLtMatmulDesc_t desc,
int64_t *seed, int64_t* seed,
const std::hash<int64_t> &hash_fn) { const std::hash<int64_t>& hash_fn) {
size_t size_to_write; size_t size_to_write;
int trans_a, trans_b; int trans_a, trans_b;
uint32_t epilogue; uint32_t epilogue;
...@@ -258,8 +263,8 @@ class GemmEpilogueAlgoCache { ...@@ -258,8 +263,8 @@ class GemmEpilogueAlgoCache {
} }
void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc,
int64_t *seed, int64_t* seed,
const std::hash<int64_t> &hash_fn) { const std::hash<int64_t>& hash_fn) {
size_t size_to_write; size_t size_to_write;
uint32_t dtype; uint32_t dtype;
int32_t batch; int32_t batch;
...@@ -317,15 +322,665 @@ class GemmEpilogueAlgoCache { ...@@ -317,15 +322,665 @@ class GemmEpilogueAlgoCache {
HashValue_(seed, hash_fn, static_cast<int64_t>(batch_offset)); HashValue_(seed, hash_fn, static_cast<int64_t>(batch_offset));
} }
void HashValue_(int64_t *seed, void HashValue_(int64_t* seed,
const std::hash<int64_t> &hash_fn, const std::hash<int64_t>& hash_fn,
int64_t value) { int64_t value) {
*seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
} }
}; };
static cublasLtEpilogue_t GetEpilogueType(const std::string& activation,
bool enable_auxiliary) {
if (activation == "relu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
: CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == "gelu") {
return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
: CUBLASLT_EPILOGUE_GELU_BIAS;
} else if (activation == "none") {
return CUBLASLT_EPILOGUE_BIAS;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
activation));
}
}
template <typename T>
void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
const phi::DenseTensor* bias,
int64_t M,
int64_t N,
int64_t K,
bool trans_x,
bool trans_y,
const std::string& activation,
phi::DenseTensor* out,
phi::DenseTensor* reserve_space) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims()
<< "}, out.shape={" << out->dims() << "}, M=" << M << ", N=" << N
<< ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y
<< ", activation=" << activation
<< ", reserve_space=" << reserve_space;
bool enable_auxiliary = reserve_space == nullptr ? false : true;
auto* out_data = out->data<T>();
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, double>::value) {
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtMatmulDesc_t operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&operation_desc, compute_type, scale_type));
cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx)));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy)));
cublasLtEpilogue_t epiloque_func =
GetEpilogueType(activation, enable_auxiliary);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func,
sizeof(epiloque_func)));
const T* bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
sizeof(bias_data)));
if (enable_auxiliary && activation != "none") {
// Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") {
paddle::experimental::DataType rs_type =
paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else {
size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
}
void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
sizeof(aux_ld)));
}
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
if (trans_x) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, M, K, M));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
}
if (trans_y) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, K, N, K));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
MT alpha = static_cast<MT>(1);
MT beta = static_cast<MT>(0);
const auto* y_data = y->data<T>();
const auto* x_data = x->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
operation_desc,
y_desc,
x_desc,
out_desc,
&alpha,
&beta,
y_data,
x_data,
out_data,
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(lt_handle,
operation_desc,
&alpha,
y_data,
y_desc,
x_data,
x_desc,
&beta,
out_data,
out_desc,
out_data,
out_desc,
algo,
workspace->ptr(),
workspace_size,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
}
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY>
struct FusedGEMMGradTrait;
template <>
struct FusedGEMMGradTrait<false, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, false> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDX;
static constexpr auto kYGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradATrans = false;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<false, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradB = FusedGEMMGradInType::kDY;
static constexpr auto kXGradATrans = false;
static constexpr auto kXGradBTrans = false;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = false;
};
template <>
struct FusedGEMMGradTrait<true, true> {
static constexpr auto kXGradA = FusedGEMMGradInType::kDY;
static constexpr auto kXGradB = FusedGEMMGradInType::kDZ;
static constexpr auto kXGradATrans = true;
static constexpr auto kXGradBTrans = true;
static constexpr auto kYGradA = FusedGEMMGradInType::kDZ;
static constexpr auto kYGradB = FusedGEMMGradInType::kDX;
static constexpr auto kYGradATrans = true;
static constexpr auto kYGradBTrans = true;
};
static constexpr auto BoolToCuBlasEnum(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}
static cublasLtEpilogue_t GetEpilogueGradType(
const std::string& activation_grad) {
if (activation_grad == "relu_grad") {
return CUBLASLT_EPILOGUE_DRELU;
} else if (activation_grad == "gelu_grad") {
return CUBLASLT_EPILOGUE_DGELU;
} else if (activation_grad == "none") {
return CUBLASLT_EPILOGUE_DEFAULT;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The activation_grad attribute of fused_gemm_epilogue op should "
"be one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation_grad=%s.",
activation_grad));
}
}
template <typename T, bool TransX, bool TransY>
void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
const phi::DenseTensor* reserve_space,
int64_t M,
int64_t N,
int64_t K,
const std::string activation_grad,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
using Trait = FusedGEMMGradTrait<TransX, TransY>;
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, double>::value) {
compute_type = CUBLAS_COMPUTE_64F;
}
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
MT alpha = static_cast<MT>(1.0);
MT beta_dx = use_addto ? static_cast<MT>(1.0) : static_cast<MT>(0.0);
MT beta_dy = static_cast<MT>(0.0);
cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr;
cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr;
cublasLtMatmulDesc_t dx_operation_desc = nullptr, dy_operation_desc = nullptr;
DEFINE_PADDLE_SCOPE_GUARD([&] {
auto descs = {dout_desc,
dout_trans_desc,
x_desc,
x_trans_desc,
y_desc,
y_trans_desc,
dx_desc,
dy_desc};
for (auto desc : descs) {
if (desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(desc));
}
}
if (dx_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
}
if (dy_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
}
});
auto x_row = TransX ? K : M;
auto x_col = TransX ? M : K;
auto y_row = TransY ? N : K;
auto y_col = TransY ? K : N;
auto z_row = TransX ? N : M;
auto z_col = TransX ? M : N;
// dx = func(dout, y)
if (dx) {
constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ);
cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc;
if (TransX) {
dx_dout_desc = &dout_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_row, z_col, z_row));
} else {
dx_dout_desc = &dout_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_col, z_row, z_col));
}
dx_y_desc = &y_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dx_y_desc, mat_type, y_col, y_row, y_col));
auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc);
auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, x_col, x_row, x_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&a_trans,
sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&b_trans,
sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dx =
GetEpilogueGradType(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dx,
sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") {
auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = TransX ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
sizeof(aux_ld)));
}
auto dx_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
const auto* b_data = kXGradAIsDZ ? y_data : dout_data;
auto algo =
GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
dx_operation_desc,
b_desc,
a_desc,
dx_desc,
&alpha,
&beta_dx,
b_data,
a_data,
dx_data,
stream,
dx_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
dx_operation_desc,
&alpha,
b_data,
b_desc,
a_data,
a_desc,
&beta_dx,
dx_data,
dx_desc,
dx_data,
dx_desc,
algo,
dx_workspace->ptr(),
workspace_size,
stream));
}
// dy = func(dout, x)
if (dy) {
constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ);
cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr;
if (TransX) {
dy_dout_desc = &dout_trans_desc;
if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_row, z_col, z_row));
}
} else {
dy_dout_desc = &dout_desc;
if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_col, z_row, z_col));
}
}
dy_x_desc = &x_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
dy_x_desc, mat_type, x_col, x_row, x_col));
auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc);
auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, y_col, y_row, y_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&a_trans,
sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&b_trans,
sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dy;
if (dbias == nullptr) {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT;
} else {
if (TransY) {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB;
} else {
epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA;
}
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dy,
sizeof(epiloque_func_for_dy)));
if (dbias) {
auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&dbias_data,
sizeof(dbias_data)));
}
auto dy_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
const auto* b_data = kYGradAIsDZ ? x_data : dout_data;
auto algo =
GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
dy_operation_desc,
b_desc,
a_desc,
dy_desc,
&alpha,
&beta_dy,
b_data,
a_data,
dy_data,
stream,
dy_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
dy_operation_desc,
&alpha,
b_data,
b_desc,
a_data,
a_desc,
&beta_dy,
dy_data,
dy_desc,
dy_data,
dy_desc,
algo,
dy_workspace->ptr(),
workspace_size,
stream));
}
}
template <typename T>
void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
const phi::DenseTensor* y,
const phi::DenseTensor* reserve_space,
int64_t M,
int64_t N,
int64_t K,
bool trans_x,
bool trans_y,
const std::string& activation_grad,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto = false) {
VLOG(10) << "M=" << M << ", K=" << K << ", N=" << N << ", trans_x=" << trans_x
<< ", trans_y=" << trans_y
<< ", activation_grad=" << activation_grad;
if (trans_x) {
if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, true, true>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
} else {
ComputeFusedGemmEpilogueBackwardImpl<T, true, false>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
}
} else {
if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, false, true>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
} else {
ComputeFusedGemmEpilogueBackwardImpl<T, false, false>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif #endif
#endif #endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// This file has been adapted from FasterTransformer file: // This file has been adapted from FasterTransformer file:
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu // https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head. // We add License in the head.
......
...@@ -14,6 +14,12 @@ ...@@ -14,6 +14,12 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> // NOLINT
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
namespace phi { namespace phi {
namespace backends { namespace backends {
namespace gpu { namespace gpu {
...@@ -69,6 +75,22 @@ namespace gpu { ...@@ -69,6 +75,22 @@ namespace gpu {
for (index_type i = __index__; __index__ < (num); \ for (index_type i = __index__; __index__ < (num); \
__index__ += __stride__, i = __index__) __index__ += __stride__, i = __index__)
template <typename T>
cudaDataType_t ToCudaDataType() {
if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else if (std::is_same<T, double>::value) {
return CUDA_R_64F;
} else if (std::is_same<T, phi::dtype::float16>::value) {
return CUDA_R_16F;
#if CUDA_VERSION >= 11000
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
#endif
}
}
} // namespace gpu } // namespace gpu
} // namespace backends } // namespace backends
} // namespace phi } // namespace phi
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册