From 9b7126d05987b725ad3fb31f31298218c860b2f5 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 15 Jun 2022 20:26:28 +0800 Subject: [PATCH] Optimize prod's python implementation for dygraph. (#43309) * Optimize prod's python implementation for dygraph. * Change key_dim to head_dim. * Add comment in unittest. * Disable TF32 in unittest. --- .../operators/fused/fused_gate_attention.h | 58 +++---- .../fused/fused_gate_attention_op.cc | 20 +-- .../fused/fused_gate_attention_op.cu | 34 ++-- .../unittests/test_fused_gate_attention_op.py | 152 +++++++++++++----- python/paddle/tensor/math.py | 25 +-- 5 files changed, 181 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index d7ed144f02..3f725d3c1c 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -68,7 +68,7 @@ struct GateAttentionConfig { int64_t seq_len_r; int64_t q_dim; int64_t kv_dim; - int64_t key_dim; + int64_t head_dim; int64_t m_size; int64_t num_heads; @@ -103,15 +103,15 @@ struct GateAttentionConfig { "when merge_qkv is true.")); // When q_dim == kv_dim, QKV matmul can be computed merged. - // qkv_weight: shape=[3, num_heads, key_dim, q_dim] + // qkv_weight: shape=[3, num_heads, head_dim, q_dim] num_heads = qkv_weight->dims()[1]; - key_dim = qkv_weight->dims()[2]; + head_dim = qkv_weight->dims()[2]; m_size = seq_len_r; kv_dim = q_dim; - qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim}; + qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim}; qkv_transpose_out_dims = {3, batch_size, seq_len_m, - num_heads, seq_len_r, key_dim}; + num_heads, seq_len_r, head_dim}; } else { PADDLE_ENFORCE_NOT_NULL( key, @@ -124,28 +124,28 @@ struct GateAttentionConfig { // When q_dim != kv_dim, QKV matmul must be computed saparately. // key: shape=[batch_size, seq_len_m, m_size, kv_dim] - // query_w: shape=[q_dim, num_heads, key_dim] + // query_w: shape=[q_dim, num_heads, head_dim] num_heads = query_weight->dims()[1]; - key_dim = query_weight->dims()[2]; + head_dim = query_weight->dims()[2]; m_size = key->dims()[2]; kv_dim = key->dims()[3]; - q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, key_dim}; - kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, key_dim}; + q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim}; + kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, head_dim}; q_transpose_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, - key_dim}; + head_dim}; kv_transpose_out_dims = {batch_size, seq_len_m, num_heads, m_size, - key_dim}; + head_dim}; } qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size}; softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size}; - qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, key_dim}; - gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, key_dim}; + qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, head_dim}; + gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim}; } int64_t GetQuerySize() const { - return batch_size * seq_len_m * seq_len_r * num_heads * key_dim; + return batch_size * seq_len_m * seq_len_r * num_heads * head_dim; } Tensor* GetQKVOut() { @@ -365,8 +365,8 @@ class FMHAGateRef { } // qk_out = BatchedGEMM(Q, K^T) - // [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] * - // [batch_size, seq_len_m, num_heads, m_size, key_dim] + // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] * + // [batch_size, seq_len_m, num_heads, m_size, head_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size] Tensor* qk_out = config->GetQKOut(softmax_out); T* qk_out_ptr = qk_out->data(); @@ -375,9 +375,9 @@ class FMHAGateRef { config->batch_size * config->seq_len_m * config->num_heads; int64_t gemm_m = config->seq_len_r; int64_t gemm_n = config->m_size; - int64_t gemm_k = config->key_dim; + int64_t gemm_k = config->head_dim; - T alpha = static_cast(1.0 / sqrt(config->key_dim)); + T alpha = static_cast(1.0 / sqrt(config->head_dim)); ComputeBatchedGEMM(q_ptr, k_ptr, qk_out_ptr, false, true, gemm_m, gemm_n, gemm_k, gemm_batch_size, alpha); @@ -388,13 +388,13 @@ class FMHAGateRef { // qktv_out = BatchedGEMM(softmax_out, V) // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * - // [batch_size, seq_len_m, num_heads, m_size, key_dim] - // -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] + // [batch_size, seq_len_m, num_heads, m_size, head_dim] + // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] Tensor* qktv_out = config->GetQKTVOut(gate_out); T* qktv_out_ptr = qktv_out->data(); gemm_m = config->seq_len_r; - gemm_n = config->key_dim; + gemm_n = config->head_dim; gemm_k = config->m_size; T* softmax_out_ptr = softmax_out->data(); @@ -490,7 +490,7 @@ class FMHAGateRef { // Backward: // V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout) int64_t gemm_m = config->m_size; - int64_t gemm_n = config->key_dim; + int64_t gemm_n = config->head_dim; int64_t gemm_k = config->seq_len_r; const T* softmax_out_ptr = softmax_out->data(); @@ -501,7 +501,7 @@ class FMHAGateRef { // Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T) gemm_m = config->seq_len_r; gemm_n = config->m_size; - gemm_k = config->key_dim; + gemm_k = config->head_dim; T* softmax_out_grad_ptr = softmax_out_grad.data(); ComputeBatchedGEMM(qktv_out_grad_ptr, v_ptr, softmax_out_grad_ptr, false, @@ -516,9 +516,9 @@ class FMHAGateRef { // Forward: qk_out = BatchedGEMM(Q, K^T) // Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x) int64_t gemm_m = config->m_size; - int64_t gemm_n = config->key_dim; + int64_t gemm_n = config->head_dim; int64_t gemm_k = config->seq_len_r; - T alpha = static_cast(1.0 / sqrt(config->key_dim)); + T alpha = static_cast(1.0 / sqrt(config->head_dim)); T* qk_out_grad_ptr = qk_out_grad->data(); ComputeBatchedGEMM(qk_out_grad_ptr, q_ptr, k_grad_ptr, true, false, gemm_m, @@ -526,7 +526,7 @@ class FMHAGateRef { // Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y) gemm_m = config->seq_len_r; - gemm_n = config->key_dim; + gemm_n = config->head_dim; gemm_k = config->m_size; ComputeBatchedGEMM(qk_out_grad_ptr, k_ptr, q_grad_ptr, false, false, gemm_m, gemm_n, gemm_k, gemm_batch_size, alpha); @@ -570,8 +570,8 @@ class FMHAGateRef { v_out_grad); } - // [batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim] -> - // [3, batch_size, seq_len_m, num_heads, seq_len_r, key_dim] + // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] -> + // [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim] void ComputeQKVTransposeForward(const Tensor& qkv_out, Tensor* qkv_transpose_out) { int ndims = 6; @@ -610,7 +610,7 @@ class FMHAGateRef { const Tensor* src_mask, Tensor* qk_out, Tensor* softmax_out) { if (nonbatched_bias) { - std::vector ins = {qk_out, nonbatched_bias, src_mask}; + std::vector ins = {qk_out, src_mask, nonbatched_bias}; std::vector outs = {qk_out}; phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, TernaryAddFunctor()); diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cc b/paddle/fluid/operators/fused/fused_gate_attention_op.cc index 506f437b1a..e814601785 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cc @@ -47,10 +47,10 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { int seq_len_m = input_q_dims[1]; int seq_len_r = input_q_dims[2]; - int num_head, m_size, key_dim; + int num_head, m_size, head_dim; if (ctx->Attrs().Get("merge_qkv")) { // QKV's input: [batch_size, seq_len_m, seq_len_r, qkv_dim] - // QKV's weight: [3, num_head, key_dim, qkv_dim] + // QKV's weight: [3, num_head, head_dim, qkv_dim] OP_INOUT_CHECK(ctx->HasInput("QKVWeight"), "Input", "QKVWeight", "fused_gate_attention"); OP_INOUT_CHECK(ctx->HasOutput("QKVTransposeOut"), "Output", @@ -59,11 +59,11 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { auto qkv_w_dims = ctx->GetInputDim("QKVWeight"); num_head = qkv_w_dims[1]; - key_dim = qkv_w_dims[2]; + head_dim = qkv_w_dims[2]; m_size = seq_len_r; ctx->SetOutputDim("QKVTransposeOut", {3, batch_size, seq_len_m, num_head, - seq_len_r, key_dim}); + seq_len_r, head_dim}); } else { OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), "Input", "QueryWeight", "fused_gate_attention"); @@ -76,21 +76,21 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { auto q_w_dims = ctx->GetInputDim("QueryWeight"); num_head = q_w_dims[1]; - key_dim = q_w_dims[2]; + head_dim = q_w_dims[2]; m_size = input_k_dims[2]; ctx->SetOutputDim("QueryTransposeOut", - {batch_size, seq_len_m, num_head, seq_len_r, key_dim}); + {batch_size, seq_len_m, num_head, seq_len_r, head_dim}); ctx->SetOutputDim("KeyTransposeOut", - {batch_size, seq_len_m, num_head, m_size, key_dim}); + {batch_size, seq_len_m, num_head, m_size, head_dim}); ctx->SetOutputDim("ValueTransposeOut", - {batch_size, seq_len_m, num_head, m_size, key_dim}); + {batch_size, seq_len_m, num_head, m_size, head_dim}); } ctx->SetOutputDim("SoftmaxOut", {batch_size, seq_len_m, num_head, seq_len_r, m_size}); ctx->SetOutputDim("FMHAOut", - {batch_size, seq_len_m, seq_len_r, num_head, key_dim}); + {batch_size, seq_len_m, seq_len_r, num_head, head_dim}); if (ctx->Attrs().Get("has_gating")) { OP_INOUT_CHECK(ctx->HasInput("GateWeight"), "Input", "GateWeight", @@ -98,7 +98,7 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("GateBias"), "Input", "GateBias", "fused_gate_attention"); ctx->SetOutputDim("GateOut", - {batch_size, seq_len_m, seq_len_r, num_head, key_dim}); + {batch_size, seq_len_m, seq_len_r, num_head, head_dim}); } ctx->SetOutputDim("Out", ctx->GetInputDim("Query")); diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cu b/paddle/fluid/operators/fused/fused_gate_attention_op.cu index ebc9a4f98d..c0f4d158aa 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cu @@ -65,13 +65,13 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx, const GateAttentionConfig &config, const Tensor *query, Tensor *qkv_out) { // query: shape=[batch_size, seq_len_m, seq_len_r, qkv_dim] - // qkv_weight: shape=[3, num_heads, key_dim, qkv_dim] - // qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim] + // qkv_weight: shape=[3, num_heads, head_dim, qkv_dim] + // qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] auto *qkv_weight = ctx.Input("QKVWeight"); // qkv_out = GEMM(query, qkv_weight^T) int m = config.batch_size * config.seq_len_m * config.seq_len_r; - int n = 3 * config.num_heads * config.key_dim; + int n = 3 * config.num_heads * config.head_dim; int k = config.q_dim; auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, m, n, k, false); @@ -91,7 +91,7 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, // Gradient of GEMM(query, qkv_weight) int m = config.batch_size * config.seq_len_m * config.seq_len_r; - int n = 3 * config.num_heads * config.key_dim; + int n = 3 * config.num_heads * config.head_dim; int k = config.q_dim; auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, m, n, k, false); @@ -111,10 +111,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, // query_out = GEMM(query, query_weight) // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] - // query_weight: shape=[q_dim, num_heads, key_dim] - // query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, key_dim] + // query_weight: shape=[q_dim, num_heads, head_dim] + // query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, head_dim] int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; - int q_n = config.num_heads * config.key_dim; + int q_n = config.num_heads * config.head_dim; int q_k = config.q_dim; auto q_compute = AttnMatMul(ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false); @@ -122,10 +122,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, // k_out = GEMM(key, key_weight) // key: shape=[batch_size, seq_len_m, m_size, kv_dim] - // key_weight: shape=[kv_dim, num_heads, key_dim] - // key_out: shape=[batch_size, seq_len_m, m_size, num_heads, key_dim] + // key_weight: shape=[kv_dim, num_heads, head_dim] + // key_out: shape=[batch_size, seq_len_m, m_size, num_heads, head_dim] int kv_m = config.batch_size * config.seq_len_m * config.m_size; - int kv_n = config.num_heads * config.key_dim; + int kv_n = config.num_heads * config.head_dim; int kv_k = config.kv_dim; auto kv_compute = AttnMatMul(ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false); @@ -151,7 +151,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, key_weight_grad->mutable_data(ctx.GetPlace()); int kv_m = config.batch_size * config.seq_len_m * config.m_size; - int kv_n = config.num_heads * config.key_dim; + int kv_n = config.num_heads * config.head_dim; int kv_k = config.kv_dim; auto kv_compute = AttnMatMul(ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false); @@ -174,7 +174,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, query_weight_grad->mutable_data(ctx.GetPlace()); int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; - int q_n = config.num_heads * config.key_dim; + int q_n = config.num_heads * config.head_dim; int q_k = config.q_dim; auto q_compute = AttnMatMul(ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false); @@ -195,7 +195,7 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, // bias. // gate_out = GEMM(query, gate_weight) + gate_bias int m = config.batch_size * config.seq_len_m * config.seq_len_r; - int n = config.num_heads * config.key_dim; + int n = config.num_heads * config.head_dim; int k = config.q_dim; auto gate_attn_compute = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); @@ -224,7 +224,7 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, gate_bias_out.mutable_data(ctx.GetPlace()); int m = config.batch_size * config.seq_len_m * config.seq_len_r; - int n = config.num_heads * config.key_dim; + int n = config.num_heads * config.head_dim; int k = config.q_dim; auto gate_attn_compute = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); @@ -260,7 +260,7 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; - int k = config.num_heads * config.key_dim; + int k = config.num_heads * config.head_dim; auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out, @@ -282,11 +282,9 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, out_linear_weight_grad->mutable_data(ctx.GetPlace()); out_linear_bias_grad->mutable_data(ctx.GetPlace()); - auto &dev_ctx = ctx.template device_context(); - int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; - int k = config.num_heads * config.key_dim; + int k = config.num_heads * config.head_dim; auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad, diff --git a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py index 52418bba63..0aad7ec758 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +os.environ['NVIDIA_TF32_OVERRIDE'] = "0" +os.environ['FLAGS_new_einsum'] = "0" + import numpy as np import paddle @@ -47,7 +52,7 @@ class TestFusedGateAttentionOp(OpTest): self.res_len = 5 self.q_dim = 6 self.num_heads = 2 - self.key_dim = 4 + self.head_dim = 4 self.m_size = self.res_len self.kv_dim = self.q_dim self.out_dim = self.q_dim @@ -65,12 +70,12 @@ class TestFusedGateAttentionOp(OpTest): np.random.seed(123) self.query = _random( (self.batch_size, self.msa_len, self.res_len, self.q_dim)) - self.q_weight = _random((self.q_dim, self.num_heads, self.key_dim)) - self.k_weight = _random((self.kv_dim, self.num_heads, self.key_dim)) - self.v_weight = _random((self.kv_dim, self.num_heads, self.key_dim)) + self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim)) + self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim)) + self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim)) if self.merge_qkv: self.key = None - # (3, self.num_heads, self.key_dim, self.q_dim) + # (3, self.num_heads, self.head_dim, self.q_dim) q_weight_t = np.transpose(self.q_weight, axes=[1, 2, 0]) k_weight_t = np.transpose(self.k_weight, axes=[1, 2, 0]) v_weight_t = np.transpose(self.v_weight, axes=[1, 2, 0]) @@ -88,15 +93,22 @@ class TestFusedGateAttentionOp(OpTest): (self.batch_size, 1, self.num_heads, self.res_len, self.m_size)) if self.has_gating: - self.gating_w = _random((self.q_dim, self.num_heads, self.key_dim)) - self.gating_b = _random((self.num_heads, self.key_dim)) + self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim)) + self.gating_b = _random((self.num_heads, self.head_dim)) - self.output_w = _random((self.num_heads, self.key_dim, self.out_dim)) + self.output_w = _random((self.num_heads, self.head_dim, self.out_dim)) self.output_b = _random((self.out_dim)) self.dout = _random( (self.batch_size, self.msa_len, self.res_len, self.q_dim)) + def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out): + outputs = [ + softmax_out, fmha_out, gate_out if self.has_gating else None, out, + query.grad, None if self.merge_qkv else key.grad + ] + return outputs + def get_reference_out(self): paddle.disable_static(place=paddle.CUDAPlace(0)) @@ -108,44 +120,85 @@ class TestFusedGateAttentionOp(OpTest): v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False) src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) - c = self.key_dim**(-0.5) - # [batch_size, msa_len, num_heads, res_len, key_dim] + c = self.head_dim**(-0.5) + # [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim] + # -> [batch_size, msa_len, res_len, num_heads, head_dim] q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c - # [batch_size, msa_len, num_heads, m_size, key_dim] + # [batch_size, msa_len, m_size, kv_dim], [kv_dim, num_heads, head_dim] + # -> [batch_size, msa_len, m_size, num_heads, head_dim] k = paddle.einsum('nbka,ahc->nbkhc', key, k_weight) - # [batch_size, msa_len, num_heads, m_size, key_dim] + # [batch_size, msa_len, m_size, kv_dim], [kv_dim, num_heads, head_dim] + # -> [batch_size, msa_len, m_size, num_heads, head_dim] v = paddle.einsum('nbka,ahc->nbkhc', key, v_weight) - # [batch_size, msa_len, num_heads, res_len, m_size] + # [batch_size, msa_len, res_len, num_heads, head_dim], [batch_size, msa_len, m_size, num_heads, head_dim] + # -> [batch_size, msa_len, num_heads, res_len, m_size] logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) # qk_out + # [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, mas_len, 1, 1, m_size] + # -> [batch_size, msa_len, num_heads, res_len, m_size] logits = logits + src_mask if self.bias_attr: nonbatched_bias = paddle.to_tensor(self.nonbatched_bias, stop_gradient=False) + # [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size] + # -> [batch_size, msa_len, num_heads, res_len, m_size] logits = logits + nonbatched_bias - weights = nn.functional.softmax(logits) # softmax_out - weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) + # [batch_size, msa_len, num_heads, res_len, m_size] + softmax_out = nn.functional.softmax(logits) + # [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, msa_len, m_size, num_heads, head_dim] + # -> [batch_size, msa_len, res_len, num_heads, head_dim] + # fmha_out = paddle.einsum('nbhqk,nbkhc->nbqhc', softmax_out, v) + v_trans = paddle.transpose(v, perm=[0, 1, 3, 2, 4]) + qktv_out = paddle.matmul(softmax_out, v_trans) + fmha_out = paddle.transpose(qktv_out, perm=[0, 1, 3, 2, 4]) if self.has_gating: gating_w = paddle.to_tensor(self.gating_w, stop_gradient=False) gating_b = paddle.to_tensor(self.gating_b, stop_gradient=False) - gate_values = paddle.einsum('nbqc,chv->nbqhv', query, - gating_w) + gating_b + # [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim] + # -> [batch_size, msa_len, res_len, num_heads, head_dim] + # gate_values = paddle.einsum('nbqc,chv->nbqhv', query, + # gating_w) + gating_b + gating_w_2d = paddle.reshape( + gating_w, shape=[self.q_dim, self.num_heads * self.head_dim]) + gate_values_4d = paddle.matmul(query, gating_w_2d) + gate_values = paddle.reshape( + gate_values_4d, + shape=[ + self.batch_size, self.msa_len, self.res_len, self.num_heads, + self.head_dim + ]) + gating_b gate_values = nn.functional.sigmoid(gate_values) - weighted_avg = weighted_avg * gate_values + gate_out = fmha_out * gate_values + else: + gate_out = fmha_out output_b = paddle.to_tensor(self.output_b, stop_gradient=False) output_w = paddle.to_tensor(self.output_w, stop_gradient=False) - out = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, - output_w) + output_b + # [batch_size, msa_len, res_len, num_heads, head_dim], [num_heads, head_dim, out_dim] + # -> [batch_size, msa_len, res_len, out_dim] + # out = paddle.einsum('nbqhc,hco->nbqo', gate_out, + # output_w) + output_b + gate_out_2d = paddle.reshape( + gate_out, + shape=[ + self.batch_size * self.msa_len * self.res_len, + self.num_heads * self.head_dim + ]) + output_w_2d = paddle.reshape( + output_w, shape=[self.num_heads * self.head_dim, self.out_dim]) + out_2d = paddle.matmul(gate_out_2d, output_w_2d) + out = paddle.reshape( + out_2d, + shape=[self.batch_size, self.msa_len, self.res_len, self.out_dim + ]) + output_b + paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], retain_graph=True) - if self.merge_qkv: - return out, query.grad, None - else: - return out, query.grad, key.grad + return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out, + out) def get_fused_gate_attention_out(self): paddle.disable_static(place=paddle.CUDAPlace(0)) @@ -181,40 +234,59 @@ class TestFusedGateAttentionOp(OpTest): output_w = paddle.to_tensor(self.output_w, stop_gradient=False) output_b = paddle.to_tensor(self.output_b, stop_gradient=False) - _, _, _, _, _, _, _, out = _C_ops.fused_gate_attention( + _, _, _, _, softmax_out, fmha_out, gate_out, out = _C_ops.fused_gate_attention( query, key, q_weight, k_weight, v_weight, qkv_weight, nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b, 'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv) paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], retain_graph=True) - if key is not None: - return out, query.grad, key.grad - else: - return out, query.grad, None + return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out, + out) - def check_output_and_grad(self, atol, rtol): + def check(self, ref, out, atol, rtol, check_equal, name): def _convert(value): if self.dtype == "bfloat16": return convert_uint16_to_float(value) return value - output_names = ["out", "query_grad", "key_grad"] + if check_equal: + self.assertTrue( + np.equal(_convert(ref), _convert(out)).all(), + "Checking < {} > failed!".format(name)) + else: + np.testing.assert_allclose( + _convert(ref), + _convert(out), + atol=atol, + rtol=rtol, + err_msg="Checking < {} > failed!".format(name)) + + def check_output_and_grad(self, atol, rtol): + output_names = [ + "softmax_out", "fmha_out", "gate_out", "out", "query_grad", + "key_grad" + ] outputs_ref = self.get_reference_out() outputs_fused = self.get_fused_gate_attention_out() - for i in range(len(outputs_fused)): + for i in range(len(output_names)): ref_res = outputs_ref[i] fused_res = outputs_fused[i] if ref_res is not None and fused_res is not None: - print("Checking {}".format(output_names[i])) - np.testing.assert_allclose(_convert(ref_res), - _convert(fused_res.numpy()), - atol=atol, - rtol=rtol) + # The python implementation of einsum is likely to call + # matmul(x, y, transpose_x=False, transpose_y=True). With different + # transpose_x and transpose_y, cublas will launch different kernels + # and the result cannot be exactly equal. + # Because the arguments of matmul in einsum is the the same as + # that in fused ops, check_equal is set to False and we use allclose + # to check the correctness. + check_equal = False + self.check(ref_res.numpy(), fused_res.numpy(), atol, rtol, + check_equal, output_names[i]) def test_output_and_grad(self): - self.check_output_and_grad(atol=1e-5, rtol=1e-5) + self.check_output_and_grad(atol=1e-5, rtol=1e-6) class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp): @@ -234,7 +306,7 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp): self.res_len = 5 self.q_dim = 6 self.num_heads = 2 - self.key_dim = 4 + self.head_dim = 4 self.m_size = 4 self.kv_dim = 2 self.out_dim = self.q_dim @@ -279,7 +351,7 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp): self.dtype = "bfloat16" def test_output_and_grad(self): - self.check_output_and_grad(atol=1e-1, rtol=1e-3) + self.check_output_and_grad(atol=1e-1, rtol=1e-2) class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 1cb350f4d7..0b5cc5bf64 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3263,9 +3263,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): if x.dtype != convert_np_dtype_to_dtype_(dtype): x = cast(x, dtype) - input = x dim = axis - keep_dim = keepdim if dim is not None and not isinstance(dim, list): if isinstance(dim, tuple): dim = list(dim) @@ -3275,24 +3273,29 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): raise TypeError( "The type of axis must be int, list or tuple, but received {}". format(type(dim))) + + reduce_all = True if dim is None or len(dim) == 0 or len(dim) == len(x.shape) else False + if dim is None or len(dim) == 0: + dim = [0] + if in_dygraph_mode(): - return _C_ops.final_state_reduce_prod( - input, dim if dim != None and dim != [] else [0], keep_dim, True if - dim == None or dim == [] or len(dim) == len(input.shape) else False) + return _C_ops.final_state_reduce_prod(x, dim, keepdim, reduce_all) + if _in_legacy_dygraph(): + return _C_ops.reduce_prod( + x, 'dim', dim, 'keep_dim', keepdim, 'reduce_all', reduce_all) helper = LayerHelper('reduce_prod', **locals()) check_variable_and_dtype( - input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod') + x, 'x/input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod') out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) helper.append_op( type='reduce_prod', - inputs={'X': input}, + inputs={'X': x}, outputs={'Out': out}, attrs={ - 'dim': dim if dim != None and dim != [] else [0], - 'keep_dim': keep_dim, - 'reduce_all': True if dim == None or dim == [] or - len(dim) == len(input.shape) else False + 'dim': dim, + 'keep_dim': keepdim, + 'reduce_all': reduce_all }) return out -- GitLab