diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index e59b901b6a38d5c3e73d3655ba7bcd93b3486720..f7f7e5f6ad8935eae3ca1aa5f398ce8dd40221ba 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -42,9 +42,6 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto& astream = OneDNNContext::tls().get_stream(); - platform::SetInMemDescWithLogicalLayoutFusesSupport( - ctx, const_cast(x), x->mem_desc()); - if (ndims == 1) { framework::TensorCopy(*x, x->place(), out); out->set_mem_desc(x->mem_desc()); @@ -82,11 +79,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); - platform::SetOutMemDescWithLogicalLayoutFusesSupport( - ctx, - out, - reorder_dst_memory_p->get_desc().permute_axes( - TransposeToPermuteAxis(transpose_axis))); + out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes( + TransposeToPermuteAxis(transpose_axis))); } private: @@ -180,11 +174,3 @@ REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNGradOpKernel); - -REGISTER_OP_KERNEL(transpose2, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::TransposeMKLDNNOpKernel, - ops::TransposeMKLDNNOpKernel, - ops::TransposeMKLDNNOpKernel, - ops::TransposeMKLDNNOpKernel); diff --git a/paddle/fluid/operators/ops_extra_info.h b/paddle/fluid/operators/ops_extra_info.h index 94adfaf3b4500b0b58153aec947a64acde59dc77..12df9f96d6d58d23bf0fcd464e84e4d78a44bda5 100644 --- a/paddle/fluid/operators/ops_extra_info.h +++ b/paddle/fluid/operators/ops_extra_info.h @@ -120,6 +120,9 @@ const std::unordered_map {"Scale_weights", ExtraAttrProperty::ONEDNN}, {"x_data_format", ExtraAttrProperty::ONEDNN}, {"y_data_format", ExtraAttrProperty::ONEDNN}, + {"fused_squeeze2_axes", ExtraAttrProperty::ONEDNN}, + {"fused_unsqueeze2_axes", ExtraAttrProperty::ONEDNN}, + {"fused_reshape2_shape", ExtraAttrProperty::ONEDNN}, // ONEDNN pass dedicated attributes {"Activation_scale", ExtraAttrProperty::ONEDNN}, {"Bias_scales", ExtraAttrProperty::ONEDNN}, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 7a8ef9c939572a13c87d7b482af9468b1f1a2b62..0142fa2afd13de6e5faa3c2b537df05d0a7bd59b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -151,50 +151,6 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport( } } -static void SetInMemDescWithSqueeze2FuseSupport( - const framework::ExecutionContext& ctx, - phi::DenseTensor* in, - const dnnl::memory::desc& in_md) { - const std::vector fused_squeeze2_axes = - ctx.Attr>("fused_squeeze2_axes"); - const std::set squeeze2_axes_set(fused_squeeze2_axes.begin(), - fused_squeeze2_axes.end()); - const std::vector& x_vec_dims = in_md.dims(); - std::vector squeezed_op_tz( - x_vec_dims.size() - fused_squeeze2_axes.size(), 0); - - int j = 0; - for (size_t i = 0; i < x_vec_dims.size(); ++i) { - if (squeeze2_axes_set.count(i) || - squeeze2_axes_set.count(i - x_vec_dims.size())) { - PADDLE_ENFORCE_EQ( - x_vec_dims[i], - 1, - platform::errors::InvalidArgument( - "Squeeze2 input dim %d should be equal to one, but get %d.", - i, - x_vec_dims[i])); - continue; - } - squeezed_op_tz[j++] = x_vec_dims[i]; - } - - in->set_mem_desc(in_md.reshape(squeezed_op_tz)); - in->Resize(phi::make_ddim(squeezed_op_tz)); -} - -static void SetInMemDescWithLogicalLayoutFusesSupport( - const framework::ExecutionContext& ctx, - phi::DenseTensor* in, - const dnnl::memory::desc& in_md) { - if (ctx.HasAttr("fused_squeeze2_axes")) { - SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md); - } else { - in->set_mem_desc(in_md); - in->Resize(phi::make_ddim(in_md.dims())); - } -} - template class MatMulV2MKLDNNHandler : public phi::funcs::OneDNNHandlerNoCachingT { diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 7f64f8668c91bdfcdfe2fd47762a8826117dbc4b..dbb70cb07aaeca1bb2a4245a7dcdf0d0415cf2bc 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -1660,6 +1661,85 @@ class PoolingOneDNNHandler } }; +static void SetOutMemDescWithUnsqueeze2FuseSupport( + const std::vector fused_unsqueeze2_axes, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + const std::vector& op_tz = out_md.dims(); + std::vector unsqueezed_op_tz( + op_tz.size() + fused_unsqueeze2_axes.size(), 0); + + for (const auto& axis : fused_unsqueeze2_axes) { + int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis; + unsqueezed_op_tz[positive_axis] = 1; + } + + int j = 0; + for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) { + if (unsqueezed_op_tz[i] == 0) { + unsqueezed_op_tz[i] = op_tz[j++]; + } + } + out->set_mem_desc(out_md.reshape(unsqueezed_op_tz)); + out->Resize(make_ddim(unsqueezed_op_tz)); +} + +static void SetOutMemDescWithReshape2FuseSupport( + const std::vector fused_reshape2_shape_, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + std::vector fused_reshape2_shape(fused_reshape2_shape_.begin(), + fused_reshape2_shape_.end()); + + const int out_shape_numel = out->numel(); + const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(), + fused_reshape2_shape.end(), + 1, + std::multiplies()); + + for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) { + if (fused_reshape2_shape[i] == -1) { + fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel; + break; + } + } + + out->set_mem_desc(out_md.reshape(fused_reshape2_shape)); + out->Resize(phi::make_ddim(fused_reshape2_shape)); +} + +static void SetOutMemDescWithLogicalLayoutFusesSupport( + const OneDNNContext& dev_ctx, + phi::DenseTensor* out, + const dnnl::memory::desc& out_md) { + const auto fused_unsqueeze2_axes = + dev_ctx.HasDnnAttr("fused_unsqueeze2_axes") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_unsqueeze2_axes")) + : std::vector(); + const auto fused_reshape2_shape = + dev_ctx.HasDnnAttr("fused_reshape2_shape") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_reshape2_shape")) + : std::vector(); + const auto fused_squeeze2_axes = + dev_ctx.HasDnnAttr("fused_squeeze2_axes") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_squeeze2_axes")) + : std::vector(); + + if (!fused_unsqueeze2_axes.empty()) { + SetOutMemDescWithUnsqueeze2FuseSupport(fused_unsqueeze2_axes, out, out_md); + } else if (!fused_reshape2_shape.empty()) { + SetOutMemDescWithReshape2FuseSupport(fused_reshape2_shape, out, out_md); + } else if (!fused_squeeze2_axes.empty()) { + out->set_mem_desc(out_md); + out->Resize(make_ddim(out_md.dims())); + } else { + out->set_mem_desc(out_md); + } +} + static DDim RowMatrixDimsFromVector(const DDim& x_dim) { return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]}); } diff --git a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc index a754cdffed14db26e1a78000977cf5c8e593ce1a..64f1f9f610861ba0a8a17479060a8610277ba2e4 100644 --- a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc @@ -63,4 +63,4 @@ void TransposeGradKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - transpose_grad, OneDNN, ALL_LAYOUT, phi::TransposeGradKernel, float) {} + transpose_grad, OneDNN, ONEDNN, phi::TransposeGradKernel, float) {} diff --git a/paddle/phi/kernels/onednn/transpose_kernel.cc b/paddle/phi/kernels/onednn/transpose_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..26c89197e0d7f4a34bab041b26ba99364129b40c --- /dev/null +++ b/paddle/phi/kernels/onednn/transpose_kernel.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +void SetInMemDescWithSqueeze2FuseSupport( + const std::vector fused_squeeze2_axes, + DenseTensor* in, + const dnnl::memory::desc& in_md) { + const std::set squeeze2_axes_set(fused_squeeze2_axes.begin(), + fused_squeeze2_axes.end()); + const std::vector& x_vec_dims = in_md.dims(); + std::vector squeezed_op_tz( + x_vec_dims.size() - fused_squeeze2_axes.size(), 0); + + int j = 0; + for (size_t i = 0; i < x_vec_dims.size(); ++i) { + if (squeeze2_axes_set.count(i) || + squeeze2_axes_set.count(i - x_vec_dims.size())) { + PADDLE_ENFORCE_EQ( + x_vec_dims[i], + 1, + errors::InvalidArgument( + "Squeeze2 input dim %d should be equal to one, but get %d.", + i, + x_vec_dims[i])); + continue; + } + squeezed_op_tz[j++] = x_vec_dims[i]; + } + + in->set_mem_desc(in_md.reshape(squeezed_op_tz)); + in->Resize(make_ddim(squeezed_op_tz)); +} + +void SetInMemDescWithLogicalLayoutFusesSupport( + const OneDNNContext& dev_ctx, + DenseTensor* in, + const dnnl::memory::desc& in_md) { + const auto fused_squeeze2_axes = + dev_ctx.HasDnnAttr("fused_squeeze2_axes") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("fused_squeeze2_axes")) + : std::vector(); + if (fused_squeeze2_axes.empty()) { + in->set_mem_desc(in_md); + in->Resize(make_ddim(in_md.dims())); + } else { + SetInMemDescWithSqueeze2FuseSupport(fused_squeeze2_axes, in, in_md); + } +} + +template +void TransposeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + dev_ctx.GetPlace().GetType() == AllocationType::CPU, + true, + errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace")); + + SetInMemDescWithLogicalLayoutFusesSupport( + dev_ctx, const_cast(&x), x.mem_desc()); + + if (axis.size() == 1) { + paddle::framework::TensorCopy(x, x.place(), out); + out->set_mem_desc(x.mem_desc()); + return; + } + + auto x_vec_dims = vectorize(x.dims()); + auto x_type = funcs::ToOneDNNDataType(x.dtype()); + funcs::ReorderOneDNNHandler reorder_handler( + x_vec_dims, x.dtype(), x_type, dev_ctx.GetEngine()); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x.mem_desc(), funcs::to_void_cast(x.data())); + auto dst_md = + dnnl::memory::desc(x_vec_dims, + x.mem_desc().data_type(), + funcs::GetPlainOneDNNFormat(x_vec_dims.size())); + + // a trick is used here to fake transpose of out_md, so later it will be + // "untransposed", leaving output data in plain format tag + std::vector fake_strides(axis.size()); + auto dims = dst_md.dims(); + int total_stride = 1; + for (int i = static_cast(dims.size()) - 1; i >= 0; --i) { + fake_strides[axis[i]] = total_stride; + total_stride *= dims[axis[i]]; + } + dst_md = + dnnl::memory::desc(x_vec_dims, x.mem_desc().data_type(), fake_strides); + auto dst_data = dev_ctx.template Alloc(out); + auto reorder_dst_memory_p = + std::make_shared(dst_md, dev_ctx.GetEngine(), dst_data); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + auto& astream = OneDNNContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + // it is needed because oneDNN's permute axis understand axes order in + // different way PaddlePaddle's transpose + std::vector permute_axis(axis.size()); + for (size_t i = 0; i < axis.size(); ++i) { + permute_axis[axis[i]] = i; + } + funcs::SetOutMemDescWithLogicalLayoutFusesSupport( + dev_ctx, + out, + reorder_dst_memory_p->get_desc().permute_axes(permute_axis)); +} +} // namespace phi + +PD_REGISTER_KERNEL(transpose, + OneDNN, + ONEDNN, + phi::TransposeKernel, + float, + uint8_t, + int8_t, + phi::dtype::bfloat16) {}