diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index f11edb9a41bdcbcb33efc600f1d7d8f70fb76f45..c163f0edf16238ef7467c3563c417b26e2bf0923 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -557,6 +557,7 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { } } } + platform::AttachPointerHashToMKLDNNKey(this, place_); #else LOG(WARNING) << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index be405a2cfb6b202e365aafbc46a9aea0c8e543e8..943997be2e12b7a2218008dc020e8212d53232ab 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -44,6 +44,9 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, } void NaiveExecutor::Run() { +#ifdef PADDLE_WITH_MKLDNN + platform::AttachPointerHashToMKLDNNKey(this, place_); +#endif for (auto &op : ops_) { VLOG(4) << std::this_thread::get_id() << " run " << op->DebugStringEx(scope_) << " on scope " << scope_; diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index bb475b4e543660e4de6b8460ee97e573e25cf8ef..114daaecb59369658191b382a0471d30448a7462 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -160,7 +160,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { std::string key = platform::CreateKey( paddle::framework::vectorize(multi_input[0]->dims()), multi_input.size(), ctx.OutputName("Out"), dt, - platform::ThreadIDasStr()); + platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); const std::string key_prim = key + "@concat_p"; const std::string key_concat_pd = key + "@concat_pd"; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index d560e80a332b568dc41bbe4e72b9e6d99ece298e..6f0987deeabf50de1cd91a5f7fc0a461b35fa1f6 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -361,7 +361,8 @@ class FCPrimitiveFactory { void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, const ExecutionContext& ctx) { - const std::string key = platform::CreateKey(platform::ThreadIDasStr()); + const std::string key = + platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); const std::string weights_key = key + ctx.InputName("W"); const std::string bias_key = key + ctx.InputName("Bias"); dev_ctx.SetBlob(weights_key, weights_); @@ -532,8 +533,9 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, bool fuse_relu, bool force_fp32_output) { auto& dev_ctx = ctx.template device_context(); const std::string prim_key = platform::CreateKey( - platform::ThreadIDasStr(), input->format(), input->dims()[0], - framework::vectorize(w->dims()), ctx.OutputName("Out")); + platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(), + input->dims()[0], framework::vectorize(w->dims()), + ctx.OutputName("Out")); constexpr bool is_int8 = std::is_same::value || std::is_same::value; bool is_bfloat16 = std::is_same::value; diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 21f94c07c1fea6ca0a9f57fa2bc198238b8ce94c..1f2216cbed2b256b15df21956da6741affd8b296 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -337,8 +337,8 @@ static std::shared_ptr> GetPrimitiveFactory( const auto& dev_ctx = ctx.template device_context(); const auto batch_size = ctx.Input("X")->dims()[0]; - const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), batch_size, out_name); + const std::string key = platform::CreateKey( + platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e8b1d587121dc7ed31dc3362c5061ec51a8dafde..074106f3f205121c145fa98e2b766c1cb8354bc3 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -535,6 +535,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Remove all entries from the blob map void ResetBlobMap(); + // Set a suffix to be added to key + void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } + const std::string& GetKeySuffix(void) const { return key_suffix_; } + // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -556,6 +560,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { std::shared_ptr p_blobmap_; std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; + std::string key_suffix_; // Key identifying current Executor }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 67b68183cc8470c9dadb4d67e6a39a2fd889c4e8..34f5759e4cd01607a63946174d2726ed00b8693c 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -433,6 +433,23 @@ inline void AppendKey(std::string* key, const std::vector& dims) { } } +inline unsigned int HashPointer(uintptr_t ptr) { + // Get four less meaningful digits in decimal numerals + return ptr % 1000; +} + +// If MKLDNN build and CPU place then register suffix in DeviceContext +inline void AttachPointerHashToMKLDNNKey(void* ptr, + const platform::Place& place) { + if (platform::is_cpu_place(place)) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(place); + dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer( + reinterpret_cast(ptr)))); + } +} + template inline std::string CreateKey(ArgTypes&&... args) { std::string key; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 8649b90321c13ba9605b1de7e581d831bde62bdd..90266f6c2099b9667a78f9ca6d29c7ceec2a74bb 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -51,6 +51,7 @@ class MKLDNNHandlerT { } else { key_ = key_common_ + "-t:" + ThreadIDasStr(); } + key_ += dev_ctx.GetKeySuffix(); } std::shared_ptr AcquireForwardPrimitive() { @@ -316,6 +317,7 @@ class MKLDNNHandler { } else { key_ = key_common_ + "-t:" + ThreadIDasStr(); } + key_ += dev_ctx.GetKeySuffix(); } std::shared_ptr AcquireSrcMemory(