diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 09386fc31ee31b8ce2bfb3caf5aad053e2b2544c..1b69dd7ea00c7cce45e2ef2691ea14f03e59318d 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -706,7 +706,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::CreateKey(dev_ctx, src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter")); - const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; std::shared_ptr conv_p; std::shared_ptr src_memory_p; @@ -721,6 +720,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // are merged/unified, this will disappear auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); + const std::string key_conv_pd = key_tid + "@conv_pd"; auto prim_key = key_tid + "@conv_p"; auto dst_key = key_tid + "@dst_mem_p"; auto src_key = key_tid + "@src_mem_p"; @@ -731,12 +731,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_reorder_key = key_tid + "@src_mem_preorder_p"; auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p"; - conv_p = std::static_pointer_cast( - dev_ctx.GetBlob(prim_key)); + conv_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_conv_pd)); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - if (conv_p == nullptr || !is_test) { + if (conv_pd == nullptr || !is_test) { float fuse_alpha = ctx.Attr("fuse_alpha"); float fuse_beta = ctx.Attr("fuse_beta"); bool force_fp32_output = ctx.Attr("force_fp32_output"); @@ -946,7 +947,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } // create convolution op primitive - auto scale_bias_key = key + "@scale_bias"; conv_p = handler->AcquireConvolution(); if (bias) { const K* bias_data = bias->data(); @@ -1000,13 +1000,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dev_ctx.GetBlob(weights_key)); dst_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(dst_key)); - conv_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_conv_pd)); - if (conv_pd) { - handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, - mkldnn_engine, key)); - } + conv_p = std::static_pointer_cast( + dev_ctx.GetBlob(prim_key)); + handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, + mkldnn_engine, key)); if (fuse_residual_conn) { auto residual_param = ctx.Input("ResidualData"); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 29a3f8e9dcd3cd628972c0ac77dfedfc85ed9854..d6ab9e50a066e01663800fb424d2edd2a5dc4b9f 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -603,7 +603,6 @@ class MKLDNNHandler { const std::string& base_key) : dev_ctx_(dev_ctx), engine_(engine), - key_common_(base_key), key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) { platform::MKLDNNDeviceContext::tls().log_lib_version(); } @@ -789,7 +788,6 @@ class MKLDNNHandler { protected: const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; - std::string key_common_; std::string key_; }; @@ -1371,42 +1369,34 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { // Conv PD has to be passed to Grad op that // may be exxecuted by diffrent thread, hence // for that one we use key that does not contain TID - const std::string key_conv_pd = key_common_ + "@conv_pd"; + const std::string key_conv_pd = key_ + "@conv_pd"; conv_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_conv_pd)); if (conv_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - - conv_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_conv_pd)); - if (conv_pd_ == nullptr) { - mkldnn::memory::dims stride_dims = strides; - mkldnn::memory::dims dilations_dims = dilations; - auto mkldnn_paddings = ToMkldnnPadding(paddings); - - auto conv_desc = - bias ? typename forward_t::desc( - fwd_prop_kind, convolutional_algorithm::T, - src, weights, *bias, dst, stride_dims, dilations_dims, - mkldnn_paddings[0], mkldnn_paddings[1]) - : typename forward_t::desc( - fwd_prop_kind, convolutional_algorithm::T, - src, weights, dst, stride_dims, dilations_dims, - mkldnn_paddings[0], mkldnn_paddings[1]); - - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_activation, fuse_alpha, fuse_beta, - fuse_residual_conn, output_shift_scale, sum_scale); - - conv_pd_.reset(new typename forward_t::primitive_desc( - conv_desc, conv_attr, engine)); - // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx_.SetBlob(key_conv_pd, conv_pd_); - } + mkldnn::memory::dims stride_dims = strides; + mkldnn::memory::dims dilations_dims = dilations; + auto mkldnn_paddings = ToMkldnnPadding(paddings); + + auto conv_desc = + bias ? typename forward_t::desc( + fwd_prop_kind, convolutional_algorithm::T, src, + weights, *bias, dst, stride_dims, dilations_dims, + mkldnn_paddings[0], mkldnn_paddings[1]) + : typename forward_t::desc( + fwd_prop_kind, convolutional_algorithm::T, src, + weights, dst, stride_dims, dilations_dims, + mkldnn_paddings[0], mkldnn_paddings[1]); + + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, output_shift_scale, sum_scale); + + conv_pd_.reset( + new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); + // Save conv_pd/src_memory/weights_memory for backward pass + dev_ctx_.SetBlob(key_conv_pd, conv_pd_); } return conv_pd_;