From 84bb45c054e34f53c290b63d1defd76ac5ab2f94 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 11 Jun 2019 08:11:59 +0200 Subject: [PATCH] [MKL-DNN] Thread-Safety for MKL-DNN reusing Part 1 (#17965) * - removed is_reusing_ * - Added TID to keys for reusing apart from softmax PD * - compilation fix * - Yet another compilation fix * - Batch Norm and Conv adapted * - Fix to softmax MT * - Fixes to MT code of MKL-DNN * - Lint fixes test=develop --- .../operators/mkldnn/batch_norm_mkldnn_op.cc | 36 ++--- .../operators/mkldnn/softmax_mkldnn_op.cc | 34 +++-- paddle/fluid/platform/mkldnn_reuse.h | 129 ++++++------------ 3 files changed, 77 insertions(+), 122 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 911c4d22ee5..40f7231c125 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -61,20 +61,25 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc, const mkldnn::engine &engine) { - const std::string key_batch_norm_fwd_pd = key_ + "@bn_fwd_pd"; - auto batch_norm_pd = - std::static_pointer_cast( - dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); - - if (batch_norm_pd == nullptr) { - batch_norm_pd_.reset( - new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine)); - dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_); - } else { - batch_norm_pd_ = batch_norm_pd; - is_reusing_ = true; + // BatchNorm PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_batch_norm_fwd_pd = key_common_ + "@bn_fwd_pd"; + batch_norm_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); + + if (batch_norm_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + batch_norm_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); + if (batch_norm_pd_ == nullptr) { + batch_norm_pd_.reset( + new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine)); + dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_); + } } - return batch_norm_pd_; } @@ -87,9 +92,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { auto batch_norm_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_, - "Fail to find batch norm primitive in device context"); - if (batch_norm_p == nullptr) { if (is_test) { batch_norm_p = std::make_shared( @@ -104,8 +106,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { } dev_ctx_.SetBlob(prim_key, batch_norm_p); - } else { - is_reusing_ = true; } return batch_norm_p; diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 1b3f33d345f..a01dd512a37 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -54,18 +54,24 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc, const mkldnn::engine& engine) { - const std::string key_softmax_pd = key_ + "@softmax_pd"; + // Softmax PD has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_softmax_pd = key_common_ + "@softmax_pd"; - auto softmax_pd = std::static_pointer_cast( + softmax_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_softmax_pd)); - - if (softmax_pd == nullptr) { - softmax_pd_.reset( - new softmax_forward::primitive_desc(softmax_desc, engine)); - dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_); - } else { - softmax_pd_ = softmax_pd; - is_reusing_ = true; + if (softmax_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + softmax_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_softmax_pd)); + if (softmax_pd_ == nullptr) { + softmax_pd_.reset( + new softmax_forward::primitive_desc(softmax_desc, engine)); + dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_); + } } return softmax_pd_; @@ -79,15 +85,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { auto softmax_p = std::static_pointer_cast( dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false), - "Fail to find softmax primitive in device context"); if (softmax_p == nullptr) { softmax_p = std::make_shared( *softmax_pd_, *(static_cast(src_memory_p.get())), *(static_cast(dst_memory_p.get()))); dev_ctx_.SetBlob(prim_key, softmax_p); - } else { - is_reusing_ = true; } return softmax_p; @@ -100,15 +102,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { auto prim_key = key_ + "@softmax_bwd_p"; auto softmax_bwd_p = std::static_pointer_cast( dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false), - "Fail to find softmax backward primitive in device context"); if (softmax_bwd_p == nullptr) { softmax_bwd_p = std::make_shared( *softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p); - } else { - is_reusing_ = true; } return softmax_bwd_p; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index fa36a49fb88..f1fb6b156ae 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include "boost/optional.hpp" @@ -31,10 +32,13 @@ class MKLDNNHandler { public: MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const std::string& base_key) - : dev_ctx_(dev_ctx), - engine_(engine), - key_(base_key), - is_reusing_(false) {} + : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { + // TODO(jczaja): Make it faster + auto tid = std::this_thread::get_id(); + std::stringstream ss; + ss << tid; + key_ = key_common_ + "-t:" + ss.str(); + } std::shared_ptr AcquireSrcMemory( const mkldnn::memory::desc& md, void* ptr) { @@ -73,16 +77,11 @@ class MKLDNNHandler { auto local_key = key_ + suffix; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), - "Fail to find mem primitive in device context"); if (mem_p == nullptr) { mem_p = std::make_shared(mdp, ptr); dev_ctx_.SetBlob(local_key, mem_p); } else { mem_p->set_data_handle(ptr); - // Mark that reusing happenned. All primitives from operator instance - // should be reused or none of them. So we check consistency - is_reusing_ = true; } return mem_p; } @@ -96,8 +95,6 @@ class MKLDNNHandler { auto local_key = key_ + suffix; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), - "Fail to find mem primitive in device context"); if (mem_p == nullptr) { // Call custom reorder/preprocessing func if available if (custom_func) { @@ -111,9 +108,6 @@ class MKLDNNHandler { dev_ctx_.SetBlob(local_key, mem_p); } else { mem_p->set_data_handle(ptr); - // Mark that reusing happenned. All primitives from operator instance - // should be reused or none of them. So we check consistency - is_reusing_ = true; } return mem_p; } @@ -155,8 +149,6 @@ class MKLDNNHandler { auto target_memory_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((target_memory_p != nullptr) || (is_reusing_ == false), - "Fail to find mem primitive in device context"); if (target_memory_p == nullptr) { target_memory_p = user_memory_p; std::shared_ptr reorder_p; @@ -187,7 +179,6 @@ class MKLDNNHandler { if (reorder_p != nullptr) { pipeline.push_back(*reorder_p); } - is_reusing_ = true; } return target_memory_p; } @@ -268,7 +259,7 @@ class MKLDNNHandler { const MKLDNNDeviceContext& dev_ctx_; mkldnn::engine engine_; std::string key_; - bool is_reusing_; + std::string key_common_; public: static constexpr int MaxKeyLength = 256; @@ -290,8 +281,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { auto local_key = key_ + "@user_src_mem_p"; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), - " find mem primitive in device context"); if (mem_p == nullptr) { // Make memory descriptor using input format, unless it // cannot be trusted (nchw) then make up memory fmt manually @@ -307,9 +296,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { dev_ctx_.SetBlob(local_key, mem_p); } else { mem_p->set_data_handle(ptr); - // Mark that reusing happenned. All primitives from operator instance - // should be reused or none of them. So we check consistency - is_reusing_ = true; } return mem_p; } @@ -319,8 +305,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { auto local_key = key_ + "@user_dst_mem_p"; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), - " find mem primitive in device context"); if (mem_p == nullptr) { auto dst_mdp = mkldnn::memory::primitive_desc{ Axis2MemoryDesc(dims_, axis_), engine_}; @@ -332,9 +316,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { } else { auto dst_data = output->mutable_data(place); mem_p->set_data_handle(dst_data); - // Mark that reusing happenned. All primitives from operator instance - // should be reused or none of them. So we check consistency - is_reusing_ = true; } return mem_p; } @@ -345,14 +326,10 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { auto prim_key = key_ + "@transpose_p"; auto transpose_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false), - "Fail to find convolution primitive in device context"); if (transpose_p == nullptr) { transpose_p = std::make_shared(*(src_memory_p), *(dst_memory_p)); dev_ctx_.SetBlob(prim_key, transpose_p); - } else { - is_reusing_ = true; } return transpose_p; } @@ -416,8 +393,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { auto local_key = key_ + "@user_src_mem_p"; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), - " find mem primitive in device context"); if (mem_p == nullptr) { auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); mem_p = std::make_shared( @@ -425,7 +400,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { dev_ctx_.SetBlob(local_key, mem_p); } else { mem_p->set_data_handle(ptr); - is_reusing_ = true; } return mem_p; } @@ -436,8 +410,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { auto local_key = key_ + "@user_dst_mem_p"; auto mem_p = std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); - PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), - " find mem primitive in device context"); if (mem_p == nullptr) { auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_}; @@ -449,7 +421,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { } else { auto dst_data = output->mutable_data(place, vtype_); mem_p->set_data_handle(dst_data); - is_reusing_ = true; } return mem_p; } @@ -460,14 +431,10 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { auto prim_key = key_ + "@reorder_p"; auto reorder_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((reorder_p != nullptr) || (is_reusing_ == false), - "Fail to find convolution primitive in device context"); if (reorder_p == nullptr) { reorder_p = std::make_shared(*(src_memory_p), *(dst_memory_p)); dev_ctx_.SetBlob(prim_key, reorder_p); - } else { - is_reusing_ = true; } return reorder_p; } @@ -695,35 +662,43 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { const bool fuse_relu, const bool fuse_residual_conn, const bool fuse_brelu, const float fuse_brelu_threshold, mkldnn::prop_kind fwd_prop_kind) { - const std::string key_conv_pd = key_ + "@conv_pd"; + // Conv PD has to be passed to Grad op that + // may be exxecuted by diffrent thread, hence + // for that one we use key that does not contain TID + const std::string key_conv_pd = key_common_ + "@conv_pd"; - auto conv_pd = std::static_pointer_cast( + conv_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_conv_pd)); - if (conv_pd == nullptr) { - mkldnn::memory::dims stride_dims = strides; - mkldnn::memory::dims padding_dims = paddings; - - auto conv_desc = - bias ? typename forward_t::desc( - fwd_prop_kind, convolutional_algorithm::T, src, - weights, *bias, dst, stride_dims, padding_dims, - padding_dims, mkldnn::padding_kind::zero) - : typename forward_t::desc( - fwd_prop_kind, convolutional_algorithm::T, src, - weights, dst, stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = CreatePostOps( - fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); - - conv_pd_.reset( - new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); - // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx_.SetBlob(key_conv_pd, conv_pd_); - } else { - conv_pd_ = conv_pd; - is_reusing_ = true; + if (conv_pd_ == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + + conv_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_conv_pd)); + if (conv_pd_ == nullptr) { + mkldnn::memory::dims stride_dims = strides; + mkldnn::memory::dims padding_dims = paddings; + + auto conv_desc = + bias ? typename forward_t::desc( + fwd_prop_kind, convolutional_algorithm::T, + src, weights, *bias, dst, stride_dims, padding_dims, + padding_dims, mkldnn::padding_kind::zero) + : typename forward_t::desc( + fwd_prop_kind, convolutional_algorithm::T, + src, weights, dst, stride_dims, padding_dims, + padding_dims, mkldnn::padding_kind::zero); + + mkldnn::primitive_attr conv_attr = CreatePostOps( + fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); + + conv_pd_.reset(new typename forward_t::primitive_desc( + conv_desc, conv_attr, engine)); + // Save conv_pd/src_memory/weights_memory for backward pass + dev_ctx_.SetBlob(key_conv_pd, conv_pd_); + } } return conv_pd_; @@ -736,15 +711,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { auto prim_key = key_ + "@conv_p"; auto conv_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), - "Fail to find convolution primitive in device context"); if (conv_p == nullptr) { conv_p = std::make_shared(*conv_pd_, *src_memory_p, *weights_memory_p, *dst_memory_p); dev_ctx_.SetBlob(prim_key, conv_p); - } else { - is_reusing_ = true; } return conv_p; } @@ -757,16 +728,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { auto prim_key = key_ + "@conv_p"; auto conv_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false), - "Fail to find convolution primitive in device context"); if (conv_p == nullptr) { conv_p = std::make_shared(*conv_pd_, *src_memory_p, *weights_memory_p, *bias_memory_p, *dst_memory_p); dev_ctx_.SetBlob(prim_key, conv_p); - } else { - is_reusing_ = true; } return conv_p; } @@ -778,17 +745,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { auto prim_key = key_ + "@conv_bwd_weights_p"; auto conv_bwd_weights_p = std::static_pointer_cast( dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE( - (conv_bwd_weights_p != nullptr) || (is_reusing_ == false), - "Fail to find convolution bwd weights primitive in device context"); if (conv_bwd_weights_p == nullptr) { // create backward conv primitive for weights conv_bwd_weights_p = std::make_shared( *conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p, *diff_weights_memory_p); dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p); - } else { - is_reusing_ = true; } return conv_bwd_weights_p; } @@ -800,16 +762,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { auto prim_key = key_ + "@conv_bwd_data_p"; auto conv_bwd_data_p = std::static_pointer_cast(dev_ctx_.GetBlob(prim_key)); - PADDLE_ENFORCE( - (conv_bwd_data_p != nullptr) || (is_reusing_ == false), - "Fail to find convolution bwd data primitive in device context"); if (conv_bwd_data_p == nullptr) { conv_bwd_data_p = std::make_shared( *conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p, *diff_src_memory_p); dev_ctx_.SetBlob(prim_key, conv_bwd_data_p); - } else { - is_reusing_ = true; } return conv_bwd_data_p; } -- GitLab