未验证 提交 10f8637c 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix wrong reduce_dims in fused_gate_attention and optimize the memory usage. (#43216)

* Polish codes and memory usage for fused_gate_attention.

* Fix wrong reduce_dims in fused_gate_attention when computing gradient of nonbatched_bias.
上级 5a6c974e
......@@ -14,11 +14,11 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......@@ -27,19 +27,29 @@ namespace operators {
using Tensor = framework::Tensor;
inline std::string MemoryDebugString(const Tensor& t) {
int device_id = platform::GetCurrentDeviceId();
int64_t allocated =
memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
int64_t reserved =
memory::DeviceMemoryStatCurrentValue("Reserved", device_id);
std::stringstream ss;
ss << "shape=[" << t.dims()
<< "], size=" << static_cast<float>(t.memory_size()) / (1 << 20)
<< " MB, ptr=" << t.data();
size_t total = 0;
size_t available = 0;
platform::GpuMemoryUsage(&available, &total);
ss << "; memory allocated="
<< static_cast<float>(total - available) / (1 << 20) << " MB";
<< " MB, ptr=" << t.data()
<< "; [MEMORY] allocated=" << static_cast<float>(allocated) / (1 << 20)
<< " MB"
<< ", reserved=" << static_cast<float>(reserved) / (1 << 20) << " MB";
return ss.str();
}
template <typename T>
void AllocWithDebugInfo(const platform::CUDADeviceContext& dev_ctx,
const std::string& info, Tensor* t) {
t->mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << info << ": " << MemoryDebugString(*t);
}
template <typename T>
struct TernaryAddFunctor {
inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
......@@ -48,6 +58,11 @@ struct TernaryAddFunctor {
template <typename T>
struct GateAttentionConfig {
public:
const platform::CUDADeviceContext& dev_ctx;
bool merge_qkv;
bool has_gating;
int64_t batch_size;
int64_t seq_len_m;
int64_t seq_len_r;
......@@ -70,9 +85,11 @@ struct GateAttentionConfig {
phi::DDim qktv_out_dims;
phi::DDim gate_out_dims;
GateAttentionConfig(const Tensor* query, const Tensor* key,
GateAttentionConfig(const platform::CUDADeviceContext& dev_ctx,
const Tensor* query, const Tensor* key,
const Tensor* query_weight, const Tensor* qkv_weight,
bool merge_qkv) {
bool merge_qkv, bool has_gating)
: dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) {
// query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
batch_size = query->dims()[0];
seq_len_m = query->dims()[1];
......@@ -131,59 +148,68 @@ struct GateAttentionConfig {
return batch_size * seq_len_m * seq_len_r * num_heads * key_dim;
}
Tensor* GetQKVOut(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetQKVOut() {
if (!qkv_out.IsInitialized()) {
qkv_out.Resize(qkv_out_dims);
qkv_out.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "qkv_out: " << MemoryDebugString(qkv_out);
AllocWithDebugInfo<T>(dev_ctx, "qkv_out", &qkv_out);
}
return &qkv_out;
}
Tensor* GetQueryOut(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetQueryOut() {
if (!query_out.IsInitialized()) {
query_out.Resize(q_out_dims);
query_out.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "query_out: " << MemoryDebugString(query_out);
AllocWithDebugInfo<T>(dev_ctx, "query_out", &query_out);
}
return &query_out;
}
Tensor* GetKeyOut(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetKeyOut() {
if (!key_out.IsInitialized()) {
key_out.Resize(kv_out_dims);
key_out.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "key_out: " << MemoryDebugString(key_out);
AllocWithDebugInfo<T>(dev_ctx, "key_out", &key_out);
}
return &key_out;
}
Tensor* GetValueOut(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetValueOut() {
if (!value_out.IsInitialized()) {
value_out.Resize(kv_out_dims);
value_out.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "value_out: " << MemoryDebugString(value_out);
AllocWithDebugInfo<T>(dev_ctx, "value_out", &value_out);
}
return &value_out;
}
Tensor* GetQKOut(const platform::CUDADeviceContext& dev_ctx,
Tensor* softmax_out) {
Tensor* GetQKOut(Tensor* softmax_out) {
// softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
int softmax_dim = m_size;
if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
// Not sure whether cudnn softmax can execute inplace.
if (!qkv_out.IsInitialized()) {
qk_out.Resize(qk_out_dims);
qk_out.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "qk_out: " << MemoryDebugString(qk_out);
AllocWithDebugInfo<T>(dev_ctx, "qk_out", &qk_out);
}
return &qk_out;
} else {
// Enable inplace softmax.
return softmax_out;
}
}
Tensor* GetQKTVOut(Tensor* gate_out) {
if (has_gating && gate_out) {
// Reuse gate_out.
gate_out->Resize(qktv_out_dims);
return gate_out;
} else {
if (!qktv_out.IsInitialized()) {
qktv_out.Resize(qktv_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "qktv_out", &qktv_out);
}
return &qktv_out;
}
}
void ClearQKVOut() {
if (qkv_out.IsInitialized()) {
qkv_out.clear();
......@@ -196,9 +222,14 @@ struct GateAttentionConfig {
}
}
void ClearQKTVOut() {
if (qktv_out.IsInitialized()) {
qktv_out.clear();
}
}
protected:
Tensor qkv_out;
// QKV is not merged
Tensor query_out;
Tensor key_out;
Tensor value_out;
......@@ -207,63 +238,60 @@ struct GateAttentionConfig {
// softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
// The shape of qk_out, softmax_out is the same, thus can be called inplace.
Tensor qk_out;
// qktv_out may reuse gate_out.
Tensor qktv_out;
};
template <typename T>
struct GateAttentionGradConfig : public GateAttentionConfig<T> {
public:
GateAttentionGradConfig(const Tensor* query, const Tensor* key,
GateAttentionGradConfig(const platform::CUDADeviceContext& dev_ctx,
const Tensor* query, const Tensor* key,
const Tensor* query_weight, const Tensor* qkv_weight,
bool merge_qkv)
: GateAttentionConfig<T>(query, key, query_weight, qkv_weight,
merge_qkv) {}
bool merge_qkv, bool has_gating)
: GateAttentionConfig<T>(dev_ctx, query, key, query_weight, qkv_weight,
merge_qkv, has_gating) {}
Tensor* GetQKVOutGrad(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetQKVOutGrad() {
if (!qkv_out_grad.IsInitialized()) {
qkv_out_grad.Resize(this->qkv_out_dims);
qkv_out_grad.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "qkv_out_grad: " << MemoryDebugString(qkv_out_grad);
AllocWithDebugInfo<T>(this->dev_ctx, "qkv_out_grad", &qkv_out_grad);
}
return &qkv_out_grad;
}
Tensor* GetQueryOutGrad(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetQueryOutGrad() {
if (!query_out_grad.IsInitialized()) {
query_out_grad.Resize(this->q_out_dims);
query_out_grad.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "query_out_grad: " << MemoryDebugString(query_out_grad);
AllocWithDebugInfo<T>(this->dev_ctx, "query_out_grad", &query_out_grad);
}
return &query_out_grad;
}
Tensor* GetKeyOutGrad(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetKeyOutGrad() {
if (!key_out_grad.IsInitialized()) {
key_out_grad.Resize(this->kv_out_dims);
key_out_grad.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "key_out_grad: " << MemoryDebugString(key_out_grad);
AllocWithDebugInfo<T>(this->dev_ctx, "key_out_grad", &key_out_grad);
}
return &key_out_grad;
}
Tensor* GetValueOutGrad(const platform::CUDADeviceContext& dev_ctx) {
Tensor* GetValueOutGrad() {
if (!value_out_grad.IsInitialized()) {
value_out_grad.Resize(this->kv_out_dims);
value_out_grad.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "value_out_grad: " << MemoryDebugString(value_out_grad);
AllocWithDebugInfo<T>(this->dev_ctx, "value_out_grad", &value_out_grad);
}
return &value_out_grad;
}
Tensor* GetQKOutGrad(const platform::CUDADeviceContext& dev_ctx,
Tensor* softmax_out_grad) {
Tensor* GetQKOutGrad(Tensor* softmax_out_grad) {
// softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
int softmax_dim = this->m_size;
if (!softmax_out_grad ||
phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
phi::UseCudnnSoftmax<T>(this->dev_ctx, softmax_dim, true)) {
if (!qk_out_grad.IsInitialized()) {
qk_out_grad.Resize(this->qk_out_dims);
qk_out_grad.mutable_data<T>(dev_ctx.GetPlace());
VLOG(4) << "qk_out_grad: " << MemoryDebugString(qk_out_grad);
AllocWithDebugInfo<T>(this->dev_ctx, "qk_out_grad", &qk_out_grad);
}
return &qk_out_grad;
} else {
......@@ -288,7 +316,7 @@ class FMHAGateRef {
void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask,
Tensor* q_transpose_out, Tensor* k_transpose_out,
Tensor* v_transpose_out, Tensor* qkv_transpose_out,
Tensor* softmax_out, Tensor* fmha_out,
Tensor* softmax_out, Tensor* fmha_out, Tensor* gate_out,
GateAttentionConfig<T>* config) {
T* q_ptr = nullptr;
T* k_ptr = nullptr;
......@@ -300,7 +328,7 @@ class FMHAGateRef {
platform::errors::NotFound("The input qkv_transpose_out can not be "
"nullptr when merge_qkv is true."));
Tensor* qkv_out = config->GetQKVOut(dev_ctx_);
Tensor* qkv_out = config->GetQKVOut();
ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
config->ClearQKVOut();
......@@ -323,9 +351,9 @@ class FMHAGateRef {
platform::errors::NotFound("The input v_transpose_out can not be "
"nullptr when merge_qkv is false."));
Tensor* query_out = config->GetQueryOut(dev_ctx_);
Tensor* key_out = config->GetKeyOut(dev_ctx_);
Tensor* value_out = config->GetValueOut(dev_ctx_);
Tensor* query_out = config->GetQueryOut();
Tensor* key_out = config->GetKeyOut();
Tensor* value_out = config->GetValueOut();
ComputeQKVTransposeForward(*query_out, *key_out, *value_out,
q_transpose_out, k_transpose_out,
v_transpose_out);
......@@ -340,7 +368,7 @@ class FMHAGateRef {
// [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
Tensor* qk_out = config->GetQKOut(dev_ctx_, softmax_out);
Tensor* qk_out = config->GetQKOut(softmax_out);
T* qk_out_ptr = qk_out->data<T>();
int64_t gemm_batch_size =
......@@ -362,9 +390,8 @@ class FMHAGateRef {
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim]
Tensor qktv_out;
qktv_out.Resize(config->qktv_out_dims);
T* qktv_out_ptr = qktv_out.mutable_data<T>(dev_ctx_.GetPlace());
Tensor* qktv_out = config->GetQKTVOut(gate_out);
T* qktv_out_ptr = qktv_out->data<T>();
gemm_m = config->seq_len_r;
gemm_n = config->key_dim;
......@@ -375,7 +402,11 @@ class FMHAGateRef {
gemm_m, gemm_n, gemm_k, gemm_batch_size);
// fmha_out = transpose(qktv_out)
ComputeQKTVTransposeForward(qktv_out, fmha_out);
ComputeQKTVTransposeForward(*qktv_out, fmha_out);
config->ClearQKTVOut();
if (config->has_gating) {
gate_out->Resize(config->gate_out_dims);
}
}
void ComputeBackward(const Tensor* q_transpose_out,
......@@ -409,8 +440,10 @@ class FMHAGateRef {
v_ptr = k_ptr + q_size;
qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims);
AllocWithDebugInfo<T>(dev_ctx_, "qkv_transpose_out_grad",
&qkv_transpose_out_grad);
q_grad_ptr = qkv_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace());
q_grad_ptr = qkv_transpose_out_grad.data<T>();
k_grad_ptr = q_grad_ptr + q_size;
v_grad_ptr = k_grad_ptr + q_size;
} else {
......@@ -442,7 +475,7 @@ class FMHAGateRef {
Tensor softmax_out_grad;
softmax_out_grad.Resize(config->softmax_out_dims);
softmax_out_grad.mutable_data<T>(dev_ctx_.GetPlace());
AllocWithDebugInfo<T>(dev_ctx_, "softmax_out_grad", &softmax_out_grad);
int64_t gemm_batch_size =
config->batch_size * config->seq_len_m * config->num_heads;
......@@ -450,7 +483,7 @@ class FMHAGateRef {
// Forward: fmha_out = transpose(qktv_out)
Tensor qktv_out_grad;
qktv_out_grad.Resize(config->qktv_out_dims);
T* qktv_out_grad_ptr = qktv_out_grad.mutable_data<T>(dev_ctx_.GetPlace());
AllocWithDebugInfo<T>(dev_ctx_, "qktv_out_grad", &qktv_out_grad);
ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad);
// Forward: qktv_out = BatchedGEMM(softmax_out, V)
......@@ -461,6 +494,7 @@ class FMHAGateRef {
int64_t gemm_k = config->seq_len_r;
const T* softmax_out_ptr = softmax_out->data<T>();
const T* qktv_out_grad_ptr = qktv_out_grad.data<T>();
ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true,
false, gemm_m, gemm_n, gemm_k, gemm_batch_size);
......@@ -474,7 +508,7 @@ class FMHAGateRef {
true, gemm_m, gemm_n, gemm_k, gemm_batch_size);
}
Tensor* qk_out_grad = config->GetQKOutGrad(dev_ctx_, &softmax_out_grad);
Tensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad);
ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out,
src_mask_grad, qk_out_grad,
nonbatched_bias_grad);
......@@ -498,12 +532,12 @@ class FMHAGateRef {
gemm_n, gemm_k, gemm_batch_size, alpha);
if (merge_qkv_) {
Tensor* qkv_out_grad = config->GetQKVOutGrad(dev_ctx_);
Tensor* qkv_out_grad = config->GetQKVOutGrad();
ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
} else {
Tensor* q_out_grad = config->GetQueryOutGrad(dev_ctx_);
Tensor* k_out_grad = config->GetKeyOutGrad(dev_ctx_);
Tensor* v_out_grad = config->GetValueOutGrad(dev_ctx_);
Tensor* q_out_grad = config->GetQueryOutGrad();
Tensor* k_out_grad = config->GetKeyOutGrad();
Tensor* v_out_grad = config->GetValueOutGrad();
ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad,
v_transpose_out_grad, q_out_grad, k_out_grad,
v_out_grad);
......@@ -578,12 +612,12 @@ class FMHAGateRef {
if (nonbatched_bias) {
std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask};
std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>(
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
} else {
std::vector<const Tensor*> ins = {qk_out, src_mask};
std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
}
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
......@@ -614,12 +648,12 @@ class FMHAGateRef {
phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, *softmax_out,
*softmax_out_grad, -1, qk_out_grad);
// [1, bs, num_head, seq_l, seq_l] -> [bs, num_head, seq_l, seq_l]
if (nonbatched_bias_grad) {
gpuStream_t stream = dev_ctx_.stream();
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] ->
// [batch_size, 1, num_heads, seq_len_r, m_size]
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_, *qk_out_grad, nonbatched_bias_grad,
kps::IdentityFunctor<T>(), {0, 1}, stream);
kps::IdentityFunctor<T>(), {1});
}
}
......
......@@ -214,7 +214,7 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
"fused_aate_attention_arad");
if (ctx->Attrs().Get<bool>("has_gating")) {
for (auto& name : {"GateWeight", "GateBias", "GateOut"}) {
for (auto& name : {"GateWeight", "GateBias"}) {
ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name));
}
}
......@@ -224,9 +224,6 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("NonbatchedBias"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearWeight"),
ctx->GetInputDim("OutLinearWeight"));
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
......@@ -270,8 +267,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
if (this->HasInput("NonbatchedBias")) {
op->SetInput("NonbatchedBias", this->Input("NonbatchedBias"));
......@@ -292,8 +287,6 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("GateBias"));
op->SetInput("GateOut", this->Output("GateOut"));
op->SetOutput(framework::GradVarName("GateOut"),
this->OutputGrad("GateOut"));
}
op->SetInput("OutLinearWeight", this->Input("OutLinearWeight"));
......
......@@ -79,7 +79,7 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
}
template <typename T>
Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config,
const Tensor *query,
const Tensor *qkv_out_grad,
......@@ -97,7 +97,6 @@ Tensor *ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, query_grad,
qkv_weight_grad, nullptr, use_addto);
return query_grad;
}
template <typename T>
......@@ -137,11 +136,13 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
}
template <typename T>
Tensor *ComputeSeparatedQKVMatmulBackward(
const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config, const Tensor *query,
const Tensor *key, const Tensor *query_out_grad, const Tensor *key_out_grad,
const Tensor *value_out_grad, Tensor *query_grad, Tensor *key_grad,
void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config,
const Tensor *query, const Tensor *key,
const Tensor *query_out_grad,
const Tensor *key_out_grad,
const Tensor *value_out_grad,
Tensor *query_grad, Tensor *key_grad,
bool use_addto) {
// Gradient of GEMM(key, k_weight)
const auto *key_weight = ctx.Input<Tensor>("KeyWeight");
......@@ -179,22 +180,16 @@ Tensor *ComputeSeparatedQKVMatmulBackward(
q_n, q_k, false);
q_compute.ComputeBackward(query, query_weight, query_out_grad, query_grad,
query_weight_grad, nullptr, use_addto);
return query_grad;
}
template <typename T>
Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config,
const Tensor *query,
const Tensor *fmha_out) {
const Tensor *query, const Tensor *fmha_out,
Tensor *gate_out) {
auto *gate_weight = ctx.Input<Tensor>("GateWeight");
auto *gate_bias = ctx.Input<Tensor>("GateBias");
auto *gate_out = ctx.Output<Tensor>("GateOut");
gate_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "[ComputeGatingLinearForward] gate_out: "
<< MemoryDebugString(*gate_out);
// The first gate_bias_out stores the result of the multiplication,
// and the second gate_bias_out stores the result of the multiplication +
// bias.
......@@ -212,16 +207,14 @@ Tensor *ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
std::vector<Tensor *> outs = {gate_out};
phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
SigmoidMultiplyFunctor<T>());
return gate_out;
}
template <typename T>
Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config,
const Tensor *fmha_out,
const Tensor *query, const Tensor *fmha_out,
const Tensor *gate_out_grad,
Tensor *query_grad, Tensor *fmha_out_grad) {
const auto *query = ctx.Input<Tensor>("Query");
const auto *gate_weight = ctx.Input<Tensor>("GateWeight");
const auto *gate_bias = ctx.Input<Tensor>("GateBias");
......@@ -255,20 +248,15 @@ Tensor *ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
gate_attn_compute.ComputeBackward(query, gate_weight, &gate_bias_out,
query_grad, gate_weight_grad,
gate_bias_grad);
return fmha_out_grad;
}
template <typename T>
Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config,
const Tensor *fmha_or_gate_out) {
const Tensor *fmha_or_gate_out, Tensor *out) {
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
const auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "[ComputeOutputLinearForward] out: " << MemoryDebugString(*out);
// out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim;
......@@ -277,28 +265,24 @@ Tensor *ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out,
out_linear_bias, out, out);
return out;
}
template <typename T>
Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
const GateAttentionGradConfig<T> &config,
bool has_gating) {
std::string input_name = has_gating ? "GateOut" : "FMHAOut";
const Tensor *input, Tensor *input_grad) {
const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
const auto *input = ctx.Input<Tensor>(input_name);
auto *out_linear_weight_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearWeight"));
auto *out_linear_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *input_grad = ctx.Output<Tensor>(framework::GradVarName(input_name));
out_linear_weight_grad->mutable_data<T>(ctx.GetPlace());
out_linear_bias_grad->mutable_data<T>(ctx.GetPlace());
input_grad->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim;
......@@ -308,7 +292,6 @@ Tensor *ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad,
input_grad, out_linear_weight_grad,
out_linear_bias_grad);
return input_grad;
}
template <typename T>
......@@ -330,56 +313,64 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
auto *fmha_out = ctx.Output<Tensor>("FMHAOut");
auto *gate_out = ctx.Output<Tensor>("GateOut");
auto *out = ctx.Output<Tensor>("Out");
const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating");
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
GateAttentionConfig<T> config(query, key, query_weight, qkv_weight,
merge_qkv);
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
if (has_gating) {
AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
}
AllocWithDebugInfo<T>(dev_ctx, "out", out);
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
GateAttentionConfig<T> config(dev_ctx, query, key, query_weight, qkv_weight,
merge_qkv, has_gating);
if (merge_qkv) {
PADDLE_ENFORCE_EQ(!key || query == key, true,
platform::errors::InvalidArgument(
"key is expected to be nullptr or the same as "
"query, but recieved key=%p, query=%p.",
key, query));
// 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc)
Tensor *qkv_out = config.GetQKVOut(dev_ctx);
Tensor *qkv_out = config.GetQKVOut();
ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);
qkv_transpose_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "qkv_transpose_out:" << MemoryDebugString(*qkv_transpose_out);
AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
} else {
// 1. Separated QKV Matmul
Tensor *query_out = config.GetQueryOut(dev_ctx);
Tensor *key_out = config.GetKeyOut(dev_ctx);
Tensor *value_out = config.GetValueOut(dev_ctx);
Tensor *query_out = config.GetQueryOut();
Tensor *key_out = config.GetKeyOut();
Tensor *value_out = config.GetValueOut();
ComputeSeparatedQKVMatmulForward<T>(ctx, config, query, key, query_out,
key_out, value_out);
q_transpose_out->mutable_data<T>(ctx.GetPlace());
k_transpose_out->mutable_data<T>(ctx.GetPlace());
v_transpose_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "q_transpose_out: " << MemoryDebugString(*q_transpose_out);
VLOG(4) << "k_transpose_out: " << MemoryDebugString(*k_transpose_out);
VLOG(4) << "v_transpose_out: " << MemoryDebugString(*v_transpose_out);
AllocWithDebugInfo<T>(dev_ctx, "q_transpose_out", q_transpose_out);
AllocWithDebugInfo<T>(dev_ctx, "k_transpose_out", k_transpose_out);
AllocWithDebugInfo<T>(dev_ctx, "v_transpose_out", v_transpose_out);
}
softmax_out->mutable_data<T>(ctx.GetPlace());
fmha_out->mutable_data<T>(ctx.GetPlace());
VLOG(4) << "softmax_out: " << MemoryDebugString(*softmax_out);
VLOG(4) << "fmha_out: " << MemoryDebugString(*fmha_out);
// 2. FMHA
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward(
nonbatched_bias, src_mask, q_transpose_out, k_transpose_out,
v_transpose_out, qkv_transpose_out, softmax_out, fmha_out, &config);
fmha_compute.ComputeForward(nonbatched_bias, src_mask, q_transpose_out,
k_transpose_out, v_transpose_out,
qkv_transpose_out, softmax_out, fmha_out,
gate_out, &config);
// 3. Gating Linear
Tensor *fmha_or_gate_out = !has_gating ? fmha_out
: ComputeGatingLinearForward<T>(
ctx, config, query, fmha_out);
if (has_gating) {
ComputeGatingLinearForward<T>(ctx, config, query, fmha_out, gate_out);
}
// 4. Output Linear
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out);
Tensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out;
ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out, out);
}
};
......@@ -387,9 +378,6 @@ template <typename T>
class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto has_gating = ctx.Attr<bool>("has_gating");
const auto merge_qkv = ctx.Attr<bool>("merge_qkv");
// forward input
const auto *query = ctx.Input<Tensor>("Query");
const auto *key = ctx.Input<Tensor>("Key");
......@@ -403,56 +391,68 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
const auto *qkv_transpose_out = ctx.Input<Tensor>("QKVTransposeOut");
const auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
const auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
const auto *gate_out = ctx.Input<Tensor>("GateOut");
// backward output
auto *query_grad = ctx.Output<Tensor>(framework::GradVarName("Query"));
query_grad->mutable_data<T>(ctx.GetPlace());
auto *nonbatched_bias_grad =
ctx.Output<Tensor>(framework::GradVarName("NonbatchedBias"));
auto *fmha_out_grad = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
GateAttentionGradConfig<T> config(query, key, query_weight, qkv_weight,
merge_qkv);
AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
// 1. Gradient of Output Linear
Tensor *fhma_or_gate_out_grad =
ComputeOutputLinearBackward<T>(ctx, config, has_gating);
GateAttentionGradConfig<T> config(dev_ctx, query, key, query_weight,
qkv_weight, merge_qkv, has_gating);
// 2. Gradient of Gating Linear
Tensor fmha_out_grad;
fmha_out_grad.Resize(config.gate_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out_grad", &fmha_out_grad);
if (has_gating) {
// fhma_or_gate_out_grad is actually gate_out_grad.
fmha_out_grad->mutable_data<T>(ctx.GetPlace());
ComputeGatingLinearBackward<T>(ctx, config, fmha_out,
fhma_or_gate_out_grad, query_grad,
fmha_out_grad);
// 1. Gradient of Output Linear: out = Linear(gate_out)
Tensor gate_out_grad;
gate_out_grad.Resize(config.gate_out_dims);
AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad);
ComputeOutputLinearBackward<T>(ctx, config, gate_out, &gate_out_grad);
// 2. Gradient of Gating Linear
// Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out
ComputeGatingLinearBackward<T>(ctx, config, query, fmha_out,
&gate_out_grad, query_grad,
&fmha_out_grad);
} else {
// 1. Gradient of Output Linear: out = Linear(fmha_grad)
ComputeOutputLinearBackward<T>(ctx, config, fmha_out, &fmha_out_grad);
}
// 3. Gradient of FMHA
if (nonbatched_bias_grad) {
nonbatched_bias_grad->mutable_data<T>(ctx.GetPlace());
AllocWithDebugInfo<T>(dev_ctx, "nonbatched_bias_grad",
nonbatched_bias_grad);
}
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward(
q_transpose_out, k_transpose_out, v_transpose_out, qkv_transpose_out,
softmax_out, fmha_out_grad, nullptr, nonbatched_bias_grad, &config);
softmax_out, &fmha_out_grad, nullptr, nonbatched_bias_grad, &config);
bool use_addto = has_gating ? true : false;
if (merge_qkv) {
// 4. Gradient of Merged QKV Matmul
Tensor *qkv_out_grad = config.GetQKVOutGrad(dev_ctx);
Tensor *qkv_out_grad = config.GetQKVOutGrad();
ComputeMergedQKVMatmulBackward<T>(ctx, config, query, qkv_out_grad,
query_grad, use_addto);
} else {
// 4. Gradient of Separated QKV Matmul
auto *key_grad = ctx.Output<Tensor>(framework::GradVarName("Key"));
if (key_grad) {
key_grad->mutable_data<T>(ctx.GetPlace());
AllocWithDebugInfo<T>(dev_ctx, "key_grad", key_grad);
}
Tensor *query_out_grad = config.GetQueryOutGrad(dev_ctx);
Tensor *key_out_grad = config.GetKeyOutGrad(dev_ctx);
Tensor *value_out_grad = config.GetValueOutGrad(dev_ctx);
Tensor *query_out_grad = config.GetQueryOutGrad();
Tensor *key_out_grad = config.GetKeyOutGrad();
Tensor *value_out_grad = config.GetValueOutGrad();
ComputeSeparatedQKVMatmulBackward<T>(
ctx, config, query, key, query_out_grad, key_out_grad, value_out_grad,
query_grad, key_grad, use_addto);
......
......@@ -18,7 +18,7 @@ import paddle
import paddle.nn as nn
from paddle import tensor
import unittest
from op_test import OpTest, convert_float_to_uint16
from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float
from test_sparse_attention_op import get_cuda_version
from paddle import _C_ops
from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph
......@@ -194,16 +194,22 @@ class TestFusedGateAttentionOp(OpTest):
return out, query.grad, None
def check_output_and_grad(self, atol, rtol):
out_ref, query_grad_ref, key_grad_ref = self.get_reference_out()
out, query_grad, key_grad = self.get_fused_gate_attention_out()
np.testing.assert_allclose(out_ref, out.numpy(), atol=atol, rtol=rtol)
np.testing.assert_allclose(query_grad_ref,
query_grad.numpy(),
atol=atol,
rtol=rtol)
if key_grad_ref is not None and key_grad is not None:
np.testing.assert_allclose(key_grad_ref,
key_grad.numpy(),
def _convert(value):
if self.dtype == "bfloat16":
return convert_uint16_to_float(value)
return value
output_names = ["out", "query_grad", "key_grad"]
outputs_ref = self.get_reference_out()
outputs_fused = self.get_fused_gate_attention_out()
for i in range(len(outputs_fused)):
ref_res = outputs_ref[i]
fused_res = outputs_fused[i]
if ref_res is not None and fused_res is not None:
print("Checking {}".format(output_names[i]))
np.testing.assert_allclose(_convert(ref_res),
_convert(fused_res.numpy()),
atol=atol,
rtol=rtol)
......@@ -211,6 +217,13 @@ class TestFusedGateAttentionOp(OpTest):
self.check_output_and_grad(atol=1e-5, rtol=1e-5)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
def config(self):
super().config()
self.batch_size = 2
class TestSeparatedQKVCase(TestFusedGateAttentionOp):
def config(self):
......@@ -243,9 +256,18 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
self.dtype = "float16"
def test_output_and_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_and_grad(atol=1e-1, rtol=1e-5)
class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case):
def config(self):
super().config()
self.batch_size = 2
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11000,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
......@@ -260,5 +282,12 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
self.check_output_and_grad(atol=1e-1, rtol=1e-3)
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
def config(self):
super().config()
self.batch_size = 2
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册