diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index cda33987d68ac757bfb60a9506728180e8553f49..d7ed144f02de75ef6cab9322e05c6ed84838f3dd 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { @@ -27,19 +27,29 @@ namespace operators { using Tensor = framework::Tensor; inline std::string MemoryDebugString(const Tensor& t) { + int device_id = platform::GetCurrentDeviceId(); + int64_t allocated = + memory::DeviceMemoryStatCurrentValue("Allocated", device_id); + int64_t reserved = + memory::DeviceMemoryStatCurrentValue("Reserved", device_id); + std::stringstream ss; ss << "shape=[" << t.dims() << "], size=" << static_cast(t.memory_size()) / (1 << 20) - << " MB, ptr=" << t.data(); - - size_t total = 0; - size_t available = 0; - platform::GpuMemoryUsage(&available, &total); - ss << "; memory allocated=" - << static_cast(total - available) / (1 << 20) << " MB"; + << " MB, ptr=" << t.data() + << "; [MEMORY] allocated=" << static_cast(allocated) / (1 << 20) + << " MB" + << ", reserved=" << static_cast(reserved) / (1 << 20) << " MB"; return ss.str(); } +template +void AllocWithDebugInfo(const platform::CUDADeviceContext& dev_ctx, + const std::string& info, Tensor* t) { + t->mutable_data(dev_ctx.GetPlace()); + VLOG(4) << info << ": " << MemoryDebugString(*t); +} + template struct TernaryAddFunctor { inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; } @@ -48,6 +58,11 @@ struct TernaryAddFunctor { template struct GateAttentionConfig { public: + const platform::CUDADeviceContext& dev_ctx; + + bool merge_qkv; + bool has_gating; + int64_t batch_size; int64_t seq_len_m; int64_t seq_len_r; @@ -70,9 +85,11 @@ struct GateAttentionConfig { phi::DDim qktv_out_dims; phi::DDim gate_out_dims; - GateAttentionConfig(const Tensor* query, const Tensor* key, + GateAttentionConfig(const platform::CUDADeviceContext& dev_ctx, + const Tensor* query, const Tensor* key, const Tensor* query_weight, const Tensor* qkv_weight, - bool merge_qkv) { + bool merge_qkv, bool has_gating) + : dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) { // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] batch_size = query->dims()[0]; seq_len_m = query->dims()[1]; @@ -131,59 +148,68 @@ struct GateAttentionConfig { return batch_size * seq_len_m * seq_len_r * num_heads * key_dim; } - Tensor* GetQKVOut(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetQKVOut() { if (!qkv_out.IsInitialized()) { qkv_out.Resize(qkv_out_dims); - qkv_out.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "qkv_out: " << MemoryDebugString(qkv_out); + AllocWithDebugInfo(dev_ctx, "qkv_out", &qkv_out); } return &qkv_out; } - Tensor* GetQueryOut(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetQueryOut() { if (!query_out.IsInitialized()) { query_out.Resize(q_out_dims); - query_out.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "query_out: " << MemoryDebugString(query_out); + AllocWithDebugInfo(dev_ctx, "query_out", &query_out); } return &query_out; } - Tensor* GetKeyOut(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetKeyOut() { if (!key_out.IsInitialized()) { key_out.Resize(kv_out_dims); - key_out.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "key_out: " << MemoryDebugString(key_out); + AllocWithDebugInfo(dev_ctx, "key_out", &key_out); } return &key_out; } - Tensor* GetValueOut(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetValueOut() { if (!value_out.IsInitialized()) { value_out.Resize(kv_out_dims); - value_out.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "value_out: " << MemoryDebugString(value_out); + AllocWithDebugInfo(dev_ctx, "value_out", &value_out); } return &value_out; } - Tensor* GetQKOut(const platform::CUDADeviceContext& dev_ctx, - Tensor* softmax_out) { + Tensor* GetQKOut(Tensor* softmax_out) { // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1] int softmax_dim = m_size; if (!softmax_out || phi::UseCudnnSoftmax(dev_ctx, softmax_dim, true)) { // Not sure whether cudnn softmax can execute inplace. if (!qkv_out.IsInitialized()) { qk_out.Resize(qk_out_dims); - qk_out.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "qk_out: " << MemoryDebugString(qk_out); + AllocWithDebugInfo(dev_ctx, "qk_out", &qk_out); } return &qk_out; } else { + // Enable inplace softmax. return softmax_out; } } + Tensor* GetQKTVOut(Tensor* gate_out) { + if (has_gating && gate_out) { + // Reuse gate_out. + gate_out->Resize(qktv_out_dims); + return gate_out; + } else { + if (!qktv_out.IsInitialized()) { + qktv_out.Resize(qktv_out_dims); + AllocWithDebugInfo(dev_ctx, "qktv_out", &qktv_out); + } + return &qktv_out; + } + } + void ClearQKVOut() { if (qkv_out.IsInitialized()) { qkv_out.clear(); @@ -196,9 +222,14 @@ struct GateAttentionConfig { } } + void ClearQKTVOut() { + if (qktv_out.IsInitialized()) { + qktv_out.clear(); + } + } + protected: Tensor qkv_out; - // QKV is not merged Tensor query_out; Tensor key_out; Tensor value_out; @@ -207,63 +238,60 @@ struct GateAttentionConfig { // softmax_out = softmax(qk_out + nonbatched_bias + src_mask) // The shape of qk_out, softmax_out is the same, thus can be called inplace. Tensor qk_out; + // qktv_out may reuse gate_out. + Tensor qktv_out; }; template struct GateAttentionGradConfig : public GateAttentionConfig { public: - GateAttentionGradConfig(const Tensor* query, const Tensor* key, + GateAttentionGradConfig(const platform::CUDADeviceContext& dev_ctx, + const Tensor* query, const Tensor* key, const Tensor* query_weight, const Tensor* qkv_weight, - bool merge_qkv) - : GateAttentionConfig(query, key, query_weight, qkv_weight, - merge_qkv) {} + bool merge_qkv, bool has_gating) + : GateAttentionConfig(dev_ctx, query, key, query_weight, qkv_weight, + merge_qkv, has_gating) {} - Tensor* GetQKVOutGrad(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetQKVOutGrad() { if (!qkv_out_grad.IsInitialized()) { qkv_out_grad.Resize(this->qkv_out_dims); - qkv_out_grad.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "qkv_out_grad: " << MemoryDebugString(qkv_out_grad); + AllocWithDebugInfo(this->dev_ctx, "qkv_out_grad", &qkv_out_grad); } return &qkv_out_grad; } - Tensor* GetQueryOutGrad(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetQueryOutGrad() { if (!query_out_grad.IsInitialized()) { query_out_grad.Resize(this->q_out_dims); - query_out_grad.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "query_out_grad: " << MemoryDebugString(query_out_grad); + AllocWithDebugInfo(this->dev_ctx, "query_out_grad", &query_out_grad); } return &query_out_grad; } - Tensor* GetKeyOutGrad(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetKeyOutGrad() { if (!key_out_grad.IsInitialized()) { key_out_grad.Resize(this->kv_out_dims); - key_out_grad.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "key_out_grad: " << MemoryDebugString(key_out_grad); + AllocWithDebugInfo(this->dev_ctx, "key_out_grad", &key_out_grad); } return &key_out_grad; } - Tensor* GetValueOutGrad(const platform::CUDADeviceContext& dev_ctx) { + Tensor* GetValueOutGrad() { if (!value_out_grad.IsInitialized()) { value_out_grad.Resize(this->kv_out_dims); - value_out_grad.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "value_out_grad: " << MemoryDebugString(value_out_grad); + AllocWithDebugInfo(this->dev_ctx, "value_out_grad", &value_out_grad); } return &value_out_grad; } - Tensor* GetQKOutGrad(const platform::CUDADeviceContext& dev_ctx, - Tensor* softmax_out_grad) { + Tensor* GetQKOutGrad(Tensor* softmax_out_grad) { // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1] int softmax_dim = this->m_size; if (!softmax_out_grad || - phi::UseCudnnSoftmax(dev_ctx, softmax_dim, true)) { + phi::UseCudnnSoftmax(this->dev_ctx, softmax_dim, true)) { if (!qk_out_grad.IsInitialized()) { qk_out_grad.Resize(this->qk_out_dims); - qk_out_grad.mutable_data(dev_ctx.GetPlace()); - VLOG(4) << "qk_out_grad: " << MemoryDebugString(qk_out_grad); + AllocWithDebugInfo(this->dev_ctx, "qk_out_grad", &qk_out_grad); } return &qk_out_grad; } else { @@ -288,7 +316,7 @@ class FMHAGateRef { void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask, Tensor* q_transpose_out, Tensor* k_transpose_out, Tensor* v_transpose_out, Tensor* qkv_transpose_out, - Tensor* softmax_out, Tensor* fmha_out, + Tensor* softmax_out, Tensor* fmha_out, Tensor* gate_out, GateAttentionConfig* config) { T* q_ptr = nullptr; T* k_ptr = nullptr; @@ -300,7 +328,7 @@ class FMHAGateRef { platform::errors::NotFound("The input qkv_transpose_out can not be " "nullptr when merge_qkv is true.")); - Tensor* qkv_out = config->GetQKVOut(dev_ctx_); + Tensor* qkv_out = config->GetQKVOut(); ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out); config->ClearQKVOut(); @@ -323,9 +351,9 @@ class FMHAGateRef { platform::errors::NotFound("The input v_transpose_out can not be " "nullptr when merge_qkv is false.")); - Tensor* query_out = config->GetQueryOut(dev_ctx_); - Tensor* key_out = config->GetKeyOut(dev_ctx_); - Tensor* value_out = config->GetValueOut(dev_ctx_); + Tensor* query_out = config->GetQueryOut(); + Tensor* key_out = config->GetKeyOut(); + Tensor* value_out = config->GetValueOut(); ComputeQKVTransposeForward(*query_out, *key_out, *value_out, q_transpose_out, k_transpose_out, v_transpose_out); @@ -340,7 +368,7 @@ class FMHAGateRef { // [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] * // [batch_size, seq_len_m, num_heads, m_size, key_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size] - Tensor* qk_out = config->GetQKOut(dev_ctx_, softmax_out); + Tensor* qk_out = config->GetQKOut(softmax_out); T* qk_out_ptr = qk_out->data(); int64_t gemm_batch_size = @@ -362,9 +390,8 @@ class FMHAGateRef { // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * // [batch_size, seq_len_m, num_heads, m_size, key_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] - Tensor qktv_out; - qktv_out.Resize(config->qktv_out_dims); - T* qktv_out_ptr = qktv_out.mutable_data(dev_ctx_.GetPlace()); + Tensor* qktv_out = config->GetQKTVOut(gate_out); + T* qktv_out_ptr = qktv_out->data(); gemm_m = config->seq_len_r; gemm_n = config->key_dim; @@ -375,7 +402,11 @@ class FMHAGateRef { gemm_m, gemm_n, gemm_k, gemm_batch_size); // fmha_out = transpose(qktv_out) - ComputeQKTVTransposeForward(qktv_out, fmha_out); + ComputeQKTVTransposeForward(*qktv_out, fmha_out); + config->ClearQKTVOut(); + if (config->has_gating) { + gate_out->Resize(config->gate_out_dims); + } } void ComputeBackward(const Tensor* q_transpose_out, @@ -409,8 +440,10 @@ class FMHAGateRef { v_ptr = k_ptr + q_size; qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims); + AllocWithDebugInfo(dev_ctx_, "qkv_transpose_out_grad", + &qkv_transpose_out_grad); - q_grad_ptr = qkv_transpose_out_grad.mutable_data(dev_ctx_.GetPlace()); + q_grad_ptr = qkv_transpose_out_grad.data(); k_grad_ptr = q_grad_ptr + q_size; v_grad_ptr = k_grad_ptr + q_size; } else { @@ -442,7 +475,7 @@ class FMHAGateRef { Tensor softmax_out_grad; softmax_out_grad.Resize(config->softmax_out_dims); - softmax_out_grad.mutable_data(dev_ctx_.GetPlace()); + AllocWithDebugInfo(dev_ctx_, "softmax_out_grad", &softmax_out_grad); int64_t gemm_batch_size = config->batch_size * config->seq_len_m * config->num_heads; @@ -450,7 +483,7 @@ class FMHAGateRef { // Forward: fmha_out = transpose(qktv_out) Tensor qktv_out_grad; qktv_out_grad.Resize(config->qktv_out_dims); - T* qktv_out_grad_ptr = qktv_out_grad.mutable_data(dev_ctx_.GetPlace()); + AllocWithDebugInfo(dev_ctx_, "qktv_out_grad", &qktv_out_grad); ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad); // Forward: qktv_out = BatchedGEMM(softmax_out, V) @@ -461,6 +494,7 @@ class FMHAGateRef { int64_t gemm_k = config->seq_len_r; const T* softmax_out_ptr = softmax_out->data(); + const T* qktv_out_grad_ptr = qktv_out_grad.data(); ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true, false, gemm_m, gemm_n, gemm_k, gemm_batch_size); @@ -474,7 +508,7 @@ class FMHAGateRef { true, gemm_m, gemm_n, gemm_k, gemm_batch_size); } - Tensor* qk_out_grad = config->GetQKOutGrad(dev_ctx_, &softmax_out_grad); + Tensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad); ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out, src_mask_grad, qk_out_grad, nonbatched_bias_grad); @@ -498,12 +532,12 @@ class FMHAGateRef { gemm_n, gemm_k, gemm_batch_size, alpha); if (merge_qkv_) { - Tensor* qkv_out_grad = config->GetQKVOutGrad(dev_ctx_); + Tensor* qkv_out_grad = config->GetQKVOutGrad(); ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad); } else { - Tensor* q_out_grad = config->GetQueryOutGrad(dev_ctx_); - Tensor* k_out_grad = config->GetKeyOutGrad(dev_ctx_); - Tensor* v_out_grad = config->GetValueOutGrad(dev_ctx_); + Tensor* q_out_grad = config->GetQueryOutGrad(); + Tensor* k_out_grad = config->GetKeyOutGrad(); + Tensor* v_out_grad = config->GetValueOutGrad(); ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad, v_transpose_out_grad, q_out_grad, k_out_grad, v_out_grad); @@ -578,12 +612,12 @@ class FMHAGateRef { if (nonbatched_bias) { std::vector ins = {qk_out, nonbatched_bias, src_mask}; std::vector outs = {qk_out}; - phi::funcs::BroadcastKernel( + phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, TernaryAddFunctor()); } else { std::vector ins = {qk_out, src_mask}; std::vector outs = {qk_out}; - phi::funcs::BroadcastKernel( + phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); } phi::SoftmaxForwardCUDAKernelDriver(dev_ctx_, *qk_out, -1, softmax_out); @@ -614,12 +648,12 @@ class FMHAGateRef { phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, *softmax_out, *softmax_out_grad, -1, qk_out_grad); - // [1, bs, num_head, seq_l, seq_l] -> [bs, num_head, seq_l, seq_l] if (nonbatched_bias_grad) { - gpuStream_t stream = dev_ctx_.stream(); - TensorReduceImpl>( + // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] -> + // [batch_size, 1, num_heads, seq_len_r, m_size] + phi::funcs::ReduceKernel>( dev_ctx_, *qk_out_grad, nonbatched_bias_grad, - kps::IdentityFunctor(), {0, 1}, stream); + kps::IdentityFunctor(), {1}); } } diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cc b/paddle/fluid/operators/fused/fused_gate_attention_op.cc index 0bbeabd5fc9cb965241becfc6593dfc8a313f1a7..506f437b1ae5407d598f0a69764bb46eaa7af6ab 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cc @@ -214,7 +214,7 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { "fused_aate_attention_arad"); if (ctx->Attrs().Get("has_gating")) { - for (auto& name : {"GateWeight", "GateBias", "GateOut"}) { + for (auto& name : {"GateWeight", "GateBias"}) { ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name)); } } @@ -224,9 +224,6 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("NonbatchedBias")); } - ctx->SetOutputDim(framework::GradVarName("FMHAOut"), - ctx->GetInputDim("FMHAOut")); - ctx->SetOutputDim(framework::GradVarName("OutLinearWeight"), ctx->GetInputDim("OutLinearWeight")); ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), @@ -270,8 +267,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker { } op->SetInput("FMHAOut", this->Output("FMHAOut")); - op->SetOutput(framework::GradVarName("FMHAOut"), - this->OutputGrad("FMHAOut")); if (this->HasInput("NonbatchedBias")) { op->SetInput("NonbatchedBias", this->Input("NonbatchedBias")); @@ -292,8 +287,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker { this->InputGrad("GateBias")); op->SetInput("GateOut", this->Output("GateOut")); - op->SetOutput(framework::GradVarName("GateOut"), - this->OutputGrad("GateOut")); } op->SetInput("OutLinearWeight", this->Input("OutLinearWeight")); diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cu b/paddle/fluid/operators/fused/fused_gate_attention_op.cu index 8f375a22cc0234c8d9aaec9e272e65f27de65215..ebc9a4f98d0aea9d908c98479bd5de7dbb156dab 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cu @@ -79,11 +79,11 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx, } template -Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, - const GateAttentionGradConfig &config, - const Tensor *query, - const Tensor *qkv_out_grad, - Tensor *query_grad, bool use_addto) { +void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, + const GateAttentionGradConfig &config, + const Tensor *query, + const Tensor *qkv_out_grad, + Tensor *query_grad, bool use_addto) { auto *qkv_weight = ctx.Input("QKVWeight"); auto *qkv_weight_grad = ctx.Output(framework::GradVarName("QKVWeight")); @@ -97,7 +97,6 @@ Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, AttnMatMul(ctx.cuda_device_context(), false, true, m, n, k, false); qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, query_grad, qkv_weight_grad, nullptr, use_addto); - return query_grad; } template @@ -137,12 +136,14 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, } template -Tensor *ComputeSeparatedQKVMatmulBackward( - const framework::ExecutionContext &ctx, - const GateAttentionGradConfig &config, const Tensor *query, - const Tensor *key, const Tensor *query_out_grad, const Tensor *key_out_grad, - const Tensor *value_out_grad, Tensor *query_grad, Tensor *key_grad, - bool use_addto) { +void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, + const GateAttentionGradConfig &config, + const Tensor *query, const Tensor *key, + const Tensor *query_out_grad, + const Tensor *key_out_grad, + const Tensor *value_out_grad, + Tensor *query_grad, Tensor *key_grad, + bool use_addto) { // Gradient of GEMM(key, k_weight) const auto *key_weight = ctx.Input("KeyWeight"); auto *key_weight_grad = @@ -179,22 +180,16 @@ Tensor *ComputeSeparatedQKVMatmulBackward( q_n, q_k, false); q_compute.ComputeBackward(query, query_weight, query_out_grad, query_grad, query_weight_grad, nullptr, use_addto); - return query_grad; } template -Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx, - const GateAttentionConfig &config, - const Tensor *query, - const Tensor *fmha_out) { +void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, + const GateAttentionConfig &config, + const Tensor *query, const Tensor *fmha_out, + Tensor *gate_out) { auto *gate_weight = ctx.Input("GateWeight"); auto *gate_bias = ctx.Input("GateBias"); - auto *gate_out = ctx.Output("GateOut"); - gate_out->mutable_data(ctx.GetPlace()); - VLOG(4) << "[ComputeGatingLinearForward] gate_out: " - << MemoryDebugString(*gate_out); - // The first gate_bias_out stores the result of the multiplication, // and the second gate_bias_out stores the result of the multiplication + // bias. @@ -212,16 +207,14 @@ Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx, std::vector outs = {gate_out}; phi::funcs::ElementwiseKernel(ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyFunctor()); - return gate_out; } template -Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, - const GateAttentionGradConfig &config, - const Tensor *fmha_out, - const Tensor *gate_out_grad, - Tensor *query_grad, Tensor *fmha_out_grad) { - const auto *query = ctx.Input("Query"); +void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, + const GateAttentionGradConfig &config, + const Tensor *query, const Tensor *fmha_out, + const Tensor *gate_out_grad, + Tensor *query_grad, Tensor *fmha_out_grad) { const auto *gate_weight = ctx.Input("GateWeight"); const auto *gate_bias = ctx.Input("GateBias"); @@ -255,20 +248,15 @@ Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, gate_attn_compute.ComputeBackward(query, gate_weight, &gate_bias_out, query_grad, gate_weight_grad, gate_bias_grad); - return fmha_out_grad; } template -Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx, - const GateAttentionConfig &config, - const Tensor *fmha_or_gate_out) { +void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, + const GateAttentionConfig &config, + const Tensor *fmha_or_gate_out, Tensor *out) { const auto *out_linear_weight = ctx.Input("OutLinearWeight"); const auto *out_linear_bias = ctx.Input("OutLinearBias"); - auto *out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - VLOG(4) << "[ComputeOutputLinearForward] out: " << MemoryDebugString(*out); - // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; @@ -277,28 +265,24 @@ Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx, AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out, out_linear_bias, out, out); - return out; } template -Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, - const GateAttentionGradConfig &config, - bool has_gating) { - std::string input_name = has_gating ? "GateOut" : "FMHAOut"; - +void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, + const GateAttentionGradConfig &config, + const Tensor *input, Tensor *input_grad) { const auto *out_grad = ctx.Input(framework::GradVarName("Out")); const auto *out_linear_weight = ctx.Input("OutLinearWeight"); - const auto *input = ctx.Input(input_name); auto *out_linear_weight_grad = ctx.Output(framework::GradVarName("OutLinearWeight")); auto *out_linear_bias_grad = ctx.Output(framework::GradVarName("OutLinearBias")); - auto *input_grad = ctx.Output(framework::GradVarName(input_name)); out_linear_weight_grad->mutable_data(ctx.GetPlace()); out_linear_bias_grad->mutable_data(ctx.GetPlace()); - input_grad->mutable_data(ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; @@ -308,7 +292,6 @@ Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad, input_grad, out_linear_weight_grad, out_linear_bias_grad); - return input_grad; } template @@ -330,56 +313,64 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { auto *softmax_out = ctx.Output("SoftmaxOut"); auto *fmha_out = ctx.Output("FMHAOut"); + auto *gate_out = ctx.Output("GateOut"); + auto *out = ctx.Output("Out"); const bool merge_qkv = ctx.Attr("merge_qkv"); const bool has_gating = ctx.Attr("has_gating"); - // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged. auto &dev_ctx = ctx.template device_context(); - GateAttentionConfig config(query, key, query_weight, qkv_weight, - merge_qkv); + AllocWithDebugInfo(dev_ctx, "softmax_out", softmax_out); + AllocWithDebugInfo(dev_ctx, "fmha_out", fmha_out); + if (has_gating) { + AllocWithDebugInfo(dev_ctx, "gate_out", gate_out); + } + AllocWithDebugInfo(dev_ctx, "out", out); + + // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged. + GateAttentionConfig config(dev_ctx, query, key, query_weight, qkv_weight, + merge_qkv, has_gating); if (merge_qkv) { + PADDLE_ENFORCE_EQ(!key || query == key, true, + platform::errors::InvalidArgument( + "key is expected to be nullptr or the same as " + "query, but recieved key=%p, query=%p.", + key, query)); + // 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc) - Tensor *qkv_out = config.GetQKVOut(dev_ctx); + Tensor *qkv_out = config.GetQKVOut(); ComputeMergedQKVMatmulForward(ctx, config, query, qkv_out); - qkv_transpose_out->mutable_data(ctx.GetPlace()); - VLOG(4) << "qkv_transpose_out:" << MemoryDebugString(*qkv_transpose_out); + AllocWithDebugInfo(dev_ctx, "qkv_transpose_out", qkv_transpose_out); } else { // 1. Separated QKV Matmul - Tensor *query_out = config.GetQueryOut(dev_ctx); - Tensor *key_out = config.GetKeyOut(dev_ctx); - Tensor *value_out = config.GetValueOut(dev_ctx); + Tensor *query_out = config.GetQueryOut(); + Tensor *key_out = config.GetKeyOut(); + Tensor *value_out = config.GetValueOut(); ComputeSeparatedQKVMatmulForward(ctx, config, query, key, query_out, key_out, value_out); - q_transpose_out->mutable_data(ctx.GetPlace()); - k_transpose_out->mutable_data(ctx.GetPlace()); - v_transpose_out->mutable_data(ctx.GetPlace()); - VLOG(4) << "q_transpose_out: " << MemoryDebugString(*q_transpose_out); - VLOG(4) << "k_transpose_out: " << MemoryDebugString(*k_transpose_out); - VLOG(4) << "v_transpose_out: " << MemoryDebugString(*v_transpose_out); + AllocWithDebugInfo(dev_ctx, "q_transpose_out", q_transpose_out); + AllocWithDebugInfo(dev_ctx, "k_transpose_out", k_transpose_out); + AllocWithDebugInfo(dev_ctx, "v_transpose_out", v_transpose_out); } - softmax_out->mutable_data(ctx.GetPlace()); - fmha_out->mutable_data(ctx.GetPlace()); - VLOG(4) << "softmax_out: " << MemoryDebugString(*softmax_out); - VLOG(4) << "fmha_out: " << MemoryDebugString(*fmha_out); - // 2. FMHA auto fmha_compute = FMHAGateRef(dev_ctx, merge_qkv); - fmha_compute.ComputeForward( - nonbatched_bias, src_mask, q_transpose_out, k_transpose_out, - v_transpose_out, qkv_transpose_out, softmax_out, fmha_out, &config); + fmha_compute.ComputeForward(nonbatched_bias, src_mask, q_transpose_out, + k_transpose_out, v_transpose_out, + qkv_transpose_out, softmax_out, fmha_out, + gate_out, &config); // 3. Gating Linear - Tensor *fmha_or_gate_out = !has_gating ? fmha_out - : ComputeGatingLinearForward( - ctx, config, query, fmha_out); + if (has_gating) { + ComputeGatingLinearForward(ctx, config, query, fmha_out, gate_out); + } // 4. Output Linear - ComputeOutputLinearForward(ctx, config, fmha_or_gate_out); + Tensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out; + ComputeOutputLinearForward(ctx, config, fmha_or_gate_out, out); } }; @@ -387,9 +378,6 @@ template class FusedGateAttentionGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - const auto has_gating = ctx.Attr("has_gating"); - const auto merge_qkv = ctx.Attr("merge_qkv"); - // forward input const auto *query = ctx.Input("Query"); const auto *key = ctx.Input("Key"); @@ -403,56 +391,68 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { const auto *qkv_transpose_out = ctx.Input("QKVTransposeOut"); const auto *softmax_out = ctx.Input("SoftmaxOut"); const auto *fmha_out = ctx.Input("FMHAOut"); + const auto *gate_out = ctx.Input("GateOut"); // backward output auto *query_grad = ctx.Output(framework::GradVarName("Query")); - query_grad->mutable_data(ctx.GetPlace()); auto *nonbatched_bias_grad = ctx.Output(framework::GradVarName("NonbatchedBias")); - auto *fmha_out_grad = ctx.Output(framework::GradVarName("FMHAOut")); + + bool has_gating = ctx.Attr("has_gating"); + bool merge_qkv = ctx.Attr("merge_qkv"); auto &dev_ctx = ctx.template device_context(); - GateAttentionGradConfig config(query, key, query_weight, qkv_weight, - merge_qkv); + AllocWithDebugInfo(dev_ctx, "query_grad", query_grad); - // 1. Gradient of Output Linear - Tensor *fhma_or_gate_out_grad = - ComputeOutputLinearBackward(ctx, config, has_gating); + GateAttentionGradConfig config(dev_ctx, query, key, query_weight, + qkv_weight, merge_qkv, has_gating); - // 2. Gradient of Gating Linear + Tensor fmha_out_grad; + fmha_out_grad.Resize(config.gate_out_dims); + AllocWithDebugInfo(dev_ctx, "fmha_out_grad", &fmha_out_grad); if (has_gating) { - // fhma_or_gate_out_grad is actually gate_out_grad. - fmha_out_grad->mutable_data(ctx.GetPlace()); - ComputeGatingLinearBackward(ctx, config, fmha_out, - fhma_or_gate_out_grad, query_grad, - fmha_out_grad); + // 1. Gradient of Output Linear: out = Linear(gate_out) + Tensor gate_out_grad; + gate_out_grad.Resize(config.gate_out_dims); + AllocWithDebugInfo(dev_ctx, "gate_out_grad", &gate_out_grad); + ComputeOutputLinearBackward(ctx, config, gate_out, &gate_out_grad); + + // 2. Gradient of Gating Linear + // Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out + ComputeGatingLinearBackward(ctx, config, query, fmha_out, + &gate_out_grad, query_grad, + &fmha_out_grad); + } else { + // 1. Gradient of Output Linear: out = Linear(fmha_grad) + ComputeOutputLinearBackward(ctx, config, fmha_out, &fmha_out_grad); } // 3. Gradient of FMHA if (nonbatched_bias_grad) { - nonbatched_bias_grad->mutable_data(ctx.GetPlace()); + AllocWithDebugInfo(dev_ctx, "nonbatched_bias_grad", + nonbatched_bias_grad); } auto fmha_compute = FMHAGateRef(dev_ctx, merge_qkv); fmha_compute.ComputeBackward( q_transpose_out, k_transpose_out, v_transpose_out, qkv_transpose_out, - softmax_out, fmha_out_grad, nullptr, nonbatched_bias_grad, &config); + softmax_out, &fmha_out_grad, nullptr, nonbatched_bias_grad, &config); bool use_addto = has_gating ? true : false; if (merge_qkv) { // 4. Gradient of Merged QKV Matmul - Tensor *qkv_out_grad = config.GetQKVOutGrad(dev_ctx); + Tensor *qkv_out_grad = config.GetQKVOutGrad(); ComputeMergedQKVMatmulBackward(ctx, config, query, qkv_out_grad, query_grad, use_addto); } else { // 4. Gradient of Separated QKV Matmul auto *key_grad = ctx.Output(framework::GradVarName("Key")); if (key_grad) { - key_grad->mutable_data(ctx.GetPlace()); + AllocWithDebugInfo(dev_ctx, "key_grad", key_grad); } - Tensor *query_out_grad = config.GetQueryOutGrad(dev_ctx); - Tensor *key_out_grad = config.GetKeyOutGrad(dev_ctx); - Tensor *value_out_grad = config.GetValueOutGrad(dev_ctx); + Tensor *query_out_grad = config.GetQueryOutGrad(); + Tensor *key_out_grad = config.GetKeyOutGrad(); + Tensor *value_out_grad = config.GetValueOutGrad(); ComputeSeparatedQKVMatmulBackward( ctx, config, query, key, query_out_grad, key_out_grad, value_out_grad, query_grad, key_grad, use_addto); diff --git a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py index edfb46f5813b6d895b30dda1dbd24a279e9fde96..52418bba633f15bf8b4b39a5da96a252c52d840b 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py @@ -18,7 +18,7 @@ import paddle import paddle.nn as nn from paddle import tensor import unittest -from op_test import OpTest, convert_float_to_uint16 +from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float from test_sparse_attention_op import get_cuda_version from paddle import _C_ops from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph @@ -194,23 +194,36 @@ class TestFusedGateAttentionOp(OpTest): return out, query.grad, None def check_output_and_grad(self, atol, rtol): - out_ref, query_grad_ref, key_grad_ref = self.get_reference_out() - out, query_grad, key_grad = self.get_fused_gate_attention_out() - np.testing.assert_allclose(out_ref, out.numpy(), atol=atol, rtol=rtol) - np.testing.assert_allclose(query_grad_ref, - query_grad.numpy(), - atol=atol, - rtol=rtol) - if key_grad_ref is not None and key_grad is not None: - np.testing.assert_allclose(key_grad_ref, - key_grad.numpy(), - atol=atol, - rtol=rtol) + + def _convert(value): + if self.dtype == "bfloat16": + return convert_uint16_to_float(value) + return value + + output_names = ["out", "query_grad", "key_grad"] + outputs_ref = self.get_reference_out() + outputs_fused = self.get_fused_gate_attention_out() + for i in range(len(outputs_fused)): + ref_res = outputs_ref[i] + fused_res = outputs_fused[i] + if ref_res is not None and fused_res is not None: + print("Checking {}".format(output_names[i])) + np.testing.assert_allclose(_convert(ref_res), + _convert(fused_res.numpy()), + atol=atol, + rtol=rtol) def test_output_and_grad(self): self.check_output_and_grad(atol=1e-5, rtol=1e-5) +class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp): + + def config(self): + super().config() + self.batch_size = 2 + + class TestSeparatedQKVCase(TestFusedGateAttentionOp): def config(self): @@ -243,7 +256,16 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp): self.dtype = "float16" def test_output_and_grad(self): - self.check_output_and_grad(atol=1e-1, rtol=1e-5) + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_and_grad(atol=1e-1, rtol=1e-5) + + +class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case): + + def config(self): + super().config() + self.batch_size = 2 @unittest.skipIf( @@ -260,5 +282,12 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp): self.check_output_and_grad(atol=1e-1, rtol=1e-3) +class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case): + + def config(self): + super().config() + self.batch_size = 2 + + if __name__ == "__main__": unittest.main()