diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index f8c9c9d86a9953231424ef53157123a35275cc78..000e31aad9ac9a6fc150c0432dfffd75d375423f 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -214,10 +214,7 @@ class MatMulMKLDNNHandler } astream.wait(); - auto format = - MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); - out->set_format(format); - out->set_layout(DataLayout::kMKLDNN); + out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims())); } std::shared_ptr AcquireDstMemory( @@ -651,10 +648,18 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, auto &astream = MKLDNNDeviceContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); - auto format = - MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); - out->set_format(format); - out->set_layout(DataLayout::kMKLDNN); + + // TODO(jczaja): Explain why int8 format of dst is ABCD and do not need + // permute + if (IsOutputFused(ctx) && !IsInt8()) { + auto axis = ctx.Attr>("fused_transpose_Out"); + auto permuted_md = dst_memory_p->get_desc().permute_axes(axis); + out->set_mem_desc( + permuted_md.reshape(phi::vectorize(out->dims()))); + } else { + out->set_mem_desc( + dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()))); + } } template @@ -836,8 +841,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { reduction_p->execute(astream, reduction_args); astream.wait(); - dx->set_format(paddle::platform::GetMKLDNNFormat( - dst_memory_p->get_desc().reshape(squeezed_dims))); + dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims)); } std::vector ExtendDimsWithOnes(const std::vector &dims, @@ -1119,9 +1123,8 @@ void MatMulGradMKLDNNKernel::ExecuteMatMulGrad( matmul_p->execute(astream, matmul_args); astream.wait(); - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetMKLDNNFormat( - dst_memory_p->get_desc().reshape(vectorize(out->dims())))); + out->set_mem_desc( + dst_memory_p->get_desc().reshape(vectorize(out->dims()))); } template @@ -1184,13 +1187,13 @@ void MatMulGradMKLDNNKernel::RunKernel(const ExecutionContext &ctx) const { if (dx) { if (dx_dims != x.dims()) { dx->Resize(dx_dims); - dx->set_format(x.format()); + dx->set_mem_desc(x.mem_desc()); } } if (dy) { if (dy_dims != y.dims()) { dy->Resize(dy_dims); - dy->set_format(y.format()); + dy->set_mem_desc(y.mem_desc()); } } } diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index e727a4fe9fb4888c1032c7325b295171a34884a6..e9150b0c58f76da1ba7a1c5adc7472cfeb596938 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -221,7 +221,7 @@ class MulPrimitiveFactory { to_void_cast(x_tmp.data())); x_tmp.Resize(data->dims()); - x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc)); + x_tmp.set_mem_desc(dst_mdesc); data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); } else { data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); @@ -235,11 +235,7 @@ class MulPrimitiveFactory { const Tensor *in) { x_input_->set_data_handle(to_void_cast(in->data())); output_->set_data_handle(out->mutable_data(ctx.GetPlace())); - - if (out->format() == MKLDNNMemoryFormat::undef) { - auto output_format = platform::GetMKLDNNFormat(*output_); - out->set_format((MKLDNNMemoryFormat)output_format); - } + out->set_mem_desc(output_->get_desc()); } template @@ -272,7 +268,7 @@ class MulPrimitiveFactory { auto buffer_size = dst_desc.get_size(); OT *output_data = output->mutable_data(ctx.GetPlace(), buffer_size); - output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc)); + output->set_mem_desc(dst_desc); return memory(dst_desc, engine_, to_void_cast(output_data)); } @@ -392,9 +388,10 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel { if (out_dims.size() != 2) { out->Resize(out_dims); } - out->set_layout(DataLayout::kMKLDNN); - out->set_format(platform::MKLDNNFormatForSize(out_dims.size(), - MKLDNNMemoryFormat::nchw)); + + auto in_md = dnnl::memory::desc(*dnnl_primitive_desc_query_md( + mul.get_primitive_desc(), dnnl_query_dst_md, 0)); + out->set_mem_desc(in_md.reshape(phi::vectorize(out->dims()))); } }; @@ -442,10 +439,11 @@ class MulMKLDNNKernel : public framework::OpKernel { matmul_p->execute(astream, matmul_args); astream.wait(); - out->set_layout(framework::DataLayout::kMKLDNN); - // plain output formats are enforced inside handler - out->set_format(platform::MKLDNNFormatForSize( - out->dims().size(), dnnl::memory::format_tag::nchw)); + // This kernel is flattening dims so then we need to unflattened version + // that should be set in out reshape require plain layout, but + // MatmulV2MKLDNNHanlder enforces one so it should work + out->set_mem_desc( + dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()))); } private: diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 933ac4f12e3c4ef4a300689729fb4151f35cb82f..ca099cb65d67c04188055dfb1805c916932ae58b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -196,7 +196,8 @@ class MatMulV2MKLDNNHandler out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; } - if (!IsInt8() && !IsBfloat16() && is_output_fused) { + // TODO(jczaja): Why not for int8?? + if (!IsInt8() && is_output_fused) { out_strides = FakeTransposeStrides(out_ddims); } diff --git a/paddle/phi/kernels/transfer_layout_kernel.cc b/paddle/phi/kernels/transfer_layout_kernel.cc index 25a986ea82fb020457dee5b7741d5bd7e70238a6..be232b7c671e9baa6244215426ecacad69833b09 100644 --- a/paddle/phi/kernels/transfer_layout_kernel.cc +++ b/paddle/phi/kernels/transfer_layout_kernel.cc @@ -121,8 +121,10 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, OneDNNContext::tls().set_cur_paddle_data_layout(src_layout); } - out->set_layout(DataLayout::ONEDNN); - out->set_format(out_format); + dnnl::memory::desc out_mem_desc(vectorize(out->dims()), + funcs::ToOneDNNDataType(x.dtype()), + out_format); + out->set_mem_desc(out_mem_desc); } else if (src_layout == DataLayout::ONEDNN && dst_layout != DataLayout::ONEDNN) { // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel