diff --git a/paddle/fluid/operators/transpose_mkldnn_op.cc b/paddle/fluid/operators/transpose_mkldnn_op.cc index 37f1cadc7d2ff248e8b6dcb3f0c8ba09f8ccd8b5..2f133c9e251388e9e78a6a49ca66a45a56eef76e 100644 --- a/paddle/fluid/operators/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/transpose_mkldnn_op.cc @@ -32,7 +32,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { const bool is_test = ctx.Attr("is_test"); PADDLE_ENFORCE( is_test == true, - "ConvTransposeMKLDNN works only for inference!. Set is_test = True"); + "TransposeMKLDNN works only for inference!. Set is_test = True"); auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -47,69 +47,24 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { return; } - std::vector nchw_axis(ndims, 0); - for (size_t i = 0; i < nchw_axis.size(); ++i) { - nchw_axis[i] = i; - } - std::vector nchw_tz = paddle::framework::vectorize2int(input->dims()); - std::string data_format = ctx.Attr("data_format"); - - auto src_md = - input->format() != mkldnn::memory::format::nchw - ? platform::MKLDNNMemDesc(nchw_tz, platform::MKLDNNGetDataType(), - input->format()) - : Axis2MemoryDesc(nchw_tz, nchw_axis); - - this->TransposeKernel(ctx.GetPlace(), Axis2MemoryDesc(nchw_tz, axis), - src_md, output, input_data, nchw_tz, mkldnn_engine); - } - - protected: - mkldnn::memory::desc Axis2MemoryDesc(std::vector& nchw_tz, - std::vector& axis) const { - mkldnn_memory_desc_t mem_fmt; - - mem_fmt.primitive_kind = mkldnn_memory; - mem_fmt.ndims = axis.size(); - for (unsigned int i = 0; i < nchw_tz.size(); ++i) { - mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, - // regardless physical layout) - } - mem_fmt.data_type = mkldnn_f32; - mem_fmt.format = mkldnn_blocked; - - unsigned int total_stride = 1; - for (int i = nchw_tz.size() - 1; i >= 0; --i) { - mem_fmt.layout_desc.blocking.padding_dims[i] = - nchw_tz[i]; // logical dimensions (nchw format, regardless physical - // layout) - mem_fmt.layout_desc.blocking.block_dims[i] = 1; - mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset - mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride; - mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1; - total_stride *= nchw_tz[axis[i]]; - } - mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset - return mem_fmt; - } - void TransposeKernel(platform::Place place, mkldnn::memory::desc md_o, - mkldnn::memory::desc md_i, Tensor* output, - const T* data_i, std::vector& nchw_dims, - const mkldnn::engine& eng) const { - // Make Memory primitive descriptors - auto mpd_o = mkldnn::memory::primitive_desc(md_o, eng); - auto mpd_i = mkldnn::memory::primitive_desc(md_i, eng); + const std::string key = platform::TransposeMKLDNNHandler::GetHash( + nchw_tz, axis, ctx.op().Output("Out")); - auto data_o = output->mutable_data( - place, paddle::memory::Allocator::kDefault, mpd_o.get_size()); + platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx, + mkldnn_engine, key); - auto src = mkldnn::memory(mpd_i, (T*)(data_i)); - auto dst = mkldnn::memory(mpd_o, data_o); + auto transpose_src_memory_p = handler.AcquireSrcMemory( + input->format(), platform::to_void_cast(input_data)); + auto transpose_dst_memory_p = + handler.AcquireDstMemory(output, ctx.GetPlace()); + auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, + transpose_src_memory_p); - auto r = mkldnn::reorder(src, dst); - mkldnn::stream(mkldnn::stream::kind::eager).submit({r}).wait(); + std::vector pipeline; + pipeline.push_back(*transpose_p); + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 1c6421f3fa6ffbe7d3c682611def9e87d2fae5b0..23f00406de6190dfef91e259d6af358b5dac1713 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -197,6 +197,130 @@ class MKLDNNHandler { bool is_reusing_; }; +class TransposeMKLDNNHandler : public MKLDNNHandler { + public: + TransposeMKLDNNHandler(std::vector& dims, std::vector& axis, + const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key), + dims_(dims), + axis_(axis), + logical_axis_(dims.size(), 0) {} + + std::shared_ptr AcquireSrcMemory( + const mkldnn::memory::format& fmt, void* ptr) { + auto local_key = key_ + "@user_src_mem_p"; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), + " find mem primitive in device context"); + if (mem_p == nullptr) { + // Make memory descriptor using input format, unless it + // cannot be trusted (nchw) then make up memory fmt manually + for (size_t i = 0; i < logical_axis_.size(); ++i) { + logical_axis_[i] = i; + } + auto src_md = fmt != mkldnn::memory::format::nchw + ? platform::MKLDNNMemDesc( + dims_, platform::MKLDNNGetDataType(), fmt) + : Axis2MemoryDesc(dims_, logical_axis_); + mem_p = std::make_shared( + mkldnn::memory::primitive_desc{src_md, engine_}, ptr); + dev_ctx_.SetBlob(local_key, mem_p); + } else { + mem_p->set_data_handle(ptr); + // Mark that reusing happenned. All primitives from operator instance + // should be reused or none of them. So we check consistency + is_reusing_ = true; + } + return mem_p; + } + + std::shared_ptr AcquireDstMemory(framework::Tensor* output, + platform::Place place) { + auto local_key = key_ + "@user_dst_mem_p"; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), + " find mem primitive in device context"); + if (mem_p == nullptr) { + auto dst_mdp = mkldnn::memory::primitive_desc{ + Axis2MemoryDesc(dims_, axis_), engine_}; + + auto dst_data = output->mutable_data( + place, paddle::memory::Allocator::kDefault, dst_mdp.get_size()); + + mem_p = std::make_shared(dst_mdp, dst_data); + dev_ctx_.SetBlob(local_key, mem_p); + } else { + auto dst_data = output->mutable_data(place); + mem_p->set_data_handle(dst_data); + // Mark that reusing happenned. All primitives from operator instance + // should be reused or none of them. So we check consistency + is_reusing_ = true; + } + return mem_p; + } + + std::shared_ptr AcquireTranspose( + std::shared_ptr dst_memory_p, + std::shared_ptr src_memory_p) { + auto prim_key = key_ + "@transpose_p"; + auto transpose_p = + std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); + PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false), + "Fail to find convolution primitive in device context"); + if (transpose_p == nullptr) { + transpose_p = + std::make_shared(*(src_memory_p), *(dst_memory_p)); + dev_ctx_.SetBlob(prim_key, transpose_p); + } else { + is_reusing_ = true; + } + return transpose_p; + } + + static std::string GetHash(std::vector& shape, // NOLINT + std::vector& axis, // NOLINT + const std::string& suffix) { + return dims2str(shape) + dims2str(axis) + suffix; + } + + protected: + mkldnn_memory_desc_t Axis2MemoryDesc(std::vector& nchw_tz, + std::vector& axis) { + mkldnn_memory_desc_t mem_fmt; + + mem_fmt.primitive_kind = mkldnn_memory; + mem_fmt.ndims = axis.size(); + for (unsigned int i = 0; i < nchw_tz.size(); ++i) { + mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, + // regardless physical layout) + } + mem_fmt.data_type = mkldnn_f32; + mem_fmt.format = mkldnn_blocked; + + unsigned int total_stride = 1; + for (int i = nchw_tz.size() - 1; i >= 0; --i) { + mem_fmt.layout_desc.blocking.padding_dims[i] = + nchw_tz[i]; // logical dimensions (nchw format, regardless physical + // layout) + mem_fmt.layout_desc.blocking.block_dims[i] = 1; + mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset + mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride; + mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1; + total_stride *= nchw_tz[axis[i]]; + } + mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset + return mem_fmt; + } + + private: + std::vector dims_; + std::vector axis_; + std::vector logical_axis_; +}; + template class ConvMKLDNNTemplateHandler : public MKLDNNHandler { public: