diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 2500f66c672733c29d2200e2bdf97597a7cadad4..de40ded24e3791dca72abd48345c3a149e4a11a4 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -440,13 +440,11 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto &bias_qk = detail::Ref(context.Input("BiasQK"), "Cannot find QK"); - auto *out = context.Output("Out"); auto *input_d = input->data(); auto *w_d = w->data(); auto *bias_d = bias->data(); auto *bias_qk_d = bias_qk.data(); - auto *output_d = out->mutable_data(context.GetPlace()); T scale = static_cast(context.Attr("alpha")); int head_number = context.Attr("head_number"); @@ -463,6 +461,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int all_head_size = w_dims[2]; int head_size = all_head_size / head_number; + auto *out = context.Output("Out"); + out->Resize({batch, seq_len, all_head_size}); + auto *output_d = out->mutable_data(context.GetPlace()); + // (B*S, hidden) const Tensor input_matrix = framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);