未验证 提交 9b7126d0 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize prod's python implementation for dygraph. (#43309)

* Optimize prod's python implementation for dygraph.

* Change key_dim to head_dim.

* Add comment in unittest.

* Disable TF32 in unittest.
上级 4d0ca02b
......@@ -68,7 +68,7 @@ struct GateAttentionConfig {
int64_t seq_len_r;
int64_t q_dim;
int64_t kv_dim;
int64_t key_dim;
int64_t head_dim;
int64_t m_size;
int64_t num_heads;
......@@ -103,15 +103,15 @@ struct GateAttentionConfig {
"when merge_qkv is true."));
// When q_dim == kv_dim, QKV matmul can be computed merged.
// qkv_weight: shape=[3, num_heads, key_dim, q_dim]
// qkv_weight: shape=[3, num_heads, head_dim, q_dim]
num_heads = qkv_weight->dims()[1];
key_dim = qkv_weight->dims()[2];
head_dim = qkv_weight->dims()[2];
m_size = seq_len_r;
kv_dim = q_dim;
qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim};
qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim};
qkv_transpose_out_dims = {3, batch_size, seq_len_m,
num_heads, seq_len_r, key_dim};
num_heads, seq_len_r, head_dim};
} else {
PADDLE_ENFORCE_NOT_NULL(
key,
......@@ -124,28 +124,28 @@ struct GateAttentionConfig {
// When q_dim != kv_dim, QKV matmul must be computed saparately.
// key: shape=[batch_size, seq_len_m, m_size, kv_dim]
// query_w: shape=[q_dim, num_heads, key_dim]
// query_w: shape=[q_dim, num_heads, head_dim]
num_heads = query_weight->dims()[1];
key_dim = query_weight->dims()[2];
head_dim = query_weight->dims()[2];
m_size = key->dims()[2];
kv_dim = key->dims()[3];
q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, key_dim};
kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, key_dim};
q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, head_dim};
q_transpose_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r,
key_dim};
head_dim};
kv_transpose_out_dims = {batch_size, seq_len_m, num_heads, m_size,
key_dim};
head_dim};
}
qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, key_dim};
gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, key_dim};
qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
}
int64_t GetQuerySize() const {
return batch_size * seq_len_m * seq_len_r * num_heads * key_dim;
return batch_size * seq_len_m * seq_len_r * num_heads * head_dim;
}
Tensor* GetQKVOut() {
......@@ -365,8 +365,8 @@ class FMHAGateRef {
}
// qk_out = BatchedGEMM(Q, K^T)
// [batch_size, seq_len_m, num_heads, seq_len_r, key_dim] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim]
// [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] *
// [batch_size, seq_len_m, num_heads, m_size, head_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
Tensor* qk_out = config->GetQKOut(softmax_out);
T* qk_out_ptr = qk_out->data<T>();
......@@ -375,9 +375,9 @@ class FMHAGateRef {
config->batch_size * config->seq_len_m * config->num_heads;
int64_t gemm_m = config->seq_len_r;
int64_t gemm_n = config->m_size;
int64_t gemm_k = config->key_dim;
int64_t gemm_k = config->head_dim;
T alpha = static_cast<T>(1.0 / sqrt(config->key_dim));
T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
ComputeBatchedGEMM(q_ptr, k_ptr, qk_out_ptr, false, true, gemm_m, gemm_n,
gemm_k, gemm_batch_size, alpha);
......@@ -388,13 +388,13 @@ class FMHAGateRef {
// qktv_out = BatchedGEMM(softmax_out, V)
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
// [batch_size, seq_len_m, num_heads, m_size, key_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, key_dim]
// [batch_size, seq_len_m, num_heads, m_size, head_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
Tensor* qktv_out = config->GetQKTVOut(gate_out);
T* qktv_out_ptr = qktv_out->data<T>();
gemm_m = config->seq_len_r;
gemm_n = config->key_dim;
gemm_n = config->head_dim;
gemm_k = config->m_size;
T* softmax_out_ptr = softmax_out->data<T>();
......@@ -490,7 +490,7 @@ class FMHAGateRef {
// Backward:
// V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout)
int64_t gemm_m = config->m_size;
int64_t gemm_n = config->key_dim;
int64_t gemm_n = config->head_dim;
int64_t gemm_k = config->seq_len_r;
const T* softmax_out_ptr = softmax_out->data<T>();
......@@ -501,7 +501,7 @@ class FMHAGateRef {
// Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T)
gemm_m = config->seq_len_r;
gemm_n = config->m_size;
gemm_k = config->key_dim;
gemm_k = config->head_dim;
T* softmax_out_grad_ptr = softmax_out_grad.data<T>();
ComputeBatchedGEMM(qktv_out_grad_ptr, v_ptr, softmax_out_grad_ptr, false,
......@@ -516,9 +516,9 @@ class FMHAGateRef {
// Forward: qk_out = BatchedGEMM(Q, K^T)
// Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x)
int64_t gemm_m = config->m_size;
int64_t gemm_n = config->key_dim;
int64_t gemm_n = config->head_dim;
int64_t gemm_k = config->seq_len_r;
T alpha = static_cast<T>(1.0 / sqrt(config->key_dim));
T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
T* qk_out_grad_ptr = qk_out_grad->data<T>();
ComputeBatchedGEMM(qk_out_grad_ptr, q_ptr, k_grad_ptr, true, false, gemm_m,
......@@ -526,7 +526,7 @@ class FMHAGateRef {
// Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y)
gemm_m = config->seq_len_r;
gemm_n = config->key_dim;
gemm_n = config->head_dim;
gemm_k = config->m_size;
ComputeBatchedGEMM(qk_out_grad_ptr, k_ptr, q_grad_ptr, false, false, gemm_m,
gemm_n, gemm_k, gemm_batch_size, alpha);
......@@ -570,8 +570,8 @@ class FMHAGateRef {
v_out_grad);
}
// [batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim] ->
// [3, batch_size, seq_len_m, num_heads, seq_len_r, key_dim]
// [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] ->
// [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
void ComputeQKVTransposeForward(const Tensor& qkv_out,
Tensor* qkv_transpose_out) {
int ndims = 6;
......@@ -610,7 +610,7 @@ class FMHAGateRef {
const Tensor* src_mask, Tensor* qk_out,
Tensor* softmax_out) {
if (nonbatched_bias) {
std::vector<const Tensor*> ins = {qk_out, nonbatched_bias, src_mask};
std::vector<const Tensor*> ins = {qk_out, src_mask, nonbatched_bias};
std::vector<Tensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
......
......@@ -47,10 +47,10 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
int seq_len_m = input_q_dims[1];
int seq_len_r = input_q_dims[2];
int num_head, m_size, key_dim;
int num_head, m_size, head_dim;
if (ctx->Attrs().Get<bool>("merge_qkv")) {
// QKV's input: [batch_size, seq_len_m, seq_len_r, qkv_dim]
// QKV's weight: [3, num_head, key_dim, qkv_dim]
// QKV's weight: [3, num_head, head_dim, qkv_dim]
OP_INOUT_CHECK(ctx->HasInput("QKVWeight"), "Input", "QKVWeight",
"fused_gate_attention");
OP_INOUT_CHECK(ctx->HasOutput("QKVTransposeOut"), "Output",
......@@ -59,11 +59,11 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
auto qkv_w_dims = ctx->GetInputDim("QKVWeight");
num_head = qkv_w_dims[1];
key_dim = qkv_w_dims[2];
head_dim = qkv_w_dims[2];
m_size = seq_len_r;
ctx->SetOutputDim("QKVTransposeOut", {3, batch_size, seq_len_m, num_head,
seq_len_r, key_dim});
seq_len_r, head_dim});
} else {
OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), "Input", "QueryWeight",
"fused_gate_attention");
......@@ -76,21 +76,21 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
auto q_w_dims = ctx->GetInputDim("QueryWeight");
num_head = q_w_dims[1];
key_dim = q_w_dims[2];
head_dim = q_w_dims[2];
m_size = input_k_dims[2];
ctx->SetOutputDim("QueryTransposeOut",
{batch_size, seq_len_m, num_head, seq_len_r, key_dim});
{batch_size, seq_len_m, num_head, seq_len_r, head_dim});
ctx->SetOutputDim("KeyTransposeOut",
{batch_size, seq_len_m, num_head, m_size, key_dim});
{batch_size, seq_len_m, num_head, m_size, head_dim});
ctx->SetOutputDim("ValueTransposeOut",
{batch_size, seq_len_m, num_head, m_size, key_dim});
{batch_size, seq_len_m, num_head, m_size, head_dim});
}
ctx->SetOutputDim("SoftmaxOut",
{batch_size, seq_len_m, num_head, seq_len_r, m_size});
ctx->SetOutputDim("FMHAOut",
{batch_size, seq_len_m, seq_len_r, num_head, key_dim});
{batch_size, seq_len_m, seq_len_r, num_head, head_dim});
if (ctx->Attrs().Get<bool>("has_gating")) {
OP_INOUT_CHECK(ctx->HasInput("GateWeight"), "Input", "GateWeight",
......@@ -98,7 +98,7 @@ class FusedGateAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("GateBias"), "Input", "GateBias",
"fused_gate_attention");
ctx->SetOutputDim("GateOut",
{batch_size, seq_len_m, seq_len_r, num_head, key_dim});
{batch_size, seq_len_m, seq_len_r, num_head, head_dim});
}
ctx->SetOutputDim("Out", ctx->GetInputDim("Query"));
......
......@@ -65,13 +65,13 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
const GateAttentionConfig<T> &config,
const Tensor *query, Tensor *qkv_out) {
// query: shape=[batch_size, seq_len_m, seq_len_r, qkv_dim]
// qkv_weight: shape=[3, num_heads, key_dim, qkv_dim]
// qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, key_dim]
// qkv_weight: shape=[3, num_heads, head_dim, qkv_dim]
// qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim]
auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");
// qkv_out = GEMM(query, qkv_weight^T)
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = 3 * config.num_heads * config.key_dim;
int n = 3 * config.num_heads * config.head_dim;
int k = config.q_dim;
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
......@@ -91,7 +91,7 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
// Gradient of GEMM(query, qkv_weight)
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = 3 * config.num_heads * config.key_dim;
int n = 3 * config.num_heads * config.head_dim;
int k = config.q_dim;
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
......@@ -111,10 +111,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
// query_out = GEMM(query, query_weight)
// query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
// query_weight: shape=[q_dim, num_heads, key_dim]
// query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, key_dim]
// query_weight: shape=[q_dim, num_heads, head_dim]
// query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, head_dim]
int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
int q_n = config.num_heads * config.key_dim;
int q_n = config.num_heads * config.head_dim;
int q_k = config.q_dim;
auto q_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, q_m,
q_n, q_k, false);
......@@ -122,10 +122,10 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
// k_out = GEMM(key, key_weight)
// key: shape=[batch_size, seq_len_m, m_size, kv_dim]
// key_weight: shape=[kv_dim, num_heads, key_dim]
// key_out: shape=[batch_size, seq_len_m, m_size, num_heads, key_dim]
// key_weight: shape=[kv_dim, num_heads, head_dim]
// key_out: shape=[batch_size, seq_len_m, m_size, num_heads, head_dim]
int kv_m = config.batch_size * config.seq_len_m * config.m_size;
int kv_n = config.num_heads * config.key_dim;
int kv_n = config.num_heads * config.head_dim;
int kv_k = config.kv_dim;
auto kv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, kv_m,
kv_n, kv_k, false);
......@@ -151,7 +151,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
key_weight_grad->mutable_data<T>(ctx.GetPlace());
int kv_m = config.batch_size * config.seq_len_m * config.m_size;
int kv_n = config.num_heads * config.key_dim;
int kv_n = config.num_heads * config.head_dim;
int kv_k = config.kv_dim;
auto kv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, kv_m,
kv_n, kv_k, false);
......@@ -174,7 +174,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
query_weight_grad->mutable_data<T>(ctx.GetPlace());
int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
int q_n = config.num_heads * config.key_dim;
int q_n = config.num_heads * config.head_dim;
int q_k = config.q_dim;
auto q_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, false, q_m,
q_n, q_k, false);
......@@ -195,7 +195,7 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
// bias.
// gate_out = GEMM(query, gate_weight) + gate_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.key_dim;
int n = config.num_heads * config.head_dim;
int k = config.q_dim;
auto gate_attn_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
......@@ -224,7 +224,7 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
gate_bias_out.mutable_data<T>(ctx.GetPlace());
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.num_heads * config.key_dim;
int n = config.num_heads * config.head_dim;
int k = config.q_dim;
auto gate_attn_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
......@@ -260,7 +260,7 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
// out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim;
int k = config.num_heads * config.key_dim;
int k = config.num_heads * config.head_dim;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeForward(out_linear_weight, fmha_or_gate_out,
......@@ -282,11 +282,9 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
out_linear_weight_grad->mutable_data<T>(ctx.GetPlace());
out_linear_bias_grad->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int m = config.batch_size * config.seq_len_m * config.seq_len_r;
int n = config.q_dim;
int k = config.num_heads * config.key_dim;
int k = config.num_heads * config.head_dim;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
out_linear_compute.ComputeBackward(input, out_linear_weight, out_grad,
......
......@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['NVIDIA_TF32_OVERRIDE'] = "0"
os.environ['FLAGS_new_einsum'] = "0"
import numpy as np
import paddle
......@@ -47,7 +52,7 @@ class TestFusedGateAttentionOp(OpTest):
self.res_len = 5
self.q_dim = 6
self.num_heads = 2
self.key_dim = 4
self.head_dim = 4
self.m_size = self.res_len
self.kv_dim = self.q_dim
self.out_dim = self.q_dim
......@@ -65,12 +70,12 @@ class TestFusedGateAttentionOp(OpTest):
np.random.seed(123)
self.query = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim))
self.q_weight = _random((self.q_dim, self.num_heads, self.key_dim))
self.k_weight = _random((self.kv_dim, self.num_heads, self.key_dim))
self.v_weight = _random((self.kv_dim, self.num_heads, self.key_dim))
self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim))
self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
if self.merge_qkv:
self.key = None
# (3, self.num_heads, self.key_dim, self.q_dim)
# (3, self.num_heads, self.head_dim, self.q_dim)
q_weight_t = np.transpose(self.q_weight, axes=[1, 2, 0])
k_weight_t = np.transpose(self.k_weight, axes=[1, 2, 0])
v_weight_t = np.transpose(self.v_weight, axes=[1, 2, 0])
......@@ -88,15 +93,22 @@ class TestFusedGateAttentionOp(OpTest):
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size))
if self.has_gating:
self.gating_w = _random((self.q_dim, self.num_heads, self.key_dim))
self.gating_b = _random((self.num_heads, self.key_dim))
self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim))
self.gating_b = _random((self.num_heads, self.head_dim))
self.output_w = _random((self.num_heads, self.key_dim, self.out_dim))
self.output_w = _random((self.num_heads, self.head_dim, self.out_dim))
self.output_b = _random((self.out_dim))
self.dout = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim))
def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out):
outputs = [
softmax_out, fmha_out, gate_out if self.has_gating else None, out,
query.grad, None if self.merge_qkv else key.grad
]
return outputs
def get_reference_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
......@@ -108,44 +120,85 @@ class TestFusedGateAttentionOp(OpTest):
v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False)
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
c = self.key_dim**(-0.5)
# [batch_size, msa_len, num_heads, res_len, key_dim]
c = self.head_dim**(-0.5)
# [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c
# [batch_size, msa_len, num_heads, m_size, key_dim]
# [batch_size, msa_len, m_size, kv_dim], [kv_dim, num_heads, head_dim]
# -> [batch_size, msa_len, m_size, num_heads, head_dim]
k = paddle.einsum('nbka,ahc->nbkhc', key, k_weight)
# [batch_size, msa_len, num_heads, m_size, key_dim]
# [batch_size, msa_len, m_size, kv_dim], [kv_dim, num_heads, head_dim]
# -> [batch_size, msa_len, m_size, num_heads, head_dim]
v = paddle.einsum('nbka,ahc->nbkhc', key, v_weight)
# [batch_size, msa_len, num_heads, res_len, m_size]
# [batch_size, msa_len, res_len, num_heads, head_dim], [batch_size, msa_len, m_size, num_heads, head_dim]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) # qk_out
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, mas_len, 1, 1, m_size]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + src_mask
if self.bias_attr:
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
stop_gradient=False)
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + nonbatched_bias
weights = nn.functional.softmax(logits) # softmax_out
weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v)
# [batch_size, msa_len, num_heads, res_len, m_size]
softmax_out = nn.functional.softmax(logits)
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, msa_len, m_size, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
# fmha_out = paddle.einsum('nbhqk,nbkhc->nbqhc', softmax_out, v)
v_trans = paddle.transpose(v, perm=[0, 1, 3, 2, 4])
qktv_out = paddle.matmul(softmax_out, v_trans)
fmha_out = paddle.transpose(qktv_out, perm=[0, 1, 3, 2, 4])
if self.has_gating:
gating_w = paddle.to_tensor(self.gating_w, stop_gradient=False)
gating_b = paddle.to_tensor(self.gating_b, stop_gradient=False)
gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
gating_w) + gating_b
# [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
# gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
# gating_w) + gating_b
gating_w_2d = paddle.reshape(
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim])
gate_values_4d = paddle.matmul(query, gating_w_2d)
gate_values = paddle.reshape(
gate_values_4d,
shape=[
self.batch_size, self.msa_len, self.res_len, self.num_heads,
self.head_dim
]) + gating_b
gate_values = nn.functional.sigmoid(gate_values)
weighted_avg = weighted_avg * gate_values
gate_out = fmha_out * gate_values
else:
gate_out = fmha_out
output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
out = paddle.einsum('nbqhc,hco->nbqo', weighted_avg,
output_w) + output_b
# [batch_size, msa_len, res_len, num_heads, head_dim], [num_heads, head_dim, out_dim]
# -> [batch_size, msa_len, res_len, out_dim]
# out = paddle.einsum('nbqhc,hco->nbqo', gate_out,
# output_w) + output_b
gate_out_2d = paddle.reshape(
gate_out,
shape=[
self.batch_size * self.msa_len * self.res_len,
self.num_heads * self.head_dim
])
output_w_2d = paddle.reshape(
output_w, shape=[self.num_heads * self.head_dim, self.out_dim])
out_2d = paddle.matmul(gate_out_2d, output_w_2d)
out = paddle.reshape(
out_2d,
shape=[self.batch_size, self.msa_len, self.res_len, self.out_dim
]) + output_b
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
retain_graph=True)
if self.merge_qkv:
return out, query.grad, None
else:
return out, query.grad, key.grad
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
out)
def get_fused_gate_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
......@@ -181,40 +234,59 @@ class TestFusedGateAttentionOp(OpTest):
output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
_, _, _, _, _, _, _, out = _C_ops.fused_gate_attention(
_, _, _, _, softmax_out, fmha_out, gate_out, out = _C_ops.fused_gate_attention(
query, key, q_weight, k_weight, v_weight, qkv_weight,
nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b,
'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv)
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
retain_graph=True)
if key is not None:
return out, query.grad, key.grad
else:
return out, query.grad, None
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
out)
def check_output_and_grad(self, atol, rtol):
def check(self, ref, out, atol, rtol, check_equal, name):
def _convert(value):
if self.dtype == "bfloat16":
return convert_uint16_to_float(value)
return value
output_names = ["out", "query_grad", "key_grad"]
if check_equal:
self.assertTrue(
np.equal(_convert(ref), _convert(out)).all(),
"Checking < {} > failed!".format(name))
else:
np.testing.assert_allclose(
_convert(ref),
_convert(out),
atol=atol,
rtol=rtol,
err_msg="Checking < {} > failed!".format(name))
def check_output_and_grad(self, atol, rtol):
output_names = [
"softmax_out", "fmha_out", "gate_out", "out", "query_grad",
"key_grad"
]
outputs_ref = self.get_reference_out()
outputs_fused = self.get_fused_gate_attention_out()
for i in range(len(outputs_fused)):
for i in range(len(output_names)):
ref_res = outputs_ref[i]
fused_res = outputs_fused[i]
if ref_res is not None and fused_res is not None:
print("Checking {}".format(output_names[i]))
np.testing.assert_allclose(_convert(ref_res),
_convert(fused_res.numpy()),
atol=atol,
rtol=rtol)
# The python implementation of einsum is likely to call
# matmul(x, y, transpose_x=False, transpose_y=True). With different
# transpose_x and transpose_y, cublas will launch different kernels
# and the result cannot be exactly equal.
# Because the arguments of matmul in einsum is the the same as
# that in fused ops, check_equal is set to False and we use allclose
# to check the correctness.
check_equal = False
self.check(ref_res.numpy(), fused_res.numpy(), atol, rtol,
check_equal, output_names[i])
def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-5, rtol=1e-5)
self.check_output_and_grad(atol=1e-5, rtol=1e-6)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
......@@ -234,7 +306,7 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp):
self.res_len = 5
self.q_dim = 6
self.num_heads = 2
self.key_dim = 4
self.head_dim = 4
self.m_size = 4
self.kv_dim = 2
self.out_dim = self.q_dim
......@@ -279,7 +351,7 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
self.dtype = "bfloat16"
def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-1, rtol=1e-3)
self.check_output_and_grad(atol=1e-1, rtol=1e-2)
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
......
......@@ -3263,9 +3263,7 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
if x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
input = x
dim = axis
keep_dim = keepdim
if dim is not None and not isinstance(dim, list):
if isinstance(dim, tuple):
dim = list(dim)
......@@ -3275,24 +3273,29 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".
format(type(dim)))
reduce_all = True if dim is None or len(dim) == 0 or len(dim) == len(x.shape) else False
if dim is None or len(dim) == 0:
dim = [0]
if in_dygraph_mode():
return _C_ops.final_state_reduce_prod(
input, dim if dim != None and dim != [] else [0], keep_dim, True if
dim == None or dim == [] or len(dim) == len(input.shape) else False)
return _C_ops.final_state_reduce_prod(x, dim, keepdim, reduce_all)
if _in_legacy_dygraph():
return _C_ops.reduce_prod(
x, 'dim', dim, 'keep_dim', keepdim, 'reduce_all', reduce_all)
helper = LayerHelper('reduce_prod', **locals())
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod')
x, 'x/input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod')
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='reduce_prod',
inputs={'X': input},
inputs={'X': x},
outputs={'Out': out},
attrs={
'dim': dim if dim != None and dim != [] else [0],
'keep_dim': keep_dim,
'reduce_all': True if dim == None or dim == [] or
len(dim) == len(input.shape) else False
'dim': dim,
'keep_dim': keepdim,
'reduce_all': reduce_all
})
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册