diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 5bfa1aaa696d5cbe8bdcb94d708746259952740f..909cd5895b2260cacb6e4ed56077b65ea6a8d62d 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -18,9 +18,6 @@ namespace paddle { namespace operators { -using conv_bwd_data = mkldnn::convolution_backward_data; -using conv_bwd_weights = mkldnn::convolution_backward_weights; -using conv_fwd = mkldnn::convolution_forward; using framework::DataLayout; using mkldnn::memory; using mkldnn::primitive; @@ -39,6 +36,72 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { conv_pd_ = conv_pd; } + ConvMKLDNNHandler( + std::shared_ptr conv_pd, + std::shared_ptr + conv_bwd_data_pd, + std::shared_ptr + conv_bwd_weights_pd, + const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, + const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key), + conv_pd_(conv_pd), + conv_bwd_weights_pd_(conv_bwd_weights_pd), + conv_bwd_data_pd_(conv_bwd_data_pd) { + // If we are in Grad operatgor then update a key with BWD suffix to + // distinguish from FWD memory primitives + key_ += "-BWD"; + } + + std::shared_ptr AcquireSrcMemoryFromWeightsPrimitive( + const std::shared_ptr user_memory_p, + std::vector& pipeline) { + auto src_pd = conv_bwd_weights_pd_->src_primitive_desc(); + auto user_pd = user_memory_p->get_primitive_desc(); + return this->AcquireMemory(src_pd, user_pd, user_memory_p, + "@weights-src_mem_p", pipeline); + } + + std::shared_ptr AcquireDiffDstMemoryFromWeightsPrimitive( + const std::shared_ptr user_memory_p, + std::vector& pipeline) { + auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc(); + auto user_pd = user_memory_p->get_primitive_desc(); + return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p, + "@weights-diff_dst_mem_p", pipeline); + } + + std::shared_ptr AcquireDiffWeightsMemoryFromWeightsPrimitive( + void* ptr) { + return this->AcquireMemoryFromPrimitive( + conv_bwd_weights_pd_->diff_weights_primitive_desc(), ptr, + "@diff_weights_mem_p"); + } + + std::shared_ptr AcquireDiffDstMemoryFromDataPrimitive( + const std::shared_ptr user_memory_p, + std::vector& pipeline) { + auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc(); + auto user_pd = user_memory_p->get_primitive_desc(); + return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p, + "@data-diff_dst_mem_p", pipeline); + } + + std::shared_ptr AcquireWeightsMemoryFromDataPrimitive( + const std::shared_ptr user_weights_memory_p, + std::vector& pipeline) { + auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc(); + auto user_pd = user_weights_memory_p->get_primitive_desc(); + return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p, + "@data-weights_mem_p", pipeline); + } + + std::shared_ptr AcquireDiffSrcMemoryFromDataPrimitive( + void* ptr) { + return this->AcquireMemoryFromPrimitive( + conv_bwd_data_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p"); + } + std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) { return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr, "@dst_mem_p"); @@ -68,7 +131,6 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr weights_memory_p, std::shared_ptr dst_memory_p) { auto prim_key = key_ + "@conv_p"; - auto prim_desc_key = key_ + "@conv_pd"; auto conv_p = std::static_pointer_cast( dev_ctx_.GetBlob(prim_key)); PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), @@ -85,6 +147,54 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { return conv_p; } + std::shared_ptr + AcquireConvolutionBackwardWeights( + std::shared_ptr src_memory_p, + std::shared_ptr diff_dst_memory_p, + std::shared_ptr diff_weights_memory_p) { + auto prim_key = key_ + "@conv_bwd_weights_p"; + auto conv_bwd_weights_p = + std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + PADDLE_ENFORCE( + (conv_bwd_weights_p != nullptr) || (is_reusing_ == false), + "Fail to find convolution bwd weights primitive in device context"); + if (conv_bwd_weights_p == nullptr) { + // create backward conv primitive for weights + conv_bwd_weights_p = + std::make_shared( + *conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p, + *diff_weights_memory_p); + dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p); + } else { + is_reusing_ = true; + } + return conv_bwd_weights_p; + } + + std::shared_ptr + AcquireConvolutionBackwardData( + std::shared_ptr diff_dst_memory_p, + std::shared_ptr weights_memory_p, + std::shared_ptr diff_src_memory_p) { + auto prim_key = key_ + "@conv_bwd_data_p"; + auto conv_bwd_data_p = + std::static_pointer_cast( + dev_ctx_.GetBlob(prim_key)); + PADDLE_ENFORCE( + (conv_bwd_data_p != nullptr) || (is_reusing_ == false), + "Fail to find convolution bwd data primitive in device context"); + if (conv_bwd_data_p == nullptr) { + conv_bwd_data_p = std::make_shared( + *conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p, + *diff_src_memory_p); + dev_ctx_.SetBlob(prim_key, conv_bwd_data_p); + } else { + is_reusing_ = true; + } + return conv_bwd_data_p; + } + // Generate keys for storing/retriving primitives for this operator // TODO(jczaja): Make hashing function more optimial static std::string GetHash(memory::dims& input_dims, @@ -100,6 +210,10 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { private: std::shared_ptr conv_pd_; + std::shared_ptr + conv_bwd_weights_pd_; + std::shared_ptr + conv_bwd_data_pd_; }; template @@ -174,8 +288,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst_tz, platform::MKLDNNGetDataType(), memory::format::any); // create a conv primitive descriptor and save it for usage in backward - std::shared_ptr conv_pd = ConvFwdPrimitiveDesc( - src_md, weights_md, dst_md, strides, paddings, mkldnn_engine); + std::shared_ptr conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + mkldnn_engine); // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -208,21 +323,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } private: - std::unique_ptr ConvFwdPrimitiveDesc( - const memory::desc& src, const memory::desc& weights, - const memory::desc& dst, const std::vector& strides, - const std::vector& paddings, const mkldnn::engine& engine) const { + std::unique_ptr + ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; - auto conv_desc = - conv_fwd::desc(mkldnn::prop_kind::forward, mkldnn::convolution_direct, - src, weights, dst, stride_dims, padding_dims, - padding_dims, mkldnn::padding_kind::zero); + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); - auto p_conv_pd = new conv_fwd::primitive_desc(conv_desc, engine); + auto p_conv_pd = + new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); - return std::unique_ptr(p_conv_pd); + return std::unique_ptr( + p_conv_pd); } }; @@ -290,147 +408,108 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { dilations, groups, ctx.op().Input("Output")); const std::string key_conv_pd = key + "@conv_pd"; + std::vector pipeline; - // create mkldnn memory from input tensors (input/weights/output_grad) - auto user_src_memory = memory( - {{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine}, - to_void_cast(input_data)); - auto user_weights_memory = - memory({{{weights_tz}, memory::data_type::f32, filter->format()}, - mkldnn_engine}, - to_void_cast(filter_data)); - auto user_diff_dst_memory = - memory({{{dst_tz}, memory::data_type::f32, output_grad->format()}, - mkldnn_engine}, - to_void_cast(output_grad_data)); + // Create user memory descriptors + auto user_src_md = platform::MKLDNNMemDesc( + {src_tz}, platform::MKLDNNGetDataType(), input->format()); + auto user_weights_md = platform::MKLDNNMemDesc( + {weights_tz}, platform::MKLDNNGetDataType(), filter->format()); + auto user_diff_dst_md = platform::MKLDNNMemDesc( + {dst_tz}, platform::MKLDNNGetDataType(), output_grad->format()); /* create memory descriptor for conv backward without specified format * ('any') which lets a primitive (conv backward in this case) choose * the memory format preferred for best performance */ - auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32, - memory::format::any); - auto diff_src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32, - memory::format::any); + auto src_md = platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), memory::format::any); + auto diff_src_md = platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), memory::format::any); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, memory::data_type::f32, memory::format::any); + weights_tz, platform::MKLDNNGetDataType(), memory::format::any); auto diff_weights_md = platform::MKLDNNMemDesc( - weights_tz, memory::data_type::f32, memory::format::any); - auto diff_dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32, - memory::format::any); + weights_tz, platform::MKLDNNGetDataType(), memory::format::any); + auto diff_dst_md = platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), memory::format::any); // Retrieve conv_pd from device context - auto conv_pd = std::static_pointer_cast( - dev_ctx.GetBlob(key_conv_pd)); + auto conv_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_conv_pd)); PADDLE_ENFORCE(conv_pd != nullptr, "Fail to find conv_pd in device context"); + // create backward convolution weights primitive descriptor + auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc( + mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md, + strides, paddings, paddings, mkldnn::padding_kind::zero); + auto conv_bwd_weights_pd = + std::make_shared( + conv_bwd_weights_desc, mkldnn_engine, *conv_pd); + + // create backward convolution data primitive descriptor + auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc( + mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md, + strides, paddings, paddings, mkldnn::padding_kind::zero); + auto conv_bwd_data_pd = + std::make_shared( + conv_bwd_data_desc, mkldnn_engine, *conv_pd); + + ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd, conv_bwd_weights_pd, + dev_ctx, mkldnn_engine, key); + + // create mkldnn memory from input tensors (data/weights) + auto user_src_memory_p = + handler.AcquireSrcMemory(user_src_md, to_void_cast(input_data)); + auto user_weights_memory_p = handler.AcquireWeightsMemory( + user_weights_md, to_void_cast(filter_data)); + auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory( + user_diff_dst_md, to_void_cast(output_grad_data)); + // create backward conv primitive for weights if (filter_grad) { - // create backward convolution primitive descriptor - auto conv_bwd_weights_desc = conv_bwd_weights::desc( - mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md, - strides, paddings, paddings, mkldnn::padding_kind::zero); - auto conv_bwd_weights_pd = conv_bwd_weights::primitive_desc( - conv_bwd_weights_desc, mkldnn_engine, *conv_pd); - - // create reorder primitive if the input format is not the preferred one - auto src_memory = user_src_memory; - primitive reorder_src; - bool is_src_reordered = false; - if (memory::primitive_desc(conv_bwd_weights_pd.src_primitive_desc()) != - user_src_memory.get_primitive_desc()) { - src_memory = memory(conv_bwd_weights_pd.src_primitive_desc()); - reorder_src = reorder(user_src_memory, src_memory); - is_src_reordered = true; - } - - auto diff_dst_memory_4filter = user_diff_dst_memory; - primitive reorder_diff_dst_4filter; - bool is_diff_dst_reordered_4filter = false; - if (memory::primitive_desc( - conv_bwd_weights_pd.diff_dst_primitive_desc()) != - user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory_4filter = - memory(conv_bwd_weights_pd.diff_dst_primitive_desc()); - reorder_diff_dst_4filter = - reorder(user_diff_dst_memory, diff_dst_memory_4filter); - is_diff_dst_reordered_4filter = true; - } - - // create mkldnn memory for output (i.e. diff weights) - auto diff_weights_memory = - memory(conv_bwd_weights_pd.diff_weights_primitive_desc(), - reinterpret_cast(filter_grad_data)); + auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive( + user_src_memory_p, pipeline); - // create backward conv primitive for weights - auto conv_bwd_weights_prim = - conv_bwd_weights(conv_bwd_weights_pd, src_memory, - diff_dst_memory_4filter, diff_weights_memory); - - // push primitive and execute it - std::vector pipeline; - if (is_src_reordered) pipeline.push_back(reorder_src); - if (is_diff_dst_reordered_4filter) - pipeline.push_back(reorder_diff_dst_4filter); - pipeline.push_back(conv_bwd_weights_prim); - stream(stream::kind::eager).submit(pipeline).wait(); + auto diff_dst_memory_4filter_p = + handler.AcquireDiffDstMemoryFromWeightsPrimitive( + user_diff_dst_memory_p, pipeline); + + auto diff_weights_memory_p = + handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( + reinterpret_cast(filter_grad_data)); + + auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights( + src_memory_p, diff_dst_memory_4filter_p, diff_weights_memory_p); + + // push primitive to stream and wait until it's executed + pipeline.push_back(*conv_bwd_weights_p); filter_grad->set_layout(DataLayout::kMKLDNN); - filter_grad->set_format(GetMKLDNNFormat(diff_weights_memory)); + filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p)); } if (input_grad) { - // create backward convolution primitive descriptor - auto conv_bwd_data_desc = conv_bwd_data::desc( - mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md, - strides, paddings, paddings, mkldnn::padding_kind::zero); - auto conv_bwd_data_pd = conv_bwd_data::primitive_desc( - conv_bwd_data_desc, mkldnn_engine, *conv_pd); - - // create reorder primitive if the input format is not the preferred one - auto weights_memory = user_weights_memory; - primitive reorder_weights; - bool is_weights_reordered = false; - if (memory::primitive_desc(conv_bwd_data_pd.weights_primitive_desc()) != - user_weights_memory.get_primitive_desc()) { - weights_memory = memory(conv_bwd_data_pd.weights_primitive_desc()); - reorder_weights = reorder(user_weights_memory, weights_memory); - is_weights_reordered = true; - } - - auto diff_dst_memory_4data = user_diff_dst_memory; - primitive reorder_diff_dst_4data; - bool is_diff_dst_reordered_4data = false; - if (memory::primitive_desc(conv_bwd_data_pd.diff_dst_primitive_desc()) != - user_diff_dst_memory.get_primitive_desc()) { - diff_dst_memory_4data = - memory(conv_bwd_data_pd.diff_dst_primitive_desc()); - reorder_diff_dst_4data = - reorder(user_diff_dst_memory, diff_dst_memory_4data); - is_diff_dst_reordered_4data = true; - } - - // create mkldnn memory for output (i.e. diff src) - auto diff_src_memory = memory(conv_bwd_data_pd.diff_src_primitive_desc(), - reinterpret_cast(input_grad_data)); - - // create backward conv primitive for data - auto conv_bwd_data_prim = - conv_bwd_data(conv_bwd_data_pd, diff_dst_memory_4data, weights_memory, - diff_src_memory); - - // push primitive and execute it - std::vector pipeline; - if (is_weights_reordered) pipeline.push_back(reorder_weights); - if (is_diff_dst_reordered_4data) - pipeline.push_back(reorder_diff_dst_4data); - pipeline.push_back(conv_bwd_data_prim); - stream(stream::kind::eager).submit(pipeline).wait(); + auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive( + user_weights_memory_p, pipeline); + + auto diff_dst_memory_4data_p = + handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, + pipeline); + + auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( + reinterpret_cast(input_grad_data)); + + auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData( + diff_dst_memory_4data_p, weights_memory_p, diff_src_memory_p); + + pipeline.push_back(*conv_bwd_data_p); input_grad->set_layout(DataLayout::kMKLDNN); - input_grad->set_format(GetMKLDNNFormat(diff_src_memory)); + input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p)); } + stream(stream::kind::eager).submit(pipeline).wait(); } // Compute() };