未验证 提交 8185cecd 编写于 作者: Y yeliang2258 提交者: GitHub

Fix a bug in transpose2 when run native cpu (#44659)

* fix a bug in transpose2 about mkldnn

* fix bug
上级 b54abbe8
...@@ -356,7 +356,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -356,7 +356,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
size_t x_rank = x_shape.size(); size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size(); size_t y_rank = y_shape.size();
flag = flag && x_rank >= 2 && y_rank == 2; flag = flag && x_rank >= 2 && y_rank == 2;
flag = flag && x_shape[x_rank - 1] == y_shape[0];
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed."; LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed.";
......
...@@ -80,7 +80,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class TransposeOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Here we need to match dims to paddle layout // Here we need to match dims to paddle layout
// as we are producing non-oneDNN result // as we are producing non-oneDNN result
if ((x_dims.size() >= 3) && if (ctx->IsRunMKLDNNKernel() && (x_dims.size() >= 3) &&
(paddle::platform::MKLDNNDeviceContext::tls() (paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC)) { .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC)) {
auto dims = phi::vectorize<int>(x_dims); auto dims = phi::vectorize<int>(x_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册