From 1a533ed2de02ce5541a3f7adaf3e00c1ffae3fe4 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 5 Mar 2020 10:25:19 +0800 Subject: [PATCH] [BUG]: Multihead matmul op's ouput size should be BxSx(N*H) (#22848) test=develop --- paddle/fluid/operators/fused/multihead_matmul_op.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 2500f66c67..de40ded24e 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -440,13 +440,11 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto &bias_qk = detail::Ref(context.Input("BiasQK"), "Cannot find QK"); - auto *out = context.Output("Out"); auto *input_d = input->data(); auto *w_d = w->data(); auto *bias_d = bias->data(); auto *bias_qk_d = bias_qk.data(); - auto *output_d = out->mutable_data(context.GetPlace()); T scale = static_cast(context.Attr("alpha")); int head_number = context.Attr("head_number"); @@ -463,6 +461,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int all_head_size = w_dims[2]; int head_size = all_head_size / head_number; + auto *out = context.Output("Out"); + out->Resize({batch, seq_len, all_head_size}); + auto *output_d = out->mutable_data(context.GetPlace()); + // (B*S, hidden) const Tensor input_matrix = framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */); -- GitLab