diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 521f423fb022098e6930c333af6b5e54c502cb7e..72cac9bc9fac9d9199e1f45db16e529adef2a676 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/framework/data_layout_transform.h" + namespace paddle { namespace operators { @@ -57,6 +59,11 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { return conv_pd_->dst_primitive_desc().get_size(); } + mkldnn::memory::format GetDstFormat() const { + return static_cast( + conv_pd_->dst_primitive_desc().desc().data.format); + } + size_t GetDiffWeightsMemorySize() const { return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); } @@ -108,6 +115,20 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { "@data-weights_mem_p", pipeline); } + std::shared_ptr AcquireResidualDataMemory( + const mkldnn::memory::desc& md, void* ptr) { + return this->AcquireMemory(md, ptr, "@user_residual_data_mem_p"); + } + + std::shared_ptr AcquireDstMemoryFromResidualDataMemory( + const std::shared_ptr& user_residual_memory_p, + void* dst_ptr, + std::vector& pipeline) { // NOLINT + return this->AcquireMemory(user_residual_memory_p, + this->AcquireDstMemoryFromPrimitive(dst_ptr), + "@residual_data_mem_p", pipeline); + } + std::shared_ptr AcquireDiffSrcMemoryFromDataPrimitive( void* ptr) { return this->AcquireMemoryFromPrimitive( @@ -386,7 +407,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto user_weights_memory_p = handler.AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); - T* output_data = nullptr; + // create reorder primitive if the input format is not the preferred one + auto src_memory_p = + handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); + auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( + user_weights_memory_p, pipeline, is_test); + + std::shared_ptr dst_memory_p; if (fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); @@ -399,21 +426,34 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "Output and elementwise parameter need to have the " "same dimension sizes"); - output->ShareDataWith(*residual_param); - output_data = output->mutable_data(ctx.GetPlace()); + if (residual_param->format() != handler.GetDstFormat()) { + auto output_data = + output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); + auto residual_data_tz = + paddle::framework::vectorize2int(residual_param->dims()); + auto residual_data_type = + paddle::framework::ToMKLDNNDataType(residual_param->type()); + + auto user_residual_md = platform::MKLDNNMemDesc( + residual_data_tz, residual_data_type, residual_param->format()); + auto user_residual_memory_p = handler.AcquireResidualDataMemory( + user_residual_md, to_void_cast(residual_param_data)); + + dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), pipeline); + } else { + output->ShareDataWith(*residual_param); + auto output_data = output->mutable_data(ctx.GetPlace()); + dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + } } else { - output_data = + auto output_data = output->mutable_data(ctx.GetPlace(), handler.GetDstMemorySize()); + dst_memory_p = + handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); } - // create reorder primitive if the input format is not the preferred one - auto src_memory_p = - handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); - auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline, is_test); - auto dst_memory_p = - handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); - // create convolution op primitive std::shared_ptr conv_p; if (bias) { diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index c0a2543ba5d8ff8f34cb6231c51cb5053a6a9481..814012e6c1fad414d10f5a64af283bed57e11fe3 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -187,6 +187,29 @@ class MKLDNNHandler { return mem_p; } + std::shared_ptr AcquireMemory( + const std::shared_ptr& user_memory_p, + const std::shared_ptr& target_memory_p, + const std::string& suffix, + std::vector& pipeline) { // NOLINT + auto local_key = key_ + suffix; + auto key_reorder_p = key_ + suffix + "reorder_p"; + + auto stored_reorder_p = std::static_pointer_cast( + dev_ctx_.GetBlob(key_reorder_p)); + + if (stored_reorder_p) { + pipeline.push_back(*stored_reorder_p); + } else { + auto reorder_p = + std::make_shared(*user_memory_p, *target_memory_p); + dev_ctx_.SetBlob(key_reorder_p, reorder_p); + pipeline.push_back(*reorder_p); + } + + return target_memory_p; + } + std::shared_ptr AcquireMemory( mkldnn::memory::primitive_desc& mpd, // NOLINT mkldnn::memory::primitive_desc& user_mpd, // NOLINT