未验证 提交 448d6de9 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #5562 from emailweixu/fix_matmal

Fix matmal_op for debug mode
...@@ -74,11 +74,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context, ...@@ -74,11 +74,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context,
Tensor output; Tensor output;
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() == 3) { if (in_dims.size() == 3) {
output.Resize(in_dims); output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace()); output.mutable_data<T>(context.GetPlace());
EigenTranspose<Place, T, 3>(context, input, output, {1, 0, 2}); EigenTranspose<Place, T, 3>(context, input, output, {1, 0, 2});
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
output.Resize(make_ddim(out_dims));
} else { } else {
output.ShareDataWith(input); output.ShareDataWith(input);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册