未验证 提交 10f8637c 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix wrong reduce_dims in fused_gate_attention and optimize the memory usage. (#43216)

* Polish codes and memory usage for fused_gate_attention.

* Fix wrong reduce_dims in fused_gate_attention when computing gradient of nonbatched_bias.
上级 5a6c974e
...@@ -14,11 +14,11 @@ limitations under the License. */ ...@@ -14,11 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle { namespace paddle {
...@@ -27,19 +27,29 @@ namespace operators { ...@@ -27,19 +27,29 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
inline std::string MemoryDebugString(const Tensor& t) { inline std::string MemoryDebugString(const Tensor& t) {
int device_id = platform::GetCurrentDeviceId();
int64_t allocated =
memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
int64_t reserved =
memory::DeviceMemoryStatCurrentValue("Reserved", device_id);
std::stringstream ss; std::stringstream ss;
ss << "shape=[" << t.dims() ss << "shape=[" << t.dims()
<< "], size=" << static_cast<float>(t.memory_size()) / (1 << 20) << "], size=" << static_cast<float>(t.memory_size()) / (1 << 20)
<< " MB, ptr=" << t.data(); << " MB, ptr=" << t.data()
<< "; [MEMORY] allocated=" << static_cast<float>(allocated) / (1 << 20)
size_t total = 0; << " MB"
size_t available = 0; << ", reserved=" << static_cast<float>(reserved) / (1 << 20) << " MB";
platform::GpuMemoryUsage(&available, &total);
ss << "; memory allocated="
<< static_cast<float>(total - available) / (1 << 20) << " MB";
return ss.str(); return ss.str();
} }
template <typename T>
void AllocWithDebugInfo(const platform::CUDADeviceContext& dev_ctx,
const std::string& info, Tensor* t) {
t->mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << info << ": " << MemoryDebugString(*t);
}
template <typename T> template <typename T>
struct TernaryAddFunctor { struct TernaryAddFunctor {
inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; } inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
...@@ -48,6 +58,11 @@ struct TernaryAddFunctor { ...@@ -48,6 +58,11 @@ struct TernaryAddFunctor {
template <typename T> template <typename T>
struct GateAttentionConfig { struct GateAttentionConfig {
public: public:
const platform::CUDADeviceContext& dev_ctx;
bool merge_qkv;
bool has_gating;
int64_t batch_size; int64_t batch_size;
int64_t seq_len_m; int64_t seq_len_m;
int64_t seq_len_r; int64_t seq_len_r;
...@@ -70,9 +85,11 @@ struct GateAttentionConfig { ...@@ -70,9 +85,11 @@ struct GateAttentionConfig {
phi::DDim qktv_out_dims; phi::DDim qktv_out_dims;
phi::DDim gate_out_dims; phi::DDim gate_out_dims;
GateAttentionConfig(const Tensor* query, const Tensor* key, GateAttentionConfig(const platform::CUDADeviceContext& dev_ctx,
const Tensor* query, const Tensor* key,
const Tensor* query_weight, const Tensor* qkv_weight, const Tensor* query_weight, const Tensor* qkv_weight,
bool merge_qkv) { bool merge_qkv, bool has_gating)
: dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) {
// query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
batch_size = query->dims()[0]; batch_size = query->dims()[0];
seq_len_m = query->dims()[1]; seq_len_m = query->dims()[1];
...@@ -131,59 +148,68 @@ struct GateAttentionConfig { ...@@ -131,59 +148,68 @@ struct GateAttentionConfig {
return batch_size * seq_len_m * seq_len_r * num_heads * key_dim; return batch_size * seq_len_m * seq_len_r * num_heads * key_dim;
} }
Tensor* GetQKVOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQKVOut() {
if (!qkv_out.IsInitialized()) { if (!qkv_out.IsInitialized()) {
qkv_out.Resize(qkv_out_dims); qkv_out.Resize(qkv_out_dims);
qkv_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "qkv_out", &qkv_out);
VLOG(4) << "qkv_out: " << MemoryDebugString(qkv_out);
} }
return &qkv_out; return &qkv_out;
} }
Tensor* GetQueryOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQueryOut() {
if (!query_out.IsInitialized()) { if (!query_out.IsInitialized()) {
query_out.Resize(q_out_dims); query_out.Resize(q_out_dims);
query_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "query_out", &query_out);
VLOG(4) << "query_out: " << MemoryDebugString(query_out);
} }
return &query_out; return &query_out;
} }
Tensor* GetKeyOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetKeyOut() {
if (!key_out.IsInitialized()) { if (!key_out.IsInitialized()) {
key_out.Resize(kv_out_dims); key_out.Resize(kv_out_dims);
key_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "key_out", &key_out);
VLOG(4) << "key_out: " << MemoryDebugString(key_out);
} }
return &key_out; return &key_out;
} }
Tensor* GetValueOut(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetValueOut() {
if (!value_out.IsInitialized()) { if (!value_out.IsInitialized()) {
value_out.Resize(kv_out_dims); value_out.Resize(kv_out_dims);
value_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "value_out", &value_out);
VLOG(4) << "value_out: " << MemoryDebugString(value_out);
} }
return &value_out; return &value_out;
} }
Tensor* GetQKOut(const platform::CUDADeviceContext& dev_ctx, Tensor* GetQKOut(Tensor* softmax_out) {
Tensor* softmax_out) {
// softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1] // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
int softmax_dim = m_size; int softmax_dim = m_size;
if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) { if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
// Not sure whether cudnn softmax can execute inplace. // Not sure whether cudnn softmax can execute inplace.
if (!qkv_out.IsInitialized()) { if (!qkv_out.IsInitialized()) {
qk_out.Resize(qk_out_dims); qk_out.Resize(qk_out_dims);
qk_out.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "qk_out", &qk_out);
VLOG(4) << "qk_out: " << MemoryDebugString(qk_out);
} }
return &qk_out; return &qk_out;
} else { } else {
// Enable inplace softmax.
return softmax_out; return softmax_out;
} }
} }
Tensor* GetQKTVOut(Tensor* gate_out) {
if (has_gating && gate_out) {
// Reuse gate_out.
gate_out->Resize(qktv_out_dims);
return gate_out;
} else {
if (!qktv_out.IsInitialized()) {
qktv_out.Resize(qktv_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "qktv_out", &qktv_out);
}
return &qktv_out;
}
}
void ClearQKVOut() { void ClearQKVOut() {
if (qkv_out.IsInitialized()) { if (qkv_out.IsInitialized()) {
qkv_out.clear(); qkv_out.clear();
...@@ -196,9 +222,14 @@ struct GateAttentionConfig { ...@@ -196,9 +222,14 @@ struct GateAttentionConfig {
} }
} }
void ClearQKTVOut() {
if (qktv_out.IsInitialized()) {
qktv_out.clear();
}
}
protected: protected:
Tensor qkv_out; Tensor qkv_out;
// QKV is not merged
Tensor query_out; Tensor query_out;
Tensor key_out; Tensor key_out;
Tensor value_out; Tensor value_out;
...@@ -207,63 +238,60 @@ struct GateAttentionConfig { ...@@ -207,63 +238,60 @@ struct GateAttentionConfig {
// softmax_out = softmax(qk_out + nonbatched_bias + src_mask) // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
// The shape of qk_out, softmax_out is the same, thus can be called inplace. // The shape of qk_out, softmax_out is the same, thus can be called inplace.
Tensor qk_out; Tensor qk_out;
// qktv_out may reuse gate_out.
Tensor qktv_out;
}; };
template <typename T> template <typename T>
struct GateAttentionGradConfig : public GateAttentionConfig<T> { struct GateAttentionGradConfig : public GateAttentionConfig<T> {
public: public:
GateAttentionGradConfig(const Tensor* query, const Tensor* key, GateAttentionGradConfig(const platform::CUDADeviceContext& dev_ctx,
const Tensor* query, const Tensor* key,
const Tensor* query_weight, const Tensor* qkv_weight, const Tensor* query_weight, const Tensor* qkv_weight,
bool merge_qkv) bool merge_qkv, bool has_gating)
: GateAttentionConfig<T>(query, key, query_weight, qkv_weight, : GateAttentionConfig<T>(dev_ctx, query, key, query_weight, qkv_weight,
merge_qkv) {} merge_qkv, has_gating) {}
Tensor* GetQKVOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQKVOutGrad() {
if (!qkv_out_grad.IsInitialized()) { if (!qkv_out_grad.IsInitialized()) {
qkv_out_grad.Resize(this->qkv_out_dims); qkv_out_grad.Resize(this->qkv_out_dims);
qkv_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "qkv_out_grad", &qkv_out_grad);
VLOG(4) << "qkv_out_grad: " << MemoryDebugString(qkv_out_grad);
} }
return &qkv_out_grad; return &qkv_out_grad;
} }
Tensor* GetQueryOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetQueryOutGrad() {
if (!query_out_grad.IsInitialized()) { if (!query_out_grad.IsInitialized()) {
query_out_grad.Resize(this->q_out_dims); query_out_grad.Resize(this->q_out_dims);
query_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "query_out_grad", &query_out_grad);
VLOG(4) << "query_out_grad: " << MemoryDebugString(query_out_grad);
} }
return &query_out_grad; return &query_out_grad;
} }
Tensor* GetKeyOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetKeyOutGrad() {
if (!key_out_grad.IsInitialized()) { if (!key_out_grad.IsInitialized()) {
key_out_grad.Resize(this->kv_out_dims); key_out_grad.Resize(this->kv_out_dims);
key_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "key_out_grad", &key_out_grad);
VLOG(4) << "key_out_grad: " << MemoryDebugString(key_out_grad);
} }
return &key_out_grad; return &key_out_grad;
} }
Tensor* GetValueOutGrad(const platform::CUDADeviceContext& dev_ctx) { Tensor* GetValueOutGrad() {
if (!value_out_grad.IsInitialized()) { if (!value_out_grad.IsInitialized()) {
value_out_grad.Resize(this->kv_out_dims); value_out_grad.Resize(this->kv_out_dims);
value_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "value_out_grad", &value_out_grad);
VLOG(4) << "value_out_grad: " << MemoryDebugString(value_out_grad);
} }
return &value_out_grad; return &value_out_grad;
} }
Tensor* GetQKOutGrad(const platform::CUDADeviceContext& dev_ctx, Tensor* GetQKOutGrad(Tensor* softmax_out_grad) {
Tensor* softmax_out_grad) {
// softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1] // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
int softmax_dim = this->m_size; int softmax_dim = this->m_size;
if (!softmax_out_grad || if (!softmax_out_grad ||
phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) { phi::UseCudnnSoftmax<T>(this->dev_ctx, softmax_dim, true)) {
if (!qk_out_grad.IsInitialized()) { if (!qk_out_grad.IsInitialized()) {
qk_out_grad.Resize(this->qk_out_dims); qk_out_grad.Resize(this->qk_out_dims);
qk_out_grad.mutable_data<T>(dev_ctx.GetPlace()); AllocWithDebugInfo<T>(this->dev_ctx, "qk_out_grad", &qk_out_grad);
VLOG(4) << "qk_out_grad: " << MemoryDebugString(qk_out_grad);
} }
return &qk_out_grad; return &qk_out_grad;
} else { } else {
...@@ -288,7 +316,7 @@ class FMHAGateRef { ...@@ -288,7 +316,7 @@ class FMHAGateRef {
void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask, void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask,
Tensor* q_transpose_out, Tensor* k_transpose_out, Tensor* q_transpose_out, Tensor* k_transpose_out,
Tensor* v_transpose_out, Tensor* qkv_transpose_out, Tensor* v_transpose_out, Tensor* qkv_transpose_out,
Tensor* softmax_out, Tensor* fmha_out, Tensor* softmax_out, Tensor* fmha_out, Tensor* gate_out,
GateAttentionConfig<T>* config) { GateAttentionConfig<T>* config) {
T* q_ptr = nullptr; T* q_ptr = nullptr;
T* k_ptr = nullptr; T* k_ptr = nullptr;
...@@ -300,7 +328,7 @@ class FMHAGateRef { ...@@ -300,7 +328,7 @@ class FMHAGateRef {
platform::errors::NotFound("The input qkv_transpose_out can not be " platform::errors::NotFound("The input qkv_transpose_out can not be "
"nullptr when merge_qkv is true.")); "nullptr when merge_qkv is true."));
Tensor* qkv_out = config->GetQKVOut(dev_ctx_); Tensor* qkv_out = config->GetQKVOut();
ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out); ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
config->ClearQKVOut(); config->ClearQKVOut();
...@@ -323,9 +351,9 @@ class FMHAGateRef { ...@@ -323,9 +351,9 @@ class FMHAGateRef {
platform::errors::NotFound("The input v_transpose_out can not be " platform::errors::NotFound("The input v_transpose_out can not be "
"nullptr when merge_qkv is false.")); "nullptr when merge_qkv is false."));
Tensor* query_out = config->GetQueryOut(dev_ctx_); Tensor* query_out = config->GetQueryOut();
Tensor* key_out = config->GetKeyOut(dev_ctx_); Tensor* key_out = config->GetKeyOut();
Tensor* value_out = config->GetValueOut(dev_ctx_); Tensor* value_out = config->GetValueOut();
ComputeQKVTransposeForward(*query_out, *key_out, *value_out, ComputeQKVTransposeForward(*query_out, *key_out, *value_out,
q_transpose_out, k_transpose_out, q_transpose_out, k_transpose_out,
v_transpose_out); v_transpose_out);
...@@ -340,7 +368,7 @@ class FMHAGateRef { ...@@ -340,7 +368,7 @@ class FMHAGateRef {
// [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] * // [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim] // [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size] // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
Tensor* qk_out = config->GetQKOut(dev_ctx_, softmax_out); Tensor* qk_out = config->GetQKOut(softmax_out);
T* qk_out_ptr = qk_out->data<T>(); T* qk_out_ptr = qk_out->data<T>();
int64_t gemm_batch_size = int64_t gemm_batch_size =
...@@ -362,9 +390,8 @@ class FMHAGateRef { ...@@ -362,9 +390,8 @@ class FMHAGateRef {
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim] // [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim]
Tensor qktv_out; Tensor* qktv_out = config->GetQKTVOut(gate_out);
qktv_out.Resize(config->qktv_out_dims); T* qktv_out_ptr = qktv_out->data<T>();
T* qktv_out_ptr = qktv_out.mutable_data<T>(dev_ctx_.GetPlace());
gemm_m = config->seq_len_r; gemm_m = config->seq_len_r;
gemm_n = config->key_dim; gemm_n = config->key_dim;
...@@ -375,7 +402,11 @@ class FMHAGateRef { ...@@ -375,7 +402,11 @@ class FMHAGateRef {
gemm_m, gemm_n, gemm_k, gemm_batch_size); gemm_m, gemm_n, gemm_k, gemm_batch_size);
// fmha_out = transpose(qktv_out) // fmha_out = transpose(qktv_out)
ComputeQKTVTransposeForward(qktv_out, fmha_out); ComputeQKTVTransposeForward(*qktv_out, fmha_out);
config->ClearQKTVOut();
if (config->has_gating) {
gate_out->Resize(config->gate_out_dims);
}
} }
void ComputeBackward(const Tensor* q_transpose_out, void ComputeBackward(const Tensor* q_transpose_out,
...@@ -409,8 +440,10 @@ class FMHAGateRef { ...@@ -409,8 +440,10 @@ class FMHAGateRef {
v_ptr = k_ptr + q_size; v_ptr = k_ptr + q_size;
qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims); qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims);
AllocWithDebugInfo<T>(dev_ctx_, "qkv_transpose_out_grad",
&qkv_transpose_out_grad);
q_grad_ptr = qkv_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); q_grad_ptr = qkv_transpose_out_grad.data<T>();
k_grad_ptr = q_grad_ptr + q_size; k_grad_ptr = q_grad_ptr + q_size;
v_grad_ptr = k_grad_ptr + q_size; v_grad_ptr = k_grad_ptr + q_size;
} else { } else {
...@@ -442,7 +475,7 @@ class FMHAGateRef { ...@@ -442,7 +475,7 @@ class FMHAGateRef {
Tensor softmax_out_grad; Tensor softmax_out_grad;
softmax_out_grad.Resize(config->softmax_out_dims); softmax_out_grad.Resize(config->softmax_out_dims);
softmax_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); AllocWithDebugInfo<T>(dev_ctx_, "softmax_out_grad", &softmax_out_grad);
int64_t gemm_batch_size = int64_t gemm_batch_size =
config->batch_size * config->seq_len_m * config->num_heads; config->batch_size * config->seq_len_m * config->num_heads;
...@@ -450,7 +483,7 @@ class FMHAGateRef { ...@@ -450,7 +483,7 @@ class FMHAGateRef {
// Forward: fmha_out = transpose(qktv_out) // Forward: fmha_out = transpose(qktv_out)
Tensor qktv_out_grad; Tensor qktv_out_grad;
qktv_out_grad.Resize(config->qktv_out_dims); qktv_out_grad.Resize(config->qktv_out_dims);
T* qktv_out_grad_ptr = qktv_out_grad.mutable_data<T>(dev_ctx_.GetPlace()); AllocWithDebugInfo<T>(dev_ctx_, "qktv_out_grad", &qktv_out_grad);
ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad); ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad);
// Forward: qktv_out = BatchedGEMM(softmax_out, V) // Forward: qktv_out = BatchedGEMM(softmax_out, V)
...@@ -461,6 +494,7 @@ class FMHAGateRef { ...@@ -461,6 +494,7 @@ class FMHAGateRef {
int64_t gemm_k = config->seq_len_r; int64_t gemm_k = config->seq_len_r;
const T* softmax_out_ptr = softmax_out->data<T>(); const T* softmax_out_ptr = softmax_out->data<T>();
const T* qktv_out_grad_ptr = qktv_out_grad.data<T>();
ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true, ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true,
false, gemm_m, gemm_n, gemm_k, gemm_batch_size); false, gemm_m, gemm_n, gemm_k, gemm_batch_size);
...@@ -474,7 +508,7 @@ class FMHAGateRef { ...@@ -474,7 +508,7 @@ class FMHAGateRef {
true, gemm_m, gemm_n, gemm_k, gemm_batch_size); true, gemm_m, gemm_n, gemm_k, gemm_batch_size);
} }
Tensor* qk_out_grad = config->GetQKOutGrad(dev_ctx_, &softmax_out_grad); Tensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad);
ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out, ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out,
src_mask_grad, qk_out_grad, src_mask_grad, qk_out_grad,
nonbatched_bias_grad); nonbatched_bias_grad);
...@@ -498,12 +532,12 @@ class FMHAGateRef { ...@@ -498,12 +532,12 @@ class FMHAGateRef {
gemm_n, gemm_k, gemm_batch_size, alpha); gemm_n, gemm_k, gemm_batch_size, alpha);
if (merge_qkv_) { if (merge_qkv_) {
Tensor* qkv_out_grad = config->GetQKVOutGrad(dev_ctx_); Tensor* qkv_out_grad = config->GetQKVOutGrad();
ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad); ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
} else { } else {
Tensor* q_out_grad = config->GetQueryOutGrad(dev_ctx_); Tensor* q_out_grad = config->GetQueryOutGrad();
Tensor* k_out_grad = config->GetKeyOutGrad(dev_ctx_); Tensor* k_out_grad = config->GetKeyOutGrad();
Tensor* v_out_grad = config->GetValueOutGrad(dev_ctx_); Tensor* v_out_grad = config->GetValueOutGrad();
ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad, ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad,
v_transpose_out_grad, q_out_grad, k_out_grad, v_transpose_out_grad, q_out_grad, k_out_grad,
v_out_grad); v_out_grad);
...@@ -578,12 +612,12 @@ class FMHAGateRef { ...@@ -578,12 +612,12 @@ class FMHAGateRef {
if (nonbatched_bias) { if (nonbatched_bias) {
std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask}; std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask};
std::vector<Tensor*> outs = {qk_out}; std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>( phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>()); dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
} else { } else {
std::vector<const Tensor*> ins = {qk_out, src_mask}; std::vector<const Tensor*> ins = {qk_out, src_mask};
std::vector<Tensor*> outs = {qk_out}; std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
} }
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out); phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
...@@ -614,12 +648,12 @@ class FMHAGateRef { ...@@ -614,12 +648,12 @@ class FMHAGateRef {
phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, *softmax_out, phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, *softmax_out,
*softmax_out_grad, -1, qk_out_grad); *softmax_out_grad, -1, qk_out_grad);
// [1, bs, num_head, seq_l, seq_l] -> [bs, num_head, seq_l, seq_l]
if (nonbatched_bias_grad) { if (nonbatched_bias_grad) {
gpuStream_t stream = dev_ctx_.stream(); // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] ->
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( // [batch_size, 1, num_heads, seq_len_r, m_size]
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_, *qk_out_grad, nonbatched_bias_grad, dev_ctx_, *qk_out_grad, nonbatched_bias_grad,
kps::IdentityFunctor<T>(), {0, 1}, stream); kps::IdentityFunctor<T>(), {1});
} }
} }
......
...@@ -214,7 +214,7 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -214,7 +214,7 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
"fused_aate_attention_arad"); "fused_aate_attention_arad");
if (ctx->Attrs().Get<bool>("has_gating")) { if (ctx->Attrs().Get<bool>("has_gating")) {
for (auto& name : {"GateWeight", "GateBias", "GateOut"}) { for (auto& name : {"GateWeight", "GateBias"}) {
ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name)); ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name));
} }
} }
...@@ -224,9 +224,6 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -224,9 +224,6 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("NonbatchedBias")); ctx->GetInputDim("NonbatchedBias"));
} }
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearWeight"), ctx->SetOutputDim(framework::GradVarName("OutLinearWeight"),
ctx->GetInputDim("OutLinearWeight")); ctx->GetInputDim("OutLinearWeight"));
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
...@@ -270,8 +267,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -270,8 +267,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
if (this->HasInput("NonbatchedBias")) { if (this->HasInput("NonbatchedBias")) {
op->SetInput("NonbatchedBias", this->Input("NonbatchedBias")); op->SetInput("NonbatchedBias", this->Input("NonbatchedBias"));
...@@ -292,8 +287,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -292,8 +287,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("GateBias")); this->InputGrad("GateBias"));
op->SetInput("GateOut", this->Output("GateOut")); op->SetInput("GateOut", this->Output("GateOut"));
op->SetOutput(framework::GradVarName("GateOut"),
this->OutputGrad("GateOut"));
} }
op->SetInput("OutLinearWeight", this->Input("OutLinearWeight")); op->SetInput("OutLinearWeight", this->Input("OutLinearWeight"));
......
...@@ -79,7 +79,7 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -79,7 +79,7 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
} }
template <typename T> template <typename T>
Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
const Tensor *query, const Tensor *query,
const Tensor *qkv_out_grad, const Tensor *qkv_out_grad,
...@@ -97,7 +97,6 @@ Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -97,7 +97,6 @@ Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false); AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, query_grad, qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, query_grad,
qkv_weight_grad, nullptr, use_addto); qkv_weight_grad, nullptr, use_addto);
return query_grad;
} }
template <typename T> template <typename T>
...@@ -137,11 +136,13 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -137,11 +136,13 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
} }
template <typename T> template <typename T>
Tensor *ComputeSeparatedQKVMatmulBackward( void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const framework::ExecutionContext &ctx, const GateAttentionGradConfig<T> &config,
const GateAttentionGradConfig<T> &config, const Tensor *query, const Tensor *query, const Tensor *key,
const Tensor *key, const Tensor *query_out_grad, const Tensor *key_out_grad, const Tensor *query_out_grad,
const Tensor *value_out_grad, Tensor *query_grad, Tensor *key_grad, const Tensor *key_out_grad,
const Tensor *value_out_grad,
Tensor *query_grad, Tensor *key_grad,
bool use_addto) { bool use_addto) {
// Gradient of GEMM(key, k_weight) // Gradient of GEMM(key, k_weight)
const auto *key_weight = ctx.Input<Tensor>("KeyWeight"); const auto *key_weight = ctx.Input<Tensor>("KeyWeight");
...@@ -179,22 +180,16 @@ Tensor *ComputeSeparatedQKVMatmulBackward( ...@@ -179,22 +180,16 @@ Tensor *ComputeSeparatedQKVMatmulBackward(
q_n, q_k, false); q_n, q_k, false);
q_compute.ComputeBackward(query, query_weight, query_out_grad, query_grad, q_compute.ComputeBackward(query, query_weight, query_out_grad, query_grad,
query_weight_grad, nullptr, use_addto); query_weight_grad, nullptr, use_addto);
return query_grad;
} }
template <typename T> template <typename T>
Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx, void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const Tensor *query, const Tensor *query, const Tensor *fmha_out,
const Tensor *fmha_out) { Tensor *gate_out) {
auto *gate_weight = ctx.Input<Tensor>("GateWeight"); auto *gate_weight = ctx.Input<Tensor>("GateWeight");
auto *gate_bias = ctx.Input<Tensor>("GateBias"); auto *gate_bias = ctx.Input<Tensor>("GateBias");
auto *gate_out = ctx.Output<Tensor>("GateOut");
gate_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "[ComputeGatingLinearForward] gate_out: "
<< MemoryDebugString(*gate_out);
// The first gate_bias_out stores the result of the multiplication, // The first gate_bias_out stores the result of the multiplication,
// and the second gate_bias_out stores the result of the multiplication + // and the second gate_bias_out stores the result of the multiplication +
// bias. // bias.
...@@ -212,16 +207,14 @@ Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx, ...@@ -212,16 +207,14 @@ Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
std::vector<Tensor *> outs = {gate_out}; std::vector<Tensor *> outs = {gate_out};
phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs, phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
SigmoidMultiplyFunctor<T>()); SigmoidMultiplyFunctor<T>());
return gate_out;
} }
template <typename T> template <typename T>
Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
const Tensor *fmha_out, const Tensor *query, const Tensor *fmha_out,
const Tensor *gate_out_grad, const Tensor *gate_out_grad,
Tensor *query_grad, Tensor *fmha_out_grad) { Tensor *query_grad, Tensor *fmha_out_grad) {
const auto *query = ctx.Input<Tensor>("Query");
const auto *gate_weight = ctx.Input<Tensor>("GateWeight"); const auto *gate_weight = ctx.Input<Tensor>("GateWeight");
const auto *gate_bias = ctx.Input<Tensor>("GateBias"); const auto *gate_bias = ctx.Input<Tensor>("GateBias");
...@@ -255,20 +248,15 @@ Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -255,20 +248,15 @@ Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
gate_attn_compute.ComputeBackward(query, gate_weight, &gate_bias_out, gate_attn_compute.ComputeBackward(query, gate_weight, &gate_bias_out,
query_grad, gate_weight_grad, query_grad, gate_weight_grad,
gate_bias_grad); gate_bias_grad);
return fmha_out_grad;
} }
template <typename T> template <typename T>
Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx, void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const Tensor *fmha_or_gate_out) { const Tensor *fmha_or_gate_out, Tensor *out) {
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight"); const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
const auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias"); const auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "[ComputeOutputLinearForward] out: " << MemoryDebugString(*out);
// out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
...@@ -277,28 +265,24 @@ Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx, ...@@ -277,28 +265,24 @@ Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out, out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out,
out_linear_bias, out, out); out_linear_bias, out, out);
return out;
} }
template <typename T> template <typename T>
Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const GateAttentionGradConfig<T> &config,
bool has_gating) { const Tensor *input, Tensor *input_grad) {
std::string input_name = has_gating ? "GateOut" : "FMHAOut";
const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight"); const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
const auto *input = ctx.Input<Tensor>(input_name);
auto *out_linear_weight_grad = auto *out_linear_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearWeight")); ctx.Output<Tensor>(framework::GradVarName("OutLinearWeight"));
auto *out_linear_bias_grad = auto *out_linear_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias")); ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *input_grad = ctx.Output<Tensor>(framework::GradVarName(input_name));
out_linear_weight_grad->mutable_data<T>(ctx.GetPlace()); out_linear_weight_grad->mutable_data<T>(ctx.GetPlace());
out_linear_bias_grad->mutable_data<T>(ctx.GetPlace()); out_linear_bias_grad->mutable_data<T>(ctx.GetPlace());
input_grad->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
...@@ -308,7 +292,6 @@ Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -308,7 +292,6 @@ Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad, out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad,
input_grad, out_linear_weight_grad, input_grad, out_linear_weight_grad,
out_linear_bias_grad); out_linear_bias_grad);
return input_grad;
} }
template <typename T> template <typename T>
...@@ -330,56 +313,64 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -330,56 +313,64 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut"); auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
auto *fmha_out = ctx.Output<Tensor>("FMHAOut"); auto *fmha_out = ctx.Output<Tensor>("FMHAOut");
auto *gate_out = ctx.Output<Tensor>("GateOut");
auto *out = ctx.Output<Tensor>("Out");
const bool merge_qkv = ctx.Attr<bool>("merge_qkv"); const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating"); const bool has_gating = ctx.Attr<bool>("has_gating");
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
GateAttentionConfig<T> config(query, key, query_weight, qkv_weight, AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
merge_qkv); AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
if (has_gating) {
AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
}
AllocWithDebugInfo<T>(dev_ctx, "out", out);
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
GateAttentionConfig<T> config(dev_ctx, query, key, query_weight, qkv_weight,
merge_qkv, has_gating);
if (merge_qkv) { if (merge_qkv) {
PADDLE_ENFORCE_EQ(!key || query == key, true,
platform::errors::InvalidArgument(
"key is expected to be nullptr or the same as "
"query, but recieved key=%p, query=%p.",
key, query));
// 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc) // 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc)
Tensor *qkv_out = config.GetQKVOut(dev_ctx); Tensor *qkv_out = config.GetQKVOut();
ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out); ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);
qkv_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
VLOG(4) << "qkv_transpose_out:" << MemoryDebugString(*qkv_transpose_out);
} else { } else {
// 1. Separated QKV Matmul // 1. Separated QKV Matmul
Tensor *query_out = config.GetQueryOut(dev_ctx); Tensor *query_out = config.GetQueryOut();
Tensor *key_out = config.GetKeyOut(dev_ctx); Tensor *key_out = config.GetKeyOut();
Tensor *value_out = config.GetValueOut(dev_ctx); Tensor *value_out = config.GetValueOut();
ComputeSeparatedQKVMatmulForward<T>(ctx, config, query, key, query_out, ComputeSeparatedQKVMatmulForward<T>(ctx, config, query, key, query_out,
key_out, value_out); key_out, value_out);
q_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "q_transpose_out", q_transpose_out);
k_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "k_transpose_out", k_transpose_out);
v_transpose_out->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "v_transpose_out", v_transpose_out);
VLOG(4) << "q_transpose_out: " << MemoryDebugString(*q_transpose_out);
VLOG(4) << "k_transpose_out: " << MemoryDebugString(*k_transpose_out);
VLOG(4) << "v_transpose_out: " << MemoryDebugString(*v_transpose_out);
} }
softmax_out->mutable_data<T>(ctx.GetPlace());
fmha_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "softmax_out: " << MemoryDebugString(*softmax_out);
VLOG(4) << "fmha_out: " << MemoryDebugString(*fmha_out);
// 2. FMHA // 2. FMHA
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward( fmha_compute.ComputeForward(nonbatched_bias, src_mask, q_transpose_out,
nonbatched_bias, src_mask, q_transpose_out, k_transpose_out, k_transpose_out, v_transpose_out,
v_transpose_out, qkv_transpose_out, softmax_out, fmha_out, &config); qkv_transpose_out, softmax_out, fmha_out,
gate_out, &config);
// 3. Gating Linear // 3. Gating Linear
Tensor *fmha_or_gate_out = !has_gating ? fmha_out if (has_gating) {
: ComputeGatingLinearForward<T>( ComputeGatingLinearForward<T>(ctx, config, query, fmha_out, gate_out);
ctx, config, query, fmha_out); }
// 4. Output Linear // 4. Output Linear
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out); Tensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out;
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out, out);
} }
}; };
...@@ -387,9 +378,6 @@ template <typename T> ...@@ -387,9 +378,6 @@ template <typename T>
class FusedGateAttentionGradKernel : public framework::OpKernel<T> { class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto has_gating = ctx.Attr<bool>("has_gating");
const auto merge_qkv = ctx.Attr<bool>("merge_qkv");
// forward input // forward input
const auto *query = ctx.Input<Tensor>("Query"); const auto *query = ctx.Input<Tensor>("Query");
const auto *key = ctx.Input<Tensor>("Key"); const auto *key = ctx.Input<Tensor>("Key");
...@@ -403,56 +391,68 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -403,56 +391,68 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
const auto *qkv_transpose_out = ctx.Input<Tensor>("QKVTransposeOut"); const auto *qkv_transpose_out = ctx.Input<Tensor>("QKVTransposeOut");
const auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut"); const auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
const auto *fmha_out = ctx.Input<Tensor>("FMHAOut"); const auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
const auto *gate_out = ctx.Input<Tensor>("GateOut");
// backward output // backward output
auto *query_grad = ctx.Output<Tensor>(framework::GradVarName("Query")); auto *query_grad = ctx.Output<Tensor>(framework::GradVarName("Query"));
query_grad->mutable_data<T>(ctx.GetPlace());
auto *nonbatched_bias_grad = auto *nonbatched_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("NonbatchedBias")); ctx.Output<Tensor>(framework::GradVarName("NonbatchedBias"));
auto *fmha_out_grad = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
GateAttentionGradConfig<T> config(query, key, query_weight, qkv_weight, AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
merge_qkv);
// 1. Gradient of Output Linear GateAttentionGradConfig<T> config(dev_ctx, query, key, query_weight,
Tensor *fhma_or_gate_out_grad = qkv_weight, merge_qkv, has_gating);
ComputeOutputLinearBackward<T>(ctx, config, has_gating);
// 2. Gradient of Gating Linear Tensor fmha_out_grad;
fmha_out_grad.Resize(config.gate_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out_grad", &fmha_out_grad);
if (has_gating) { if (has_gating) {
// fhma_or_gate_out_grad is actually gate_out_grad. // 1. Gradient of Output Linear: out = Linear(gate_out)
fmha_out_grad->mutable_data<T>(ctx.GetPlace()); Tensor gate_out_grad;
ComputeGatingLinearBackward<T>(ctx, config, fmha_out, gate_out_grad.Resize(config.gate_out_dims);
fhma_or_gate_out_grad, query_grad, AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad);
fmha_out_grad); ComputeOutputLinearBackward<T>(ctx, config, gate_out, &gate_out_grad);
// 2. Gradient of Gating Linear
// Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out
ComputeGatingLinearBackward<T>(ctx, config, query, fmha_out,
&gate_out_grad, query_grad,
&fmha_out_grad);
} else {
// 1. Gradient of Output Linear: out = Linear(fmha_grad)
ComputeOutputLinearBackward<T>(ctx, config, fmha_out, &fmha_out_grad);
} }
// 3. Gradient of FMHA // 3. Gradient of FMHA
if (nonbatched_bias_grad) { if (nonbatched_bias_grad) {
nonbatched_bias_grad->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "nonbatched_bias_grad",
nonbatched_bias_grad);
} }
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward( fmha_compute.ComputeBackward(
q_transpose_out, k_transpose_out, v_transpose_out, qkv_transpose_out, q_transpose_out, k_transpose_out, v_transpose_out, qkv_transpose_out,
softmax_out, fmha_out_grad, nullptr, nonbatched_bias_grad, &config); softmax_out, &fmha_out_grad, nullptr, nonbatched_bias_grad, &config);
bool use_addto = has_gating ? true : false; bool use_addto = has_gating ? true : false;
if (merge_qkv) { if (merge_qkv) {
// 4. Gradient of Merged QKV Matmul // 4. Gradient of Merged QKV Matmul
Tensor *qkv_out_grad = config.GetQKVOutGrad(dev_ctx); Tensor *qkv_out_grad = config.GetQKVOutGrad();
ComputeMergedQKVMatmulBackward<T>(ctx, config, query, qkv_out_grad, ComputeMergedQKVMatmulBackward<T>(ctx, config, query, qkv_out_grad,
query_grad, use_addto); query_grad, use_addto);
} else { } else {
// 4. Gradient of Separated QKV Matmul // 4. Gradient of Separated QKV Matmul
auto *key_grad = ctx.Output<Tensor>(framework::GradVarName("Key")); auto *key_grad = ctx.Output<Tensor>(framework::GradVarName("Key"));
if (key_grad) { if (key_grad) {
key_grad->mutable_data<T>(ctx.GetPlace()); AllocWithDebugInfo<T>(dev_ctx, "key_grad", key_grad);
} }
Tensor *query_out_grad = config.GetQueryOutGrad(dev_ctx); Tensor *query_out_grad = config.GetQueryOutGrad();
Tensor *key_out_grad = config.GetKeyOutGrad(dev_ctx); Tensor *key_out_grad = config.GetKeyOutGrad();
Tensor *value_out_grad = config.GetValueOutGrad(dev_ctx); Tensor *value_out_grad = config.GetValueOutGrad();
ComputeSeparatedQKVMatmulBackward<T>( ComputeSeparatedQKVMatmulBackward<T>(
ctx, config, query, key, query_out_grad, key_out_grad, value_out_grad, ctx, config, query, key, query_out_grad, key_out_grad, value_out_grad,
query_grad, key_grad, use_addto); query_grad, key_grad, use_addto);
......
...@@ -18,7 +18,7 @@ import paddle ...@@ -18,7 +18,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle import tensor from paddle import tensor
import unittest import unittest
from op_test import OpTest, convert_float_to_uint16 from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float
from test_sparse_attention_op import get_cuda_version from test_sparse_attention_op import get_cuda_version
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph
...@@ -194,16 +194,22 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -194,16 +194,22 @@ class TestFusedGateAttentionOp(OpTest):
return out, query.grad, None return out, query.grad, None
def check_output_and_grad(self, atol, rtol): def check_output_and_grad(self, atol, rtol):
out_ref, query_grad_ref, key_grad_ref = self.get_reference_out()
out, query_grad, key_grad = self.get_fused_gate_attention_out() def _convert(value):
np.testing.assert_allclose(out_ref, out.numpy(), atol=atol, rtol=rtol) if self.dtype == "bfloat16":
np.testing.assert_allclose(query_grad_ref, return convert_uint16_to_float(value)
query_grad.numpy(), return value
atol=atol,
rtol=rtol) output_names = ["out", "query_grad", "key_grad"]
if key_grad_ref is not None and key_grad is not None: outputs_ref = self.get_reference_out()
np.testing.assert_allclose(key_grad_ref, outputs_fused = self.get_fused_gate_attention_out()
key_grad.numpy(), for i in range(len(outputs_fused)):
ref_res = outputs_ref[i]
fused_res = outputs_fused[i]
if ref_res is not None and fused_res is not None:
print("Checking {}".format(output_names[i]))
np.testing.assert_allclose(_convert(ref_res),
_convert(fused_res.numpy()),
atol=atol, atol=atol,
rtol=rtol) rtol=rtol)
...@@ -211,6 +217,13 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -211,6 +217,13 @@ class TestFusedGateAttentionOp(OpTest):
self.check_output_and_grad(atol=1e-5, rtol=1e-5) self.check_output_and_grad(atol=1e-5, rtol=1e-5)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
def config(self):
super().config()
self.batch_size = 2
class TestSeparatedQKVCase(TestFusedGateAttentionOp): class TestSeparatedQKVCase(TestFusedGateAttentionOp):
def config(self): def config(self):
...@@ -243,9 +256,18 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp): ...@@ -243,9 +256,18 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
self.dtype = "float16" self.dtype = "float16"
def test_output_and_grad(self): def test_output_and_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_and_grad(atol=1e-1, rtol=1e-5) self.check_output_and_grad(atol=1e-1, rtol=1e-5)
class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case):
def config(self):
super().config()
self.batch_size = 2
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11000, not core.is_compiled_with_cuda() or get_cuda_version() < 11000,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" "core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
...@@ -260,5 +282,12 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp): ...@@ -260,5 +282,12 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
self.check_output_and_grad(atol=1e-1, rtol=1e-3) self.check_output_and_grad(atol=1e-1, rtol=1e-3)
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
def config(self):
super().config()
self.batch_size = 2
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册