未验证 提交 1a533ed2 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[BUG]: Multihead matmul op's ouput size should be BxSx(N*H) (#22848)

test=develop
上级 1217a521
...@@ -440,13 +440,11 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -440,13 +440,11 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"), auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
"Cannot find QK"); "Cannot find QK");
auto *out = context.Output<framework::Tensor>("Out");
auto *input_d = input->data<T>(); auto *input_d = input->data<T>();
auto *w_d = w->data<T>(); auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>(); auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.data<T>(); auto *bias_qk_d = bias_qk.data<T>();
auto *output_d = out->mutable_data<T>(context.GetPlace());
T scale = static_cast<T>(context.Attr<float>("alpha")); T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number"); int head_number = context.Attr<int>("head_number");
...@@ -463,6 +461,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -463,6 +461,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int all_head_size = w_dims[2]; int all_head_size = w_dims[2];
int head_size = all_head_size / head_number; int head_size = all_head_size / head_number;
auto *out = context.Output<framework::Tensor>("Out");
out->Resize({batch, seq_len, all_head_size});
auto *output_d = out->mutable_data<T>(context.GetPlace());
// (B*S, hidden) // (B*S, hidden)
const Tensor input_matrix = const Tensor input_matrix =
framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */); framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册