From 493fbfd75b0983de4a08afb859be104398a0af22 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 27 Oct 2022 08:22:46 +0200 Subject: [PATCH] Update of PHI transpose_grad (#47311) * - halfway transforming transpose grad - Fixes - buildable * - lint * rerunning the process --- paddle/phi/backends/onednn/onednn_reuse.h | 64 ------------------- .../kernels/onednn/transpose_grad_kernel.cc | 34 +++++----- 2 files changed, 16 insertions(+), 82 deletions(-) diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index cd8c076b28..d1810090cd 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1085,69 +1085,5 @@ class ClipOneDNNHandler to_void_cast(input_data)); } }; -template -class TransposeOneDNNHandler { - public: - TransposeOneDNNHandler(const OneDNNContext& dev_ctx, - std::vector& dims, // NOLINT - std::vector& axis, // NOLINT - dnnl::engine engine) - : dev_ctx_(dev_ctx), - dims_(dims), - axis_(axis), - logical_axis_(dims.size(), 0), - engine_(engine) {} - - std::shared_ptr AcquireSrcMemory(const OneDNNMemoryFormat& fmt, - void* ptr) { - // Make memory descriptor using input format, unless it - // cannot be trusted (nchw) then make up memory fmt manually - for (size_t i = 0; i < this->logical_axis_.size(); ++i) { - this->logical_axis_[i] = i; - } - - auto src_md = fmt != OneDNNMemoryFormat::nchw - ? OneDNNMemDesc(dims_, OneDNNGetDataType(), fmt) - : Axis2MemoryDesc(dims_, logical_axis_); - return std::make_shared(src_md, engine_, ptr); - } - - std::shared_ptr AcquireDstMemory(DenseTensor* output, - Place place) { - auto dst_md = Axis2MemoryDesc(dims_, axis_); - auto dst_data = dev_ctx_.Alloc(output); - return std::make_shared(dst_md, engine_, dst_data); - } - - std::shared_ptr AcquireTranspose( - std::shared_ptr dst_memory_p, - std::shared_ptr src_memory_p) { - return std::make_shared(*(src_memory_p), *(dst_memory_p)); - } - - protected: - dnnl::memory::desc Axis2MemoryDesc(std::vector& nchw_tz, // NOLINT - std::vector& axis // NOLINT - ) { - size_t ndims = axis.size(); - - std::vector strides(ndims); - unsigned int total_stride = 1; - for (int i = ndims - 1; i >= 0; --i) { - strides[axis[i]] = total_stride; - total_stride *= nchw_tz[axis[i]]; - } - dnnl::memory::desc mem_d(nchw_tz, OneDNNGetDataType(), strides); - - return mem_d; - } - - private: - const OneDNNContext& dev_ctx_; - std::vector dims_; - std::vector axis_; - std::vector logical_axis_; - dnnl::engine engine_; -}; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc index 09f410c61c..a754cdffed 100644 --- a/paddle/phi/kernels/onednn/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_grad_kernel.cc @@ -31,35 +31,33 @@ void TransposeGradKernel(const Context& dev_ctx, if (!x_grad) return; const auto& onednn_engine = dev_ctx.GetEngine(); - std::vector reversed_axis(axis); + if (axis.size() == 1) { paddle::framework::TensorCopy(out_grad, out_grad.place(), x_grad); - x_grad->set_format(out_grad.format()); + x_grad->set_mem_desc(out_grad.mem_desc()); return; } - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } + std::vector out_grad_tz = vectorize(out_grad.dims()); + funcs::ReorderOneDNNHandler reorder_handler( + out_grad_tz, + out_grad.dtype(), + funcs::ToOneDNNDataType(out_grad.dtype()), + onednn_engine); - const T* out_grad_data = out_grad.data(); - dev_ctx.template Alloc(x_grad); - auto nchw_tz = vectorize(out_grad.dims()); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + out_grad.mem_desc(), funcs::to_void_cast(out_grad.data())); - funcs::TransposeOneDNNHandler handler( - dev_ctx, nchw_tz, reversed_axis, onednn_engine); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + x_grad, out_grad.mem_desc(), dev_ctx.GetPlace()); - auto transpose_src_memory_p = handler.AcquireSrcMemory( - out_grad.format(), funcs::to_void_cast(out_grad_data)); - auto transpose_dst_memory_p = - handler.AcquireDstMemory(x_grad, dev_ctx.GetPlace()); - auto transpose_p = - handler.AcquireTranspose(transpose_dst_memory_p, transpose_src_memory_p); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); auto& astream = OneDNNContext::tls().get_stream(); - transpose_p->execute( - astream, *transpose_src_memory_p, *transpose_dst_memory_p); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); + x_grad->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(axis)); } } // namespace phi -- GitLab