diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index c0fbc336e46b64fc6ee43ef1a7372e413c5c3213..1c493fc6be093a2af8f58c8e78d1be43de34306f 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -29,6 +29,11 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/variable.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +DECLARE_bool(use_mkldnn); namespace paddle { namespace operators { @@ -262,6 +267,9 @@ class RunProgramOpKernel : public framework::OpKernel { } VLOG(2) << "The number of sub scopes after forward: " << out_scope_vec->front()->kids().size(); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) DontClearMKLDNNCache(ctx.GetPlace()); +#endif } }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 9a50281076942b0364441c77ef80fe65cab49c6c..29982c13c8ca88bc8b4a168f92e4116a283a97e8 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -464,9 +464,21 @@ MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) { return cur_paddle_data_layout; } -void MKLDNNDeviceContext::ResetBlobMap() const { - VLOG(3) << "Clearing DNNL cache."; - p_blobmap_->clear(); +void MKLDNNDeviceContext::ResetBlobMap() { + std::lock_guard lock(*p_mutex_); + if (!block_next_cache_clearing_) { + VLOG(3) << "Clearing DNNL cache."; + p_blobmap_->clear(); + } else { + VLOG(3) << "Prevented Clearing DNNL cache."; + block_next_cache_clearing_ = false; + } +} + +void MKLDNNDeviceContext::BlockNextCacheClearing() { + std::lock_guard lock(*p_mutex_); + VLOG(3) << "Next DNNL cache clearing has been blocked."; + block_next_cache_clearing_ = true; } size_t MKLDNNDeviceContext::GetShapeBlobSize() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3c476f4f08b1042056063e2e777cbf6eb3497a29..8bfdfc8a1c6033a79c197e1cd425197f77079bda 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -520,7 +520,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { const mkldnn::engine& GetEngine() const { return engine_; } // Remove all entries from the blob map - void ResetBlobMap() const; + void ResetBlobMap(); + + // Prevent next ResetBlobMap() + void BlockNextCacheClearing(); // Get the ShapeBlob size in cur_mkldnn_session_id. size_t GetShapeBlobSize() const; @@ -539,6 +542,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { mkldnn::engine engine_; std::shared_ptr p_blobmap_; std::shared_ptr p_mutex_; + bool block_next_cache_clearing_ = false; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index c74c47b7d84820f089d4e657f8bddccc5de8d727..3782eb684f21f8c09e9dac124082ae596fe5d1bc 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -129,6 +129,16 @@ inline void ClearMKLDNNCache(const platform::Place& place) { } } +inline void DontClearMKLDNNCache(const platform::Place& place) { + // Clear mkl-dnn cache, + if (platform::is_cpu_place(place)) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(place); + dev_ctx->BlockNextCacheClearing(); + } +} + template mkldnn::memory::data_type MKLDNNGetDataType() { return mkldnn::memory::data_type::undef;