未验证 提交 9b7126d0 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize prod's python implementation for dygraph. (#43309)

* Optimize prod's python implementation for dygraph.

* Change key_dim to head_dim.

* Add comment in unittest.

* Disable TF32 in unittest.
上级 4d0ca02b
...@@ -68,7 +68,7 @@ struct GateAttentionConfig { ...@@ -68,7 +68,7 @@ struct GateAttentionConfig {
int64_t seq_len_r; int64_t seq_len_r;
int64_t q_dim; int64_t q_dim;
int64_t kv_dim; int64_t kv_dim;
int64_t key_dim; int64_t head_dim;
int64_t m_size; int64_t m_size;
int64_t num_heads; int64_t num_heads;
...@@ -103,15 +103,15 @@ struct GateAttentionConfig { ...@@ -103,15 +103,15 @@ struct GateAttentionConfig {
"when merge_qkv is true.")); "when merge_qkv is true."));
// When q_dim == kv_dim, QKV matmul can be computed merged. // When q_dim == kv_dim, QKV matmul can be computed merged.
// qkv_weight: shape=[3, num_heads, key_dim, q_dim] // qkv_weight: shape=[3, num_heads, head_dim, q_dim]
num_heads = qkv_weight->dims()[1]; num_heads = qkv_weight->dims()[1];
key_dim = qkv_weight->dims()[2]; head_dim = qkv_weight->dims()[2];
m_size = seq_len_r; m_size = seq_len_r;
kv_dim = q_dim; kv_dim = q_dim;
qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim}; qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim};
qkv_transpose_out_dims = {3, batch_size, seq_len_m, qkv_transpose_out_dims = {3, batch_size, seq_len_m,
num_heads, seq_len_r, key_dim}; num_heads, seq_len_r, head_dim};
} else { } else {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
key, key,
...@@ -124,28 +124,28 @@ struct GateAttentionConfig { ...@@ -124,28 +124,28 @@ struct GateAttentionConfig {
// When q_dim != kv_dim, QKV matmul must be computed saparately. // When q_dim != kv_dim, QKV matmul must be computed saparately.
// key: shape=[batch_size, seq_len_m, m_size, kv_dim] // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
// query_w: shape=[q_dim, num_heads, key_dim] // query_w: shape=[q_dim, num_heads, head_dim]
num_heads = query_weight->dims()[1]; num_heads = query_weight->dims()[1];
key_dim = query_weight->dims()[2]; head_dim = query_weight->dims()[2];
m_size = key->dims()[2]; m_size = key->dims()[2];
kv_dim = key->dims()[3]; kv_dim = key->dims()[3];
q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, key_dim}; q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, key_dim}; kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, head_dim};
q_transpose_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, q_transpose_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r,
key_dim}; head_dim};
kv_transpose_out_dims = {batch_size, seq_len_m, num_heads, m_size, kv_transpose_out_dims = {batch_size, seq_len_m, num_heads, m_size,
key_dim}; head_dim};
} }
qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size}; qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size}; softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, key_dim}; qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, key_dim}; gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
} }
int64_t GetQuerySize() const { int64_t GetQuerySize() const {
return batch_size * seq_len_m * seq_len_r * num_heads * key_dim; return batch_size * seq_len_m * seq_len_r * num_heads * head_dim;
} }
Tensor* GetQKVOut() { Tensor* GetQKVOut() {
...@@ -365,8 +365,8 @@ class FMHAGateRef { ...@@ -365,8 +365,8 @@ class FMHAGateRef {
} }
// qk_out = BatchedGEMM(Q, K^T) // qk_out = BatchedGEMM(Q, K^T)
// [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] * // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim] // [batch_size, seq_len_m, num_heads, m_size, head_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size] // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
Tensor* qk_out = config->GetQKOut(softmax_out); Tensor* qk_out = config->GetQKOut(softmax_out);
T* qk_out_ptr = qk_out->data<T>(); T* qk_out_ptr = qk_out->data<T>();
...@@ -375,9 +375,9 @@ class FMHAGateRef { ...@@ -375,9 +375,9 @@ class FMHAGateRef {
config->batch_size * config->seq_len_m * config->num_heads; config->batch_size * config->seq_len_m * config->num_heads;
int64_t gemm_m = config->seq_len_r; int64_t gemm_m = config->seq_len_r;
int64_t gemm_n = config->m_size; int64_t gemm_n = config->m_size;
int64_t gemm_k = config->key_dim; int64_t gemm_k = config->head_dim;
T alpha = static_cast<T>(1.0 / sqrt(config->key_dim)); T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
ComputeBatchedGEMM(q_ptr, k_ptr, qk_out_ptr, false, true, gemm_m, gemm_n, ComputeBatchedGEMM(q_ptr, k_ptr, qk_out_ptr, false, true, gemm_m, gemm_n,
gemm_k, gemm_batch_size, alpha); gemm_k, gemm_batch_size, alpha);
...@@ -388,13 +388,13 @@ class FMHAGateRef { ...@@ -388,13 +388,13 @@ class FMHAGateRef {
// qktv_out = BatchedGEMM(softmax_out, V) // qktv_out = BatchedGEMM(softmax_out, V)
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim] // [batch_size, seq_len_m, num_heads, m_size, head_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
Tensor* qktv_out = config->GetQKTVOut(gate_out); Tensor* qktv_out = config->GetQKTVOut(gate_out);
T* qktv_out_ptr = qktv_out->data<T>(); T* qktv_out_ptr = qktv_out->data<T>();
gemm_m = config->seq_len_r; gemm_m = config->seq_len_r;
gemm_n = config->key_dim; gemm_n = config->head_dim;
gemm_k = config->m_size; gemm_k = config->m_size;
T* softmax_out_ptr = softmax_out->data<T>(); T* softmax_out_ptr = softmax_out->data<T>();
...@@ -490,7 +490,7 @@ class FMHAGateRef { ...@@ -490,7 +490,7 @@ class FMHAGateRef {
// Backward: // Backward:
// V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout) // V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout)
int64_t gemm_m = config->m_size; int64_t gemm_m = config->m_size;
int64_t gemm_n = config->key_dim; int64_t gemm_n = config->head_dim;
int64_t gemm_k = config->seq_len_r; int64_t gemm_k = config->seq_len_r;
const T* softmax_out_ptr = softmax_out->data<T>(); const T* softmax_out_ptr = softmax_out->data<T>();
...@@ -501,7 +501,7 @@ class FMHAGateRef { ...@@ -501,7 +501,7 @@ class FMHAGateRef {
// Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T) // Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T)
gemm_m = config->seq_len_r; gemm_m = config->seq_len_r;
gemm_n = config->m_size; gemm_n = config->m_size;
gemm_k = config->key_dim; gemm_k = config->head_dim;
T* softmax_out_grad_ptr = softmax_out_grad.data<T>(); T* softmax_out_grad_ptr = softmax_out_grad.data<T>();
ComputeBatchedGEMM(qktv_out_grad_ptr, v_ptr, softmax_out_grad_ptr, false, ComputeBatchedGEMM(qktv_out_grad_ptr, v_ptr, softmax_out_grad_ptr, false,
...@@ -516,9 +516,9 @@ class FMHAGateRef { ...@@ -516,9 +516,9 @@ class FMHAGateRef {
// Forward: qk_out = BatchedGEMM(Q, K^T) // Forward: qk_out = BatchedGEMM(Q, K^T)
// Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x) // Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x)
int64_t gemm_m = config->m_size; int64_t gemm_m = config->m_size;
int64_t gemm_n = config->key_dim; int64_t gemm_n = config->head_dim;
int64_t gemm_k = config->seq_len_r; int64_t gemm_k = config->seq_len_r;
T alpha = static_cast<T>(1.0 / sqrt(config->key_dim)); T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
T* qk_out_grad_ptr = qk_out_grad->data<T>(); T* qk_out_grad_ptr = qk_out_grad->data<T>();
ComputeBatchedGEMM(qk_out_grad_ptr, q_ptr, k_grad_ptr, true, false, gemm_m, ComputeBatchedGEMM(qk_out_grad_ptr, q_ptr, k_grad_ptr, true, false, gemm_m,
...@@ -526,7 +526,7 @@ class FMHAGateRef { ...@@ -526,7 +526,7 @@ class FMHAGateRef {
// Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y) // Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y)
gemm_m = config->seq_len_r; gemm_m = config->seq_len_r;
gemm_n = config->key_dim; gemm_n = config->head_dim;
gemm_k = config->m_size; gemm_k = config->m_size;
ComputeBatchedGEMM(qk_out_grad_ptr, k_ptr, q_grad_ptr, false, false, gemm_m, ComputeBatchedGEMM(qk_out_grad_ptr, k_ptr, q_grad_ptr, false, false, gemm_m,
gemm_n, gemm_k, gemm_batch_size, alpha); gemm_n, gemm_k, gemm_batch_size, alpha);
...@@ -570,8 +570,8 @@ class FMHAGateRef { ...@@ -570,8 +570,8 @@ class FMHAGateRef {
v_out_grad); v_out_grad);
} }
// [batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim] -> // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] ->
// [3, batch_size, seq_len_m, num_heads, seq_len_r, key_dim] // [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
void ComputeQKVTransposeForward(const Tensor& qkv_out, void ComputeQKVTransposeForward(const Tensor& qkv_out,
Tensor* qkv_transpose_out) { Tensor* qkv_transpose_out) {
int ndims = 6; int ndims = 6;
...@@ -610,7 +610,7 @@ class FMHAGateRef { ...@@ -610,7 +610,7 @@ class FMHAGateRef {
const Tensor* src_mask, Tensor* qk_out, const Tensor* src_mask, Tensor* qk_out,
Tensor* softmax_out) { Tensor* softmax_out) {
if (nonbatched_bias) { if (nonbatched_bias) {
std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask}; std::vector<const Tensor*> ins = {qk_out, src_mask, nonbatched_bias};
std::vector<Tensor*> outs = {qk_out}; std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>( phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>()); dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
......
...@@ -47,10 +47,10 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { ...@@ -47,10 +47,10 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
int seq_len_m = input_q_dims[1]; int seq_len_m = input_q_dims[1];
int seq_len_r = input_q_dims[2]; int seq_len_r = input_q_dims[2];
int num_head, m_size, key_dim; int num_head, m_size, head_dim;
if (ctx->Attrs().Get<bool>("merge_qkv")) { if (ctx->Attrs().Get<bool>("merge_qkv")) {
// QKV's input: [batch_size, seq_len_m, seq_len_r, qkv_dim] // QKV's input: [batch_size, seq_len_m, seq_len_r, qkv_dim]
// QKV's weight: [3, num_head, key_dim, qkv_dim] // QKV's weight: [3, num_head, head_dim, qkv_dim]
OP_INOUT_CHECK(ctx->HasInput("QKVWeight"), "Input", "QKVWeight", OP_INOUT_CHECK(ctx->HasInput("QKVWeight"), "Input", "QKVWeight",
"fused_gate_attention"); "fused_gate_attention");
OP_INOUT_CHECK(ctx->HasOutput("QKVTransposeOut"), "Output", OP_INOUT_CHECK(ctx->HasOutput("QKVTransposeOut"), "Output",
...@@ -59,11 +59,11 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { ...@@ -59,11 +59,11 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
auto qkv_w_dims = ctx->GetInputDim("QKVWeight"); auto qkv_w_dims = ctx->GetInputDim("QKVWeight");
num_head = qkv_w_dims[1]; num_head = qkv_w_dims[1];
key_dim = qkv_w_dims[2]; head_dim = qkv_w_dims[2];
m_size = seq_len_r; m_size = seq_len_r;
ctx->SetOutputDim("QKVTransposeOut", {3, batch_size, seq_len_m, num_head, ctx->SetOutputDim("QKVTransposeOut", {3, batch_size, seq_len_m, num_head,
seq_len_r, key_dim}); seq_len_r, head_dim});
} else { } else {
OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), "Input", "QueryWeight", OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), "Input", "QueryWeight",
"fused_gate_attention"); "fused_gate_attention");
...@@ -76,21 +76,21 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { ...@@ -76,21 +76,21 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
auto q_w_dims = ctx->GetInputDim("QueryWeight"); auto q_w_dims = ctx->GetInputDim("QueryWeight");
num_head = q_w_dims[1]; num_head = q_w_dims[1];
key_dim = q_w_dims[2]; head_dim = q_w_dims[2];
m_size = input_k_dims[2]; m_size = input_k_dims[2];
ctx->SetOutputDim("QueryTransposeOut", ctx->SetOutputDim("QueryTransposeOut",
{batch_size, seq_len_m, num_head, seq_len_r, key_dim}); {batch_size, seq_len_m, num_head, seq_len_r, head_dim});
ctx->SetOutputDim("KeyTransposeOut", ctx->SetOutputDim("KeyTransposeOut",
{batch_size, seq_len_m, num_head, m_size, key_dim}); {batch_size, seq_len_m, num_head, m_size, head_dim});
ctx->SetOutputDim("ValueTransposeOut", ctx->SetOutputDim("ValueTransposeOut",
{batch_size, seq_len_m, num_head, m_size, key_dim}); {batch_size, seq_len_m, num_head, m_size, head_dim});
} }
ctx->SetOutputDim("SoftmaxOut", ctx->SetOutputDim("SoftmaxOut",
{batch_size, seq_len_m, num_head, seq_len_r, m_size}); {batch_size, seq_len_m, num_head, seq_len_r, m_size});
ctx->SetOutputDim("FMHAOut", ctx->SetOutputDim("FMHAOut",
{batch_size, seq_len_m, seq_len_r, num_head, key_dim}); {batch_size, seq_len_m, seq_len_r, num_head, head_dim});
if (ctx->Attrs().Get<bool>("has_gating")) { if (ctx->Attrs().Get<bool>("has_gating")) {
OP_INOUT_CHECK(ctx->HasInput("GateWeight"), "Input", "GateWeight", OP_INOUT_CHECK(ctx->HasInput("GateWeight"), "Input", "GateWeight",
...@@ -98,7 +98,7 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { ...@@ -98,7 +98,7 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("GateBias"), "Input", "GateBias", OP_INOUT_CHECK(ctx->HasInput("GateBias"), "Input", "GateBias",
"fused_gate_attention"); "fused_gate_attention");
ctx->SetOutputDim("GateOut", ctx->SetOutputDim("GateOut",
{batch_size, seq_len_m, seq_len_r, num_head, key_dim}); {batch_size, seq_len_m, seq_len_r, num_head, head_dim});
} }
ctx->SetOutputDim("Out", ctx->GetInputDim("Query")); ctx->SetOutputDim("Out", ctx->GetInputDim("Query"));
......
...@@ -65,13 +65,13 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -65,13 +65,13 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config, const GateAttentionConfig<T> &config,
const Tensor *query, Tensor *qkv_out) { const Tensor *query, Tensor *qkv_out) {
// query: shape=[batch_size, seq_len_m, seq_len_r, qkv_dim] // query: shape=[batch_size, seq_len_m, seq_len_r, qkv_dim]
// qkv_weight: shape=[3, num_heads, key_dim, qkv_dim] // qkv_weight: shape=[3, num_heads, head_dim, qkv_dim]
// qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim] // qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim]
auto *qkv_weight = ctx.Input<Tensor>("QKVWeight"); auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");
// qkv_out = GEMM(query, qkv_weight^T) // qkv_out = GEMM(query, qkv_weight^T)
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = 3 * config.num_heads * config.key_dim; int n = 3 * config.num_heads * config.head_dim;
int k = config.q_dim; int k = config.q_dim;
auto qkv_compute = auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false); AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
...@@ -91,7 +91,7 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -91,7 +91,7 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
// Gradient of GEMM(query, qkv_weight) // Gradient of GEMM(query, qkv_weight)
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = 3 * config.num_heads * config.key_dim; int n = 3 * config.num_heads * config.head_dim;
int k = config.q_dim; int k = config.q_dim;
auto qkv_compute = auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false); AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
...@@ -111,10 +111,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -111,10 +111,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
// query_out = GEMM(query, query_weight) // query_out = GEMM(query, query_weight)
// query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
// query_weight: shape=[q_dim, num_heads, key_dim] // query_weight: shape=[q_dim, num_heads, head_dim]
// query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, key_dim] // query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, head_dim]
int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
int q_n = config.num_heads * config.key_dim; int q_n = config.num_heads * config.head_dim;
int q_k = config.q_dim; int q_k = config.q_dim;
auto q_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, q_m, auto q_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, q_m,
q_n, q_k, false); q_n, q_k, false);
...@@ -122,10 +122,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, ...@@ -122,10 +122,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
// k_out = GEMM(key, key_weight) // k_out = GEMM(key, key_weight)
// key: shape=[batch_size, seq_len_m, m_size, kv_dim] // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
// key_weight: shape=[kv_dim, num_heads, key_dim] // key_weight: shape=[kv_dim, num_heads, head_dim]
// key_out: shape=[batch_size, seq_len_m, m_size, num_heads, key_dim] // key_out: shape=[batch_size, seq_len_m, m_size, num_heads, head_dim]
int kv_m = config.batch_size * config.seq_len_m * config.m_size; int kv_m = config.batch_size * config.seq_len_m * config.m_size;
int kv_n = config.num_heads * config.key_dim; int kv_n = config.num_heads * config.head_dim;
int kv_k = config.kv_dim; int kv_k = config.kv_dim;
auto kv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, kv_m, auto kv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, kv_m,
kv_n, kv_k, false); kv_n, kv_k, false);
...@@ -151,7 +151,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -151,7 +151,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
key_weight_grad->mutable_data<T>(ctx.GetPlace()); key_weight_grad->mutable_data<T>(ctx.GetPlace());
int kv_m = config.batch_size * config.seq_len_m * config.m_size; int kv_m = config.batch_size * config.seq_len_m * config.m_size;
int kv_n = config.num_heads * config.key_dim; int kv_n = config.num_heads * config.head_dim;
int kv_k = config.kv_dim; int kv_k = config.kv_dim;
auto kv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, kv_m, auto kv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, kv_m,
kv_n, kv_k, false); kv_n, kv_k, false);
...@@ -174,7 +174,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, ...@@ -174,7 +174,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
query_weight_grad->mutable_data<T>(ctx.GetPlace()); query_weight_grad->mutable_data<T>(ctx.GetPlace());
int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
int q_n = config.num_heads * config.key_dim; int q_n = config.num_heads * config.head_dim;
int q_k = config.q_dim; int q_k = config.q_dim;
auto q_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, q_m, auto q_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, q_m,
q_n, q_k, false); q_n, q_k, false);
...@@ -195,7 +195,7 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, ...@@ -195,7 +195,7 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
// bias. // bias.
// gate_out = GEMM(query, gate_weight) + gate_bias // gate_out = GEMM(query, gate_weight) + gate_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.key_dim; int n = config.num_heads * config.head_dim;
int k = config.q_dim; int k = config.q_dim;
auto gate_attn_compute = auto gate_attn_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
...@@ -224,7 +224,7 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, ...@@ -224,7 +224,7 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
gate_bias_out.mutable_data<T>(ctx.GetPlace()); gate_bias_out.mutable_data<T>(ctx.GetPlace());
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.key_dim; int n = config.num_heads * config.head_dim;
int k = config.q_dim; int k = config.q_dim;
auto gate_attn_compute = auto gate_attn_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
...@@ -260,7 +260,7 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, ...@@ -260,7 +260,7 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
// out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
int k = config.num_heads * config.key_dim; int k = config.num_heads * config.head_dim;
auto out_linear_compute = auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out, out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out,
...@@ -282,11 +282,9 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, ...@@ -282,11 +282,9 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
out_linear_weight_grad->mutable_data<T>(ctx.GetPlace()); out_linear_weight_grad->mutable_data<T>(ctx.GetPlace());
out_linear_bias_grad->mutable_data<T>(ctx.GetPlace()); out_linear_bias_grad->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int m = config.batch_size * config.seq_len_m * config.seq_len_r; int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim; int n = config.q_dim;
int k = config.num_heads * config.key_dim; int k = config.num_heads * config.head_dim;
auto out_linear_compute = auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true); AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad, out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad,
......
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
os.environ['NVIDIA_TF32_OVERRIDE'] = "0"
os.environ['FLAGS_new_einsum'] = "0"
import numpy as np import numpy as np
import paddle import paddle
...@@ -47,7 +52,7 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -47,7 +52,7 @@ class TestFusedGateAttentionOp(OpTest):
self.res_len = 5 self.res_len = 5
self.q_dim = 6 self.q_dim = 6
self.num_heads = 2 self.num_heads = 2
self.key_dim = 4 self.head_dim = 4
self.m_size = self.res_len self.m_size = self.res_len
self.kv_dim = self.q_dim self.kv_dim = self.q_dim
self.out_dim = self.q_dim self.out_dim = self.q_dim
...@@ -65,12 +70,12 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -65,12 +70,12 @@ class TestFusedGateAttentionOp(OpTest):
np.random.seed(123) np.random.seed(123)
self.query = _random( self.query = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim)) (self.batch_size, self.msa_len, self.res_len, self.q_dim))
self.q_weight = _random((self.q_dim, self.num_heads, self.key_dim)) self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim))
self.k_weight = _random((self.kv_dim, self.num_heads, self.key_dim)) self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
self.v_weight = _random((self.kv_dim, self.num_heads, self.key_dim)) self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
if self.merge_qkv: if self.merge_qkv:
self.key = None self.key = None
# (3, self.num_heads, self.key_dim, self.q_dim) # (3, self.num_heads, self.head_dim, self.q_dim)
q_weight_t = np.transpose(self.q_weight, axes=[1, 2, 0]) q_weight_t = np.transpose(self.q_weight, axes=[1, 2, 0])
k_weight_t = np.transpose(self.k_weight, axes=[1, 2, 0]) k_weight_t = np.transpose(self.k_weight, axes=[1, 2, 0])
v_weight_t = np.transpose(self.v_weight, axes=[1, 2, 0]) v_weight_t = np.transpose(self.v_weight, axes=[1, 2, 0])
...@@ -88,15 +93,22 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -88,15 +93,22 @@ class TestFusedGateAttentionOp(OpTest):
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size)) (self.batch_size, 1, self.num_heads, self.res_len, self.m_size))
if self.has_gating: if self.has_gating:
self.gating_w = _random((self.q_dim, self.num_heads, self.key_dim)) self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim))
self.gating_b = _random((self.num_heads, self.key_dim)) self.gating_b = _random((self.num_heads, self.head_dim))
self.output_w = _random((self.num_heads, self.key_dim, self.out_dim)) self.output_w = _random((self.num_heads, self.head_dim, self.out_dim))
self.output_b = _random((self.out_dim)) self.output_b = _random((self.out_dim))
self.dout = _random( self.dout = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim)) (self.batch_size, self.msa_len, self.res_len, self.q_dim))
def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out):
outputs = [
softmax_out, fmha_out, gate_out if self.has_gating else None, out,
query.grad, None if self.merge_qkv else key.grad
]
return outputs
def get_reference_out(self): def get_reference_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
...@@ -108,44 +120,85 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -108,44 +120,85 @@ class TestFusedGateAttentionOp(OpTest):
v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False) v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False)
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
c = self.key_dim**(-0.5) c = self.head_dim**(-0.5)
# [batch_size, msa_len, num_heads, res_len, key_dim] # [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c
# [batch_size, msa_len, num_heads, m_size, key_dim] # [batch_size, msa_len, m_size, kv_dim], [kv_dim, num_heads, head_dim]
# -> [batch_size, msa_len, m_size, num_heads, head_dim]
k = paddle.einsum('nbka,ahc->nbkhc', key, k_weight) k = paddle.einsum('nbka,ahc->nbkhc', key, k_weight)
# [batch_size, msa_len, num_heads, m_size, key_dim] # [batch_size, msa_len, m_size, kv_dim], [kv_dim, num_heads, head_dim]
# -> [batch_size, msa_len, m_size, num_heads, head_dim]
v = paddle.einsum('nbka,ahc->nbkhc', key, v_weight) v = paddle.einsum('nbka,ahc->nbkhc', key, v_weight)
# [batch_size, msa_len, num_heads, res_len, m_size] # [batch_size, msa_len, res_len, num_heads, head_dim], [batch_size, msa_len, m_size, num_heads, head_dim]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) # qk_out logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) # qk_out
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, mas_len, 1, 1, m_size]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + src_mask logits = logits + src_mask
if self.bias_attr: if self.bias_attr:
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias, nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
stop_gradient=False) stop_gradient=False)
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + nonbatched_bias logits = logits + nonbatched_bias
weights = nn.functional.softmax(logits) # softmax_out # [batch_size, msa_len, num_heads, res_len, m_size]
weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) softmax_out = nn.functional.softmax(logits)
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, msa_len, m_size, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
# fmha_out = paddle.einsum('nbhqk,nbkhc->nbqhc', softmax_out, v)
v_trans = paddle.transpose(v, perm=[0, 1, 3, 2, 4])
qktv_out = paddle.matmul(softmax_out, v_trans)
fmha_out = paddle.transpose(qktv_out, perm=[0, 1, 3, 2, 4])
if self.has_gating: if self.has_gating:
gating_w = paddle.to_tensor(self.gating_w, stop_gradient=False) gating_w = paddle.to_tensor(self.gating_w, stop_gradient=False)
gating_b = paddle.to_tensor(self.gating_b, stop_gradient=False) gating_b = paddle.to_tensor(self.gating_b, stop_gradient=False)
gate_values = paddle.einsum('nbqc,chv->nbqhv', query, # [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
gating_w) + gating_b # -> [batch_size, msa_len, res_len, num_heads, head_dim]
# gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
# gating_w) + gating_b
gating_w_2d = paddle.reshape(
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim])
gate_values_4d = paddle.matmul(query, gating_w_2d)
gate_values = paddle.reshape(
gate_values_4d,
shape=[
self.batch_size, self.msa_len, self.res_len, self.num_heads,
self.head_dim
]) + gating_b
gate_values = nn.functional.sigmoid(gate_values) gate_values = nn.functional.sigmoid(gate_values)
weighted_avg = weighted_avg * gate_values gate_out = fmha_out * gate_values
else:
gate_out = fmha_out
output_b = paddle.to_tensor(self.output_b, stop_gradient=False) output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
output_w = paddle.to_tensor(self.output_w, stop_gradient=False) output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
out = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, # [batch_size, msa_len, res_len, num_heads, head_dim], [num_heads, head_dim, out_dim]
output_w) + output_b # -> [batch_size, msa_len, res_len, out_dim]
# out = paddle.einsum('nbqhc,hco->nbqo', gate_out,
# output_w) + output_b
gate_out_2d = paddle.reshape(
gate_out,
shape=[
self.batch_size * self.msa_len * self.res_len,
self.num_heads * self.head_dim
])
output_w_2d = paddle.reshape(
output_w, shape=[self.num_heads * self.head_dim, self.out_dim])
out_2d = paddle.matmul(gate_out_2d, output_w_2d)
out = paddle.reshape(
out_2d,
shape=[self.batch_size, self.msa_len, self.res_len, self.out_dim
]) + output_b
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
retain_graph=True) retain_graph=True)
if self.merge_qkv: return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
return out, query.grad, None out)
else:
return out, query.grad, key.grad
def get_fused_gate_attention_out(self): def get_fused_gate_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
...@@ -181,40 +234,59 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -181,40 +234,59 @@ class TestFusedGateAttentionOp(OpTest):
output_w = paddle.to_tensor(self.output_w, stop_gradient=False) output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
output_b = paddle.to_tensor(self.output_b, stop_gradient=False) output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
_, _, _, _, _, _, _, out = _C_ops.fused_gate_attention( _, _, _, _, softmax_out, fmha_out, gate_out, out = _C_ops.fused_gate_attention(
query, key, q_weight, k_weight, v_weight, qkv_weight, query, key, q_weight, k_weight, v_weight, qkv_weight,
nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b, nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b,
'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv) 'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv)
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
retain_graph=True) retain_graph=True)
if key is not None: return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
return out, query.grad, key.grad out)
else:
return out, query.grad, None
def check_output_and_grad(self, atol, rtol): def check(self, ref, out, atol, rtol, check_equal, name):
def _convert(value): def _convert(value):
if self.dtype == "bfloat16": if self.dtype == "bfloat16":
return convert_uint16_to_float(value) return convert_uint16_to_float(value)
return value return value
output_names = ["out", "query_grad", "key_grad"] if check_equal:
self.assertTrue(
np.equal(_convert(ref), _convert(out)).all(),
"Checking < {} > failed!".format(name))
else:
np.testing.assert_allclose(
_convert(ref),
_convert(out),
atol=atol,
rtol=rtol,
err_msg="Checking < {} > failed!".format(name))
def check_output_and_grad(self, atol, rtol):
output_names = [
"softmax_out", "fmha_out", "gate_out", "out", "query_grad",
"key_grad"
]
outputs_ref = self.get_reference_out() outputs_ref = self.get_reference_out()
outputs_fused = self.get_fused_gate_attention_out() outputs_fused = self.get_fused_gate_attention_out()
for i in range(len(outputs_fused)): for i in range(len(output_names)):
ref_res = outputs_ref[i] ref_res = outputs_ref[i]
fused_res = outputs_fused[i] fused_res = outputs_fused[i]
if ref_res is not None and fused_res is not None: if ref_res is not None and fused_res is not None:
print("Checking {}".format(output_names[i])) # The python implementation of einsum is likely to call
np.testing.assert_allclose(_convert(ref_res), # matmul(x, y, transpose_x=False, transpose_y=True). With different
_convert(fused_res.numpy()), # transpose_x and transpose_y, cublas will launch different kernels
atol=atol, # and the result cannot be exactly equal.
rtol=rtol) # Because the arguments of matmul in einsum is the the same as
# that in fused ops, check_equal is set to False and we use allclose
# to check the correctness.
check_equal = False
self.check(ref_res.numpy(), fused_res.numpy(), atol, rtol,
check_equal, output_names[i])
def test_output_and_grad(self): def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-5, rtol=1e-5) self.check_output_and_grad(atol=1e-5, rtol=1e-6)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp): class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
...@@ -234,7 +306,7 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp): ...@@ -234,7 +306,7 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp):
self.res_len = 5 self.res_len = 5
self.q_dim = 6 self.q_dim = 6
self.num_heads = 2 self.num_heads = 2
self.key_dim = 4 self.head_dim = 4
self.m_size = 4 self.m_size = 4
self.kv_dim = 2 self.kv_dim = 2
self.out_dim = self.q_dim self.out_dim = self.q_dim
...@@ -279,7 +351,7 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp): ...@@ -279,7 +351,7 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
self.dtype = "bfloat16" self.dtype = "bfloat16"
def test_output_and_grad(self): def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-1, rtol=1e-3) self.check_output_and_grad(atol=1e-1, rtol=1e-2)
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case): class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
......
...@@ -3263,9 +3263,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): ...@@ -3263,9 +3263,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
if x.dtype != convert_np_dtype_to_dtype_(dtype): if x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype) x = cast(x, dtype)
input = x
dim = axis dim = axis
keep_dim = keepdim
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
if isinstance(dim, tuple): if isinstance(dim, tuple):
dim = list(dim) dim = list(dim)
...@@ -3275,24 +3273,29 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): ...@@ -3275,24 +3273,29 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
raise TypeError( raise TypeError(
"The type of axis must be int, list or tuple, but received {}". "The type of axis must be int, list or tuple, but received {}".
format(type(dim))) format(type(dim)))
reduce_all = True if dim is None or len(dim) == 0 or len(dim) == len(x.shape) else False
if dim is None or len(dim) == 0:
dim = [0]
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_reduce_prod( return _C_ops.final_state_reduce_prod(x, dim, keepdim, reduce_all)
input, dim if dim != None and dim != [] else [0], keep_dim, True if if _in_legacy_dygraph():
dim == None or dim == [] or len(dim) == len(input.shape) else False) return _C_ops.reduce_prod(
x, 'dim', dim, 'keep_dim', keepdim, 'reduce_all', reduce_all)
helper = LayerHelper('reduce_prod', **locals()) helper = LayerHelper('reduce_prod', **locals())
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod') x, 'x/input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod')
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='reduce_prod', type='reduce_prod',
inputs={'X': input}, inputs={'X': x},
outputs={'Out': out}, outputs={'Out': out},
attrs={ attrs={
'dim': dim if dim != None and dim != [] else [0], 'dim': dim,
'keep_dim': keep_dim, 'keep_dim': keepdim,
'reduce_all': True if dim == None or dim == [] or 'reduce_all': reduce_all
len(dim) == len(input.shape) else False
}) })
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册