diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index b7b0f33ade85c26d096994949c8ded3c2d7f85ac..a0de6064b252b09cf7eb2193db6c81755cfabffa 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -24,70 +24,6 @@ namespace operators { using Tensor = phi::DenseTensor; using framework::DataLayout; -template -class TransposeMKLDNNHandler { - public: - TransposeMKLDNNHandler(std::vector& dims, // NOLINT - std::vector& axis, // NOLINT - dnnl::engine engine) - : dims_(dims), - axis_(axis), - logical_axis_(dims.size(), 0), - engine_(engine) {} - - std::shared_ptr AcquireSrcMemory(const MKLDNNMemoryFormat& fmt, - void* ptr) { - // Make memory descriptor using input format, unless it - // cannot be trusted (nchw) then make up memory fmt manually - for (size_t i = 0; i < this->logical_axis_.size(); ++i) { - this->logical_axis_[i] = i; - } - - auto src_md = fmt != MKLDNNMemoryFormat::nchw - ? platform::MKLDNNMemDesc( - dims_, platform::MKLDNNGetDataType(), fmt) - : Axis2MemoryDesc(dims_, logical_axis_); - return std::make_shared(src_md, engine_, ptr); - } - - std::shared_ptr AcquireDstMemory(phi::DenseTensor* output, - platform::Place place) { - auto dst_md = Axis2MemoryDesc(dims_, axis_); - auto dst_data = output->mutable_data(place, dst_md.get_size()); - return std::make_shared(dst_md, engine_, dst_data); - } - - std::shared_ptr AcquireTranspose( - std::shared_ptr dst_memory_p, - std::shared_ptr src_memory_p) { - return std::make_shared(*(src_memory_p), *(dst_memory_p)); - } - - protected: - dnnl::memory::desc Axis2MemoryDesc(std::vector& nchw_tz, // NOLINT - std::vector& axis // NOLINT - ) { - size_t ndims = axis.size(); - - std::vector strides(ndims); - unsigned int total_stride = 1; - for (int i = ndims - 1; i >= 0; --i) { - strides[axis[i]] = total_stride; - total_stride *= nchw_tz[axis[i]]; - } - dnnl::memory::desc mem_d( - nchw_tz, platform::MKLDNNGetDataType(), strides); - - return mem_d; - } - - private: - std::vector dims_; - std::vector axis_; - std::vector logical_axis_; - dnnl::engine engine_; -}; - template class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -98,37 +34,84 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { "Operator DNNL Transpose must use CPUPlace")); auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - std::vector axis = ctx.Attr>("axis"); - int ndims = axis.size(); - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - const T* input_data = input->data(); + const auto& dnnl_engine = dev_ctx.GetEngine(); + std::vector transpose_axis = ctx.Attr>("axis"); + int ndims = transpose_axis.size(); + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); if (ndims == 1) { - framework::TensorCopy(*input, input->place(), output); - output->set_format(input->format()); + framework::TensorCopy(*x, x->place(), out); + out->set_mem_desc(x->mem_desc()); return; } - auto nchw_tz = phi::vectorize(input->dims()); + auto x_vec_dims = phi::vectorize(x->dims()); - TransposeMKLDNNHandler handler(nchw_tz, axis, mkldnn_engine); + framework::proto::VarType::Type x_paddle_type = + framework::TransToProtoVarType(x->dtype()); + dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type); + platform::ReorderMKLDNNHandler reorder_handler( + x_vec_dims, x_paddle_type, x_type, dnnl_engine); - auto transpose_src_memory_p = handler.AcquireSrcMemory( - input->format(), platform::to_void_cast(input_data)); - auto transpose_dst_memory_p = - handler.AcquireDstMemory(output, ctx.GetPlace()); - auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, - transpose_src_memory_p); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->mem_desc(), platform::to_void_cast(x->data())); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - transpose_p->execute( - astream, *transpose_src_memory_p, *transpose_dst_memory_p); + auto dst_md = + dnnl::memory::desc(x_vec_dims, + x->mem_desc().data_type(), + platform::GetPlainMKLDNNFormat(x_vec_dims.size())); + // a trick is used here to fake transpose of out_md, so later it will be + // "untransposed", leaving output data in plain format tag + auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis); + + dst_md = + dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides); + auto dst_data = + out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size()); + + auto reorder_dst_memory_p = + std::make_shared(dst_md, dnnl_engine, dst_data); + + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); - output->set_layout(DataLayout::kNCHW); - output->set_format(MKLDNNMemoryFormat::undef); + out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes( + TransposeToPermuteAxis(transpose_axis))); + } + + private: + // it is needed because oneDNN's permute axis understand axes order in + // different way PaddlePaddle's transpose + std::vector TransposeToPermuteAxis( + const std::vector& transpose_axis) const { + std::vector permute_axis(transpose_axis.size()); + + for (size_t i = 0; i < transpose_axis.size(); ++i) { + permute_axis[transpose_axis[i]] = i; + } + return permute_axis; + } + + std::vector FakeTranposeStrides( + const dnnl::memory::desc& dst_md, + const std::vector& transpose_axis) const { + std::vector fake_strides(transpose_axis.size()); + auto dims = dst_md.dims(); + int total_stride = 1; + int ndims = static_cast(dims.size()); + + for (int i = ndims - 1; i >= 0; --i) { + fake_strides[transpose_axis[i]] = total_stride; + total_stride *= dims[transpose_axis[i]]; + } + + return fake_strides; } }; @@ -140,43 +123,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { true, paddle::platform::errors::PreconditionNotMet( "Operator DNNL TransposeGrad must use CPUPlace")); - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - if (!x_grad) return; + + const auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + if (!dx) return; auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - std::vector axis = ctx.Attr>("axis"); - std::vector reversed_axis(axis); - int ndims = axis.size(); + const auto& dnnl_engine = dev_ctx.GetEngine(); + std::vector transpose_axis = ctx.Attr>("axis"); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + int ndims = transpose_axis.size(); if (ndims == 1) { - framework::TensorCopy(*out_grad, out_grad->place(), x_grad); - x_grad->set_format(out_grad->format()); + framework::TensorCopy(*dout, dout->place(), dx); + dx->set_mem_desc(dout->mem_desc()); return; } - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } + auto dout_vec_dims = phi::vectorize(dout->dims()); - const T* out_grad_data = out_grad->data(); - x_grad->mutable_data(ctx.GetPlace()); + framework::proto::VarType::Type dout_paddle_type = + framework::TransToProtoVarType(dout->dtype()); + dnnl::memory::data_type dout_type = + framework::ToMKLDNNDataType(dout_paddle_type); - auto nchw_tz = phi::vectorize(out_grad->dims()); + platform::ReorderMKLDNNHandler reorder_handler( + dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine); - TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, mkldnn_engine); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + dout->mem_desc(), platform::to_void_cast(dout->data())); - auto transpose_src_memory_p = handler.AcquireSrcMemory( - out_grad->format(), platform::to_void_cast(out_grad_data)); - auto transpose_dst_memory_p = - handler.AcquireDstMemory(x_grad, ctx.GetPlace()); - auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, - transpose_src_memory_p); + auto reorder_dst_memory_p = + reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace()); - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - transpose_p->execute( - astream, *transpose_src_memory_p, *transpose_dst_memory_p); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); + dx->set_mem_desc( + reorder_dst_memory_p->get_desc().permute_axes(transpose_axis)); } };