未验证 提交 10f8637c 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix wrong reduce_dims in fused_gate_attention and optimize the memory usage. (#43216)

* Polish codes and memory usage for fused_gate_attention.

* Fix wrong reduce_dims in fused_gate_attention when computing gradient of nonbatched_bias.
上级 5a6c974e
...@@ -14,11 +14,11 @@ limitations under the License. */ ...@@ -14,11 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle { namespace paddle {
...@@ -27,19 +27,29 @@ namespace operators { ...@@ -27,19 +27,29 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
inline std::string MemoryDebugString(const Tensor& t) { inline std::string MemoryDebugString(const Tensor& t) {
int device_id = platform::GetCurrentDeviceId();
int64_t allocated =
memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
int64_t reserved =
memory::DeviceMemoryStatCurrentValue("Reserved", device_id);
std::stringstream ss; std::stringstream ss;
ss << "shape=[" << t.dims() ss << "shape=[" << t.dims()
<< "], size=" << static_cast<float>(t.memory_size()) / (1 << 20) << "], size=" << static_cast<float>(t.memory_size()) / (1 << 20)
<< " MB, ptr=" << t.data(); << " MB, ptr=" << t.data()
<< "; [MEMORY] allocated=" << static_cast<float>(allocated) / (1 << 20)
size_t total = 0; << " MB"
size_t available = 0; << ", reserved=" << static_cast<float>(reserved) / (1 << 20) << " MB";
platform::GpuMemoryUsage(&available, &total);
ss << "; memory allocated="
<< static_cast<float>(total - available) / (1 << 20) << " MB";
return ss.str(); return ss.str();
} }
template <typename T>
void AllocWithDebugInfo(const platform::CUDADeviceContext& dev_ctx,
const std::string& info, Tensor* t) {
t->mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << info << ": " << MemoryDebugString(*t);
}
template <typename T> template <typename T>
struct TernaryAddFunctor { struct TernaryAddFunctor {
inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; } inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
...@@ -48,6 +58,11 @@ struct TernaryAddFunctor { ...@@ -48,6 +58,11 @@ struct TernaryAddFunctor {
template <typename T> template <typename T>
struct GateAttentionConfig { struct GateAttentionConfig {
public: public:
const platform::CUDADeviceContext& dev_ctx;
bool merge_qkv;
bool has_gating;
int64_t batch_size; int64_t batch_size;
int64_t seq_len_m; int64_t seq_len_m;
int64_t seq_len_r; int64_t seq_len_r;
...@@ -70,9 +85,11 @@ struct GateAttentionConfig { ...@@ -70,9 +85,11 @@ struct GateAttentionConfig {
phi::DDim qktv_out_dims; phi::DDim qktv_out_dims;
phi::DDim gate_out_dims; phi::DDim gate_out_dims;
GateAttentionConfig(const Tensor* query, const Tensor* key, GateAttentionConfig(const platform::CUDADeviceContext& dev_ctx,
const Tensor* query, const Tensor* key,
const Tensor* query_weight, const Tensor* qkv_weight, const Tensor* query_weight, const Tensor* qkv_weight,
bool merge_qkv) { bool merge_qkv, bool has_gating)
: dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) {
// query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
batch_size = query->dims()[0]; batch_size = query->dims()[0];
seq_len_m = query->dims()[1]; seq_len_m = query->dims()[1];
...@@ -131,59 +148,68 @@ struct GateAttentionConfig { ...@@ -131,59 +148,68 @@ struct GateAttentionConfig {
return batch_size * seq_len_m * seq_len_r * num_heads * key_dim; return batch_size * seq_len_m * seq_len_r * num_heads * key_dim;
} }
Tensor* GetQKVOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQKVOut() {
if (!qkv_out.IsInitialized()) { if (!qkv_out.IsInitialized()) {
qkv_out.Resize(qkv_out_dims); qkv_out.Resize(qkv_out_dims);
qkv_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "qkv_out", &qkv_out);
VLOG(4) << "qkv_out: " << MemoryDebugString(qkv_out);
} }
return &qkv_out; return &qkv_out;
} }
Tensor* GetQueryOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQueryOut() {
if (!query_out.IsInitialized()) { if (!query_out.IsInitialized()) {
query_out.Resize(q_out_dims); query_out.Resize(q_out_dims);
query_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "query_out", &query_out);
VLOG(4) << "query_out: " << MemoryDebugString(query_out);
} }
return &query_out; return &query_out;
} }
Tensor* GetKeyOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetKeyOut() {
if (!key_out.IsInitialized()) { if (!key_out.IsInitialized()) {
key_out.Resize(kv_out_dims); key_out.Resize(kv_out_dims);
key_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "key_out", &key_out);
VLOG(4) << "key_out: " << MemoryDebugString(key_out);
} }
return &key_out; return &key_out;
} }
Tensor* GetValueOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetValueOut() {
if (!value_out.IsInitialized()) { if (!value_out.IsInitialized()) {
value_out.Resize(kv_out_dims); value_out.Resize(kv_out_dims);
value_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "value_out", &value_out);
VLOG(4) << "value_out: " << MemoryDebugString(value_out);
} }
return &value_out; return &value_out;
} }
Tensor* GetQKOut(const platform::CUDADeviceContext& dev_ctx, Tensor* GetQKOut(Tensor* softmax_out) {
Tensor* softmax_out) {
// softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1] // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
int softmax_dim = m_size; int softmax_dim = m_size;
if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) { if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
// Not sure whether cudnn softmax can execute inplace. // Not sure whether cudnn softmax can execute inplace.
if (!qkv_out.IsInitialized()) { if (!qkv_out.IsInitialized()) {
qk_out.Resize(qk_out_dims); qk_out.Resize(qk_out_dims);
qk_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "qk_out", &qk_out);
VLOG(4) << "qk_out: " << MemoryDebugString(qk_out);
} }
return &qk_out; return &qk_out;
} else { } else {
// Enable inplace softmax.
return softmax_out; return softmax_out;
} }
} }
Tensor* GetQKTVOut(Tensor* gate_out) {
if (has_gating && gate_out) {
// Reuse gate_out.
gate_out->Resize(qktv_out_dims);
return gate_out;
} else {
if (!qktv_out.IsInitialized()) {
qktv_out.Resize(qktv_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "qktv_out", &qktv_out);
}
return &qktv_out;
}
}
void ClearQKVOut() { void ClearQKVOut() {
if (qkv_out.IsInitialized()) { if (qkv_out.IsInitialized()) {
qkv_out.clear(); qkv_out.clear();
...@@ -196,9 +222,14 @@ struct GateAttentionConfig { ...@@ -196,9 +222,14 @@ struct GateAttentionConfig {
} }
} }
void ClearQKTVOut() {
if (qktv_out.IsInitialized()) {
qktv_out.clear();
}
}
protected: protected:
Tensor qkv_out; Tensor qkv_out;
// QKV is not merged
Tensor query_out; Tensor query_out;
Tensor key_out; Tensor key_out;
Tensor value_out; Tensor value_out;
...@@ -207,63 +238,60 @@ struct GateAttentionConfig { ...@@ -207,63 +238,60 @@ struct GateAttentionConfig {
// softmax_out = softmax(qk_out + nonbatched_bias + src_mask) // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
// The shape of qk_out, softmax_out is the same, thus can be called inplace. // The shape of qk_out, softmax_out is the same, thus can be called inplace.
Tensor qk_out; Tensor qk_out;
// qktv_out may reuse gate_out.
Tensor qktv_out;
}; };
template <typename T> template <typename T>
struct GateAttentionGradConfig : public GateAttentionConfig<T> { struct GateAttentionGradConfig : public GateAttentionConfig<T> {
public: public:
GateAttentionGradConfig(const Tensor* query, const Tensor* key, GateAttentionGradConfig(const platform::CUDADeviceContext& dev_ctx,
const Tensor* query, const Tensor* key,
const Tensor* query_weight, const Tensor* qkv_weight, const Tensor* query_weight, const Tensor* qkv_weight,
bool merge_qkv) bool merge_qkv, bool has_gating)
: GateAttentionConfig<T>(query, key, query_weight, qkv_weight, : GateAttentionConfig<T>(dev_ctx, query, key, query_weight, qkv_weight,
merge_qkv) {} merge_qkv, has_gating) {}
Tensor* GetQKVOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQKVOutGrad() {
if (!qkv_out_grad.IsInitialized()) { if (!qkv_out_grad.IsInitialized()) {
qkv_out_grad.Resize(this->qkv_out_dims); qkv_out_grad.Resize(this->qkv_out_dims);
qkv_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "qkv_out_grad", &qkv_out_grad);
VLOG(4) << "qkv_out_grad: " << MemoryDebugString(qkv_out_grad);
} }
return &qkv_out_grad; return &qkv_out_grad;
} }
Tensor* GetQueryOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQueryOutGrad() {
if (!query_out_grad.IsInitialized()) { if (!query_out_grad.IsInitialized()) {
query_out_grad.Resize(this->q_out_dims); query_out_grad.Resize(this->q_out_dims);
query_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "query_out_grad", &query_out_grad);
VLOG(4) << "query_out_grad: " << MemoryDebugString(query_out_grad);
} }
return &query_out_grad; return &query_out_grad;
} }
Tensor* GetKeyOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetKeyOutGrad() {
if (!key_out_grad.IsInitialized()) { if (!key_out_grad.IsInitialized()) {
key_out_grad.Resize(this->kv_out_dims); key_out_grad.Resize(this->kv_out_dims);
key_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "key_out_grad", &key_out_grad);
VLOG(4) << "key_out_grad: " << MemoryDebugString(key_out_grad);
} }
return &key_out_grad; return &key_out_grad;
} }
Tensor* GetValueOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetValueOutGrad() {
if (!value_out_grad.IsInitialized()) { if (!value_out_grad.IsInitialized()) {
value_out_grad.Resize(this->kv_out_dims); value_out_grad.Resize(this->kv_out_dims);
value_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "value_out_grad", &value_out_grad);
VLOG(4) << "value_out_grad: " << MemoryDebugString(value_out_grad);
} }
return &value_out_grad; return &value_out_grad;
} }
Tensor* GetQKOutGrad(const platform::CUDADeviceContext& dev_ctx, Tensor* GetQKOutGrad(Tensor* softmax_out_grad) {
Tensor* softmax_out_grad) {
// softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1] // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
int softmax_dim = this->m_size; int softmax_dim = this->m_size;
if (!softmax_out_grad || if (!softmax_out_grad ||
phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) { phi::UseCudnnSoftmax<T>(this->dev_ctx, softmax_dim, true)) {
if (!qk_out_grad.IsInitialized()) { if (!qk_out_grad.IsInitialized()) {
qk_out_grad.Resize(this->qk_out_dims); qk_out_grad.Resize(this->qk_out_dims);
qk_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "qk_out_grad", &qk_out_grad);
VLOG(4) << "qk_out_grad: " << MemoryDebugString(qk_out_grad);
} }
return &qk_out_grad; return &qk_out_grad;
} else { } else {
...@@ -288,7 +316,7 @@ class FMHAGateRef { ...@@ -288,7 +316,7 @@ class FMHAGateRef {
void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask, void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask,
Tensor* q_transpose_out, Tensor* k_transpose_out, Tensor* q_transpose_out, Tensor* k_transpose_out,
Tensor* v_transpose_out, Tensor* qkv_transpose_out, Tensor* v_transpose_out, Tensor* qkv_transpose_out,
Tensor* softmax_out, Tensor* fmha_out, Tensor* softmax_out, Tensor* fmha_out, Tensor* gate_out,
GateAttentionConfig<T>* config) { GateAttentionConfig<T>* config) {
T* q_ptr = nullptr; T* q_ptr = nullptr;
T* k_ptr = nullptr; T* k_ptr = nullptr;
...@@ -300,7 +328,7 @@ class FMHAGateRef { ...@@ -300,7 +328,7 @@ class FMHAGateRef {
platform::errors::NotFound("The input qkv_transpose_out can not be " platform::errors::NotFound("The input qkv_transpose_out can not be "
"nullptr when merge_qkv is true.")); "nullptr when merge_qkv is true."));
Tensor* qkv_out = config->GetQKVOut(dev_ctx_); Tensor* qkv_out = config->GetQKVOut();
ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out); ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
config->ClearQKVOut(); config->ClearQKVOut();
...@@ -323,9 +351,9 @@ class FMHAGateRef { ...@@ -323,9 +351,9 @@ class FMHAGateRef {
platform::errors::NotFound("The input v_transpose_out can not be " platform::errors::NotFound("The input v_transpose_out can not be "
"nullptr when merge_qkv is false.")); "nullptr when merge_qkv is false."));
Tensor* query_out = config->GetQueryOut(dev_ctx_); Tensor* query_out = config->GetQueryOut();
Tensor* key_out = config->GetKeyOut(dev_ctx_); Tensor* key_out = config->GetKeyOut();
Tensor* value_out = config->GetValueOut(dev_ctx_); Tensor* value_out = config->GetValueOut();
ComputeQKVTransposeForward(*query_out, *key_out, *value_out, ComputeQKVTransposeForward(*query_out, *key_out, *value_out,
q_transpose_out, k_transpose_out, q_transpose_out, k_transpose_out,
v_transpose_out); v_transpose_out);
...@@ -340,7 +368,7 @@ class FMHAGateRef { ...@@ -340,7 +368,7 @@ class FMHAGateRef {
// [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] * // [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim] // [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size] // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
Tensor* qk_out = config->GetQKOut(dev_ctx_, softmax_out); Tensor* qk_out = config->GetQKOut(softmax_out);
T* qk_out_ptr = qk_out->data<T>(); T* qk_out_ptr = qk_out->data<T>();
int64_t gemm_batch_size = int64_t gemm_batch_size =
...@@ -362,9 +390,8 @@ class FMHAGateRef { ...@@ -362,9 +390,8 @@ class FMHAGateRef {
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim] // [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim]
Tensor qktv_out; Tensor* qktv_out = config->GetQKTVOut(gate_out);
qktv_out.Resize(config->qktv_out_dims); T* qktv_out_ptr = qktv_out->data<T>();
T* qktv_out_ptr = qktv_out.mutable_data<T>(dev_ctx_.GetPlace());
gemm_m = config->seq_len_r; gemm_m = config->seq_len_r;
gemm_n = config->key_dim; gemm_n = config->key_dim;
...@@ -375,7 +402,11 @@ class FMHAGateRef { ...@@ -375,7 +402,11 @@ class FMHAGateRef {
gemm_m, gemm_n, gemm_k, gemm_batch_size); gemm_m, gemm_n, gemm_k, gemm_batch_size);
// fmha_out = transpose(qktv_out) // fmha_out = transpose(qktv_out)
ComputeQKTVTransposeForward(qktv_out, fmha_out); ComputeQKTVTransposeForward(*qktv_out, fmha_out);
config->ClearQKTVOut();
if (config->has_gating) {
gate_out->Resize(config->gate_out_dims);
}
} }
void ComputeBackward(const Tensor* q_transpose_out, void ComputeBackward(const Tensor* q_transpose_out,
...@@ -409,8 +440,10 @@ class FMHAGateRef { ...@@ -409,8 +440,10 @@ class FMHAGateRef {
v_ptr = k_ptr + q_size; v_ptr = k_ptr + q_size;
qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims); qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims);
AllocWithDebugInfo<T>(dev_ctx_, "qkv_transpose_out_grad",
&qkv_transpose_out_grad);
q_grad_ptr = qkv_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); q_grad_ptr = qkv_transpose_out_grad.data<T>();
k_grad_ptr = q_grad_ptr + q_size; k_grad_ptr = q_grad_ptr + q_size;
v_grad_ptr = k_grad_ptr + q_size; v_grad_ptr = k_grad_ptr + q_size;
} else { } else {
...@@ -442,7 +475,7 @@ class FMHAGateRef { ...@@ -442,7 +475,7 @@ class FMHAGateRef {
Tensor softmax_out_grad; Tensor softmax_out_grad;
softmax_out_grad.Resize(config->softmax_out_dims); softmax_out_grad.Resize(config->softmax_out_dims);
softmax_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); AllocWithDebugInfo<T>(dev_ctx_, "softmax_out_grad", &softmax_out_grad);
int64_t gemm_batch_size = int64_t gemm_batch_size =
config->batch_size * config->seq_len_m * config->num_heads; config->batch_size * config->seq_len_m * config->num_heads;
...@@ -450,7 +483,7 @@ class FMHAGateRef { ...@@ -450,7 +483,7 @@ class FMHAGateRef {
// Forward: fmha_out = transpose(qktv_out) // Forward: fmha_out = transpose(qktv_out)
Tensor qktv_out_grad; Tensor qktv_out_grad;
qktv_out_grad.Resize(config->qktv_out_dims); qktv_out_grad.Resize(config->qktv_out_dims);
T* qktv_out_grad_ptr = qktv_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); AllocWithDebugInfo<T>(dev_ctx_, "qktv_out_grad", &qktv_out_grad);
ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad); ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad);
// Forward: qktv_out = BatchedGEMM(softmax_out, V) // Forward: qktv_out = BatchedGEMM(softmax_out, V)
...@@ -461,6 +494,7 @@ class FMHAGateRef { ...@@ -461,6 +494,7 @@ class FMHAGateRef {
int64_t gemm_k = config->seq_len_r; int64_t gemm_k = config->seq_len_r;
const T* softmax_out_ptr = softmax_out->data<T>(); const T* softmax_out_ptr = softmax_out->data<T>();
const T* qktv_out_grad_ptr = qktv_out_grad.data<T>();
ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true, ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true,
false, gemm_m, gemm_n, gemm_k, gemm_batch_size); false, gemm_m, gemm_n, gemm_k, gemm_batch_size);
...@@ -474,7 +508,7 @@ class FMHAGateRef { ...@@ -474,7 +508,7 @@ class FMHAGateRef {
true, gemm_m, gemm_n, gemm_k, gemm_batch_size); true, gemm_m, gemm_n, gemm_k, gemm_batch_size);
} }
Tensor* qk_out_grad = config->GetQKOutGrad(dev_ctx_, &softmax_out_grad); Tensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad);
ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out, ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out,
src_mask_grad, qk_out_grad, src_mask_grad, qk_out_grad,
nonbatched_bias_grad); nonbatched_bias_grad);
...@@ -498,12 +532,12 @@ class FMHAGateRef { ...@@ -498,12 +532,12 @@ class FMHAGateRef {
gemm_n, gemm_k, gemm_batch_size, alpha); gemm_n, gemm_k, gemm_batch_size, alpha);
if (merge_qkv_) { if (merge_qkv_) {
Tensor* qkv_out_grad = config->GetQKVOutGrad(dev_ctx_); Tensor* qkv_out_grad = config->GetQKVOutGrad();
ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad); ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
} else { } else {
Tensor* q_out_grad = config->GetQueryOutGrad(dev_ctx_); Tensor* q_out_grad = config->GetQueryOutGrad();
Tensor* k_out_grad = config->GetKeyOutGrad(dev_ctx_); Tensor* k_out_grad = config->GetKeyOutGrad();
Tensor* v_out_grad = config->GetValueOutGrad(dev_ctx_); Tensor* v_out_grad = config->GetValueOutGrad();
ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad, ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad,
v_transpose_out_grad, q_out_grad, k_out_grad, v_transpose_out_grad, q_out_grad, k_out_grad,
v_out_grad); v_out_grad);
...@@ -578,12 +612,12 @@ class FMHAGateRef { ...@@ -578,12 +612,12 @@ class FMHAGateRef {
if (nonbatched_bias) { if (nonbatched_bias) {
std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask}; std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask};
std::vector<Tensor*> outs = {qk_out}; std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>( phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>()); dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
} else { } else {
std::vector<const Tensor*> ins = {qk_out, src_mask}; std::vector<const Tensor*> ins = {qk_out, src_mask};
std::vector<Tensor*> outs = {qk_out}; std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
} }
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out); phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
...@@ -614,12 +648,12 @@ class FMHAGateRef { ...@@ -614,12 +648,12 @@ class FMHAGateRef {
phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, *softmax_out, phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, *softmax_out,
*softmax_out_grad, -1, qk_out_grad); *softmax_out_grad, -1, qk_out_grad);
// [1, bs, num_head, seq_l, seq_l] -> [bs, num_head, seq_l, seq_l]
if (nonbatched_bias_grad) { if (nonbatched_bias_grad) {
gpuStream_t stream = dev_ctx_.stream(); // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] ->
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( // [batch_size, 1, num_heads, seq_len_r, m_size]
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_, *qk_out_grad, nonbatched_bias_grad, dev_ctx_, *qk_out_grad, nonbatched_bias_grad,
kps::IdentityFunctor<T>(), {0, 1}, stream); kps::IdentityFunctor<T>(), {1});
} }
} }
......
...@@ -214,7 +214,7 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -214,7 +214,7 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
"fused_aate_attention_arad"); "fused_aate_attention_arad");
if (ctx->Attrs().Get<bool>("has_gating")) { if (ctx->Attrs().Get<bool>("has_gating")) {
for (auto& name : {"GateWeight", "GateBias", "GateOut"}) { for (auto& name : {"GateWeight", "GateBias"}) {
ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name)); ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name));
} }
} }
...@@ -224,9 +224,6 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -224,9 +224,6 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("NonbatchedBias")); ctx->GetInputDim("NonbatchedBias"));
} }
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearWeight"), ctx->SetOutputDim(framework::GradVarName("OutLinearWeight"),
ctx->GetInputDim("OutLinearWeight")); ctx->GetInputDim("OutLinearWeight"));
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
...@@ -270,8 +267,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -270,8 +267,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
if (this->HasInput("NonbatchedBias")) { if (this->HasInput("NonbatchedBias")) {
op->SetInput("NonbatchedBias", this->Input("NonbatchedBias")); op->SetInput("NonbatchedBias", this->Input("NonbatchedBias"));
...@@ -292,8 +287,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -292,8 +287,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("GateBias")); this->InputGrad("GateBias"));
op->SetInput("GateOut", this->Output("GateOut")); op->SetInput("GateOut", this->Output("GateOut"));
op->SetOutput(framework::GradVarName("GateOut"),
this->OutputGrad("GateOut"));
} }
op->SetInput("OutLinearWeight", this->Input("OutLinearWeight")); op->SetInput("OutLinearWeight", this->Input("OutLinearWeight"));
......
...@@ -79,11 +79,11 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -79,11 +79,11 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
} }
template <typename T> template <typename T>
Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
const Tensor *query, const Tensor *query,
const Tensor *qkv_out_grad, const Tensor *qkv_out_grad,
Tensor *query_grad, bool use_addto) { Tensor *query_grad, bool use_addto) {
auto *qkv_weight = ctx.Input<Tensor>("QKVWeight"); auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");
auto *qkv_weight_grad = auto *qkv_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("QKVWeight")); ctx.Output<Tensor>(framework::GradVarName("QKVWeight"));
...@@ -97,7 +97,6 @@ Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -97,7 +97,6 @@ Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false); AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, query_grad, qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, query_grad,
qkv_weight_grad, nullptr, use_addto); qkv_weight_grad, nullptr, use_addto);
return query_grad;
} }
template <typename T> template <typename T>
...@@ -137,12 +136,14 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -137,12 +136,14 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
} }
template <typename T> template <typename T>
Tensor *ComputeSeparatedQKVMatmulBackward( void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const framework::ExecutionContext &ctx, const GateAttentionGradConfig<T> &config,
const GateAttentionGradConfig<T> &config, const Tensor *query, const Tensor *query, const Tensor *key,
const Tensor *key, const Tensor *query_out_grad, const Tensor *key_out_grad, const Tensor *query_out_grad,
const Tensor *value_out_grad, Tensor *query_grad, Tensor *key_grad, const Tensor *key_out_grad,
bool use_addto) { const Tensor *value_out_grad,
Tensor *query_grad, Tensor *key_grad,
bool use_addto) {
// Gradient of GEMM(key, k_weight) // Gradient of GEMM(key, k_weight)
const auto *key_weight = ctx.Input<Tensor>("KeyWeight"); const auto *key_weight = ctx.Input<Tensor>("KeyWeight");
auto *key_weight_grad = auto *key_weight_grad =
...@@ -179,22 +180,16 @@ Tensor *ComputeSeparatedQKVMatmulBackward( ...@@ -179,22 +180,16 @@ Tensor *ComputeSeparatedQKVMatmulBackward(
q_n, q_k, false); q_n, q_k, false);
q_compute.ComputeBackward(query, query_weight, query_out_grad, query_grad, q_compute.ComputeBackward(query, query_weight, query_out_grad, query_grad,
query_weight_grad, nullptr, use_addto); query_weight_grad, nullptr, use_addto);
return query_grad;
} }
template <typename T> template <typename T>
Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx, void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const Tensor *query, const Tensor *query, const Tensor *fmha_out,
const Tensor *fmha_out) { Tensor *gate_out) {
auto *gate_weight = ctx.Input<Tensor>("GateWeight"); auto *gate_weight = ctx.Input<Tensor>("GateWeight");
auto *gate_bias = ctx.Input<Tensor>("GateBias"); auto *gate_bias = ctx.Input<Tensor>("GateBias");
auto *gate_out = ctx.Output<Tensor>("GateOut");
gate_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "[ComputeGatingLinearForward] gate_out: "
<< MemoryDebugString(*gate_out);
// The first gate_bias_out stores the result of the multiplication, // The first gate_bias_out stores the result of the multiplication,
// and the second gate_bias_out stores the result of the multiplication + // and the second gate_bias_out stores the result of the multiplication +
// bias. // bias.
...@@ -212,16 +207,14 @@ Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx, ...@@ -212,16 +207,14 @@ Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
std::vector<Tensor *> outs = {gate_out}; std::vector<Tensor *> outs = {gate_out};
phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs, phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
SigmoidMultiplyFunctor<T>()); SigmoidMultiplyFunctor<T>());
return gate_out;
} }
template <typename T> template <typename T>
Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
const Tensor *fmha_out, const Tensor *query, const Tensor *fmha_out,
const Tensor *gate_out_grad, const Tensor *gate_out_grad,
Tensor *query_grad, Tensor *fmha_out_grad) { Tensor *query_grad, Tensor *fmha_out_grad) {
const auto *query = ctx.Input<Tensor>("Query");
const auto *gate_weight = ctx.Input<Tensor>("GateWeight"); const auto *gate_weight = ctx.Input<Tensor>("GateWeight");
const auto *gate_bias = ctx.Input<Tensor>("GateBias"); const auto *gate_bias = ctx.Input<Tensor>("GateBias");
...@@ -255,20 +248,15 @@ Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -255,20 +248,15 @@ Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
gate_attn_compute.ComputeBackward(query, gate_weight, &gate_bias_out, gate_attn_compute.ComputeBackward(query, gate_weight, &gate_bias_out,
query_grad, gate_weight_grad, query_grad, gate_weight_grad,
gate_bias_grad); gate_bias_grad);
return fmha_out_grad;
} }
template <typename T> template <typename T>
Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx, void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const Tensor *fmha_or_gate_out) { const Tensor *fmha_or_gate_out, Tensor *out) {
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight"); const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
const auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias"); const auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "[ComputeOutputLinearForward] out: " << MemoryDebugString(*out);
// out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
...@@ -277,28 +265,24 @@ Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx, ...@@ -277,28 +265,24 @@ Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out, out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out,
out_linear_bias, out, out); out_linear_bias, out, out);
return out;
} }
template <typename T> template <typename T>
Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
bool has_gating) { const Tensor *input, Tensor *input_grad) {
std::string input_name = has_gating ? "GateOut" : "FMHAOut";
const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight"); const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
const auto *input = ctx.Input<Tensor>(input_name);
auto *out_linear_weight_grad = auto *out_linear_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearWeight")); ctx.Output<Tensor>(framework::GradVarName("OutLinearWeight"));
auto *out_linear_bias_grad = auto *out_linear_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias")); ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *input_grad = ctx.Output<Tensor>(framework::GradVarName(input_name));
out_linear_weight_grad->mutable_data<T>(ctx.GetPlace()); out_linear_weight_grad->mutable_data<T>(ctx.GetPlace());
out_linear_bias_grad->mutable_data<T>(ctx.GetPlace()); out_linear_bias_grad->mutable_data<T>(ctx.GetPlace());
input_grad->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
...@@ -308,7 +292,6 @@ Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -308,7 +292,6 @@ Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad, out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad,
input_grad, out_linear_weight_grad, input_grad, out_linear_weight_grad,
out_linear_bias_grad); out_linear_bias_grad);
return input_grad;
} }
template <typename T> template <typename T>
...@@ -330,56 +313,64 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -330,56 +313,64 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut"); auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
auto *fmha_out = ctx.Output<Tensor>("FMHAOut"); auto *fmha_out = ctx.Output<Tensor>("FMHAOut");
auto *gate_out = ctx.Output<Tensor>("GateOut");
auto *out = ctx.Output<Tensor>("Out");
const bool merge_qkv = ctx.Attr<bool>("merge_qkv"); const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating"); const bool has_gating = ctx.Attr<bool>("has_gating");
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
GateAttentionConfig<T> config(query, key, query_weight, qkv_weight, AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
merge_qkv); AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
if (has_gating) {
AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
}
AllocWithDebugInfo<T>(dev_ctx, "out", out);
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
GateAttentionConfig<T> config(dev_ctx, query, key, query_weight, qkv_weight,
merge_qkv, has_gating);
if (merge_qkv) { if (merge_qkv) {
PADDLE_ENFORCE_EQ(!key || query == key, true,
platform::errors::InvalidArgument(
"key is expected to be nullptr or the same as "
"query, but recieved key=%p, query=%p.",
key, query));
// 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc) // 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc)
Tensor *qkv_out = config.GetQKVOut(dev_ctx); Tensor *qkv_out = config.GetQKVOut();
ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out); ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);
qkv_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
VLOG(4) << "qkv_transpose_out:" << MemoryDebugString(*qkv_transpose_out);
} else { } else {
// 1. Separated QKV Matmul // 1. Separated QKV Matmul
Tensor *query_out = config.GetQueryOut(dev_ctx); Tensor *query_out = config.GetQueryOut();
Tensor *key_out = config.GetKeyOut(dev_ctx); Tensor *key_out = config.GetKeyOut();
Tensor *value_out = config.GetValueOut(dev_ctx); Tensor *value_out = config.GetValueOut();
ComputeSeparatedQKVMatmulForward<T>(ctx, config, query, key, query_out, ComputeSeparatedQKVMatmulForward<T>(ctx, config, query, key, query_out,
key_out, value_out); key_out, value_out);
q_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "q_transpose_out", q_transpose_out);
k_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "k_transpose_out", k_transpose_out);
v_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "v_transpose_out", v_transpose_out);
VLOG(4) << "q_transpose_out: " << MemoryDebugString(*q_transpose_out);
VLOG(4) << "k_transpose_out: " << MemoryDebugString(*k_transpose_out);
VLOG(4) << "v_transpose_out: " << MemoryDebugString(*v_transpose_out);
} }
softmax_out->mutable_data<T>(ctx.GetPlace());
fmha_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "softmax_out: " << MemoryDebugString(*softmax_out);
VLOG(4) << "fmha_out: " << MemoryDebugString(*fmha_out);
// 2. FMHA // 2. FMHA
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward( fmha_compute.ComputeForward(nonbatched_bias, src_mask, q_transpose_out,
nonbatched_bias, src_mask, q_transpose_out, k_transpose_out, k_transpose_out, v_transpose_out,
v_transpose_out, qkv_transpose_out, softmax_out, fmha_out, &config); qkv_transpose_out, softmax_out, fmha_out,
gate_out, &config);
// 3. Gating Linear // 3. Gating Linear
Tensor *fmha_or_gate_out = !has_gating ? fmha_out if (has_gating) {
: ComputeGatingLinearForward<T>( ComputeGatingLinearForward<T>(ctx, config, query, fmha_out, gate_out);
ctx, config, query, fmha_out); }
// 4. Output Linear // 4. Output Linear
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out); Tensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out;
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out, out);
} }
}; };
...@@ -387,9 +378,6 @@ template <typename T> ...@@ -387,9 +378,6 @@ template <typename T>
class FusedGateAttentionGradKernel : public framework::OpKernel<T> { class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto has_gating = ctx.Attr<bool>("has_gating");
const auto merge_qkv = ctx.Attr<bool>("merge_qkv");
// forward input // forward input
const auto *query = ctx.Input<Tensor>("Query"); const auto *query = ctx.Input<Tensor>("Query");
const auto *key = ctx.Input<Tensor>("Key"); const auto *key = ctx.Input<Tensor>("Key");
...@@ -403,56 +391,68 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -403,56 +391,68 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
const auto *qkv_transpose_out = ctx.Input<Tensor>("QKVTransposeOut"); const auto *qkv_transpose_out = ctx.Input<Tensor>("QKVTransposeOut");
const auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut"); const auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
const auto *fmha_out = ctx.Input<Tensor>("FMHAOut"); const auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
const auto *gate_out = ctx.Input<Tensor>("GateOut");
// backward output // backward output
auto *query_grad = ctx.Output<Tensor>(framework::GradVarName("Query")); auto *query_grad = ctx.Output<Tensor>(framework::GradVarName("Query"));
query_grad->mutable_data<T>(ctx.GetPlace());
auto *nonbatched_bias_grad = auto *nonbatched_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("NonbatchedBias")); ctx.Output<Tensor>(framework::GradVarName("NonbatchedBias"));
auto *fmha_out_grad = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
GateAttentionGradConfig<T> config(query, key, query_weight, qkv_weight, AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
merge_qkv);
// 1. Gradient of Output Linear GateAttentionGradConfig<T> config(dev_ctx, query, key, query_weight,
Tensor *fhma_or_gate_out_grad = qkv_weight, merge_qkv, has_gating);
ComputeOutputLinearBackward<T>(ctx, config, has_gating);
// 2. Gradient of Gating Linear Tensor fmha_out_grad;
fmha_out_grad.Resize(config.gate_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out_grad", &fmha_out_grad);
if (has_gating) { if (has_gating) {
// fhma_or_gate_out_grad is actually gate_out_grad. // 1. Gradient of Output Linear: out = Linear(gate_out)
fmha_out_grad->mutable_data<T>(ctx.GetPlace()); Tensor gate_out_grad;
ComputeGatingLinearBackward<T>(ctx, config, fmha_out, gate_out_grad.Resize(config.gate_out_dims);
fhma_or_gate_out_grad, query_grad, AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad);
fmha_out_grad); ComputeOutputLinearBackward<T>(ctx, config, gate_out, &gate_out_grad);
// 2. Gradient of Gating Linear
// Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out
ComputeGatingLinearBackward<T>(ctx, config, query, fmha_out,
&gate_out_grad, query_grad,
&fmha_out_grad);
} else {
// 1. Gradient of Output Linear: out = Linear(fmha_grad)
ComputeOutputLinearBackward<T>(ctx, config, fmha_out, &fmha_out_grad);
} }
// 3. Gradient of FMHA // 3. Gradient of FMHA
if (nonbatched_bias_grad) { if (nonbatched_bias_grad) {
nonbatched_bias_grad->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "nonbatched_bias_grad",
nonbatched_bias_grad);
} }
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward( fmha_compute.ComputeBackward(
q_transpose_out, k_transpose_out, v_transpose_out, qkv_transpose_out, q_transpose_out, k_transpose_out, v_transpose_out, qkv_transpose_out,
softmax_out, fmha_out_grad, nullptr, nonbatched_bias_grad, &config); softmax_out, &fmha_out_grad, nullptr, nonbatched_bias_grad, &config);
bool use_addto = has_gating ? true : false; bool use_addto = has_gating ? true : false;
if (merge_qkv) { if (merge_qkv) {
// 4. Gradient of Merged QKV Matmul // 4. Gradient of Merged QKV Matmul
Tensor *qkv_out_grad = config.GetQKVOutGrad(dev_ctx); Tensor *qkv_out_grad = config.GetQKVOutGrad();
ComputeMergedQKVMatmulBackward<T>(ctx, config, query, qkv_out_grad, ComputeMergedQKVMatmulBackward<T>(ctx, config, query, qkv_out_grad,
query_grad, use_addto); query_grad, use_addto);
} else { } else {
// 4. Gradient of Separated QKV Matmul // 4. Gradient of Separated QKV Matmul
auto *key_grad = ctx.Output<Tensor>(framework::GradVarName("Key")); auto *key_grad = ctx.Output<Tensor>(framework::GradVarName("Key"));
if (key_grad) { if (key_grad) {
key_grad->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "key_grad", key_grad);
} }
Tensor *query_out_grad = config.GetQueryOutGrad(dev_ctx); Tensor *query_out_grad = config.GetQueryOutGrad();
Tensor *key_out_grad = config.GetKeyOutGrad(dev_ctx); Tensor *key_out_grad = config.GetKeyOutGrad();
Tensor *value_out_grad = config.GetValueOutGrad(dev_ctx); Tensor *value_out_grad = config.GetValueOutGrad();
ComputeSeparatedQKVMatmulBackward<T>( ComputeSeparatedQKVMatmulBackward<T>(
ctx, config, query, key, query_out_grad, key_out_grad, value_out_grad, ctx, config, query, key, query_out_grad, key_out_grad, value_out_grad,
query_grad, key_grad, use_addto); query_grad, key_grad, use_addto);
......
...@@ -18,7 +18,7 @@ import paddle ...@@ -18,7 +18,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle import tensor from paddle import tensor
import unittest import unittest
from op_test import OpTest, convert_float_to_uint16 from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float
from test_sparse_attention_op import get_cuda_version from test_sparse_attention_op import get_cuda_version
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph
...@@ -194,23 +194,36 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -194,23 +194,36 @@ class TestFusedGateAttentionOp(OpTest):
return out, query.grad, None return out, query.grad, None
def check_output_and_grad(self, atol, rtol): def check_output_and_grad(self, atol, rtol):
out_ref, query_grad_ref, key_grad_ref = self.get_reference_out()
out, query_grad, key_grad = self.get_fused_gate_attention_out() def _convert(value):
np.testing.assert_allclose(out_ref, out.numpy(), atol=atol, rtol=rtol) if self.dtype == "bfloat16":
np.testing.assert_allclose(query_grad_ref, return convert_uint16_to_float(value)
query_grad.numpy(), return value
atol=atol,
rtol=rtol) output_names = ["out", "query_grad", "key_grad"]
if key_grad_ref is not None and key_grad is not None: outputs_ref = self.get_reference_out()
np.testing.assert_allclose(key_grad_ref, outputs_fused = self.get_fused_gate_attention_out()
key_grad.numpy(), for i in range(len(outputs_fused)):
atol=atol, ref_res = outputs_ref[i]
rtol=rtol) fused_res = outputs_fused[i]
if ref_res is not None and fused_res is not None:
print("Checking {}".format(output_names[i]))
np.testing.assert_allclose(_convert(ref_res),
_convert(fused_res.numpy()),
atol=atol,
rtol=rtol)
def test_output_and_grad(self): def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-5, rtol=1e-5) self.check_output_and_grad(atol=1e-5, rtol=1e-5)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
def config(self):
super().config()
self.batch_size = 2
class TestSeparatedQKVCase(TestFusedGateAttentionOp): class TestSeparatedQKVCase(TestFusedGateAttentionOp):
def config(self): def config(self):
...@@ -243,7 +256,16 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp): ...@@ -243,7 +256,16 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
self.dtype = "float16" self.dtype = "float16"
def test_output_and_grad(self): def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-1, rtol=1e-5) place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_and_grad(atol=1e-1, rtol=1e-5)
class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case):
def config(self):
super().config()
self.batch_size = 2
@unittest.skipIf( @unittest.skipIf(
...@@ -260,5 +282,12 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp): ...@@ -260,5 +282,12 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
self.check_output_and_grad(atol=1e-1, rtol=1e-3) self.check_output_and_grad(atol=1e-1, rtol=1e-3)
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
def config(self):
super().config()
self.batch_size = 2
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册