From 5c073a4db215f3c441ed76145398131c2b561708 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 26 Nov 2018 15:53:37 +0800 Subject: [PATCH] fix transfer cache thread_local bug (#14581) --- .../fluid/framework/transfer_scope_cache.cc | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index e52a8317e21..f6219a14173 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -17,16 +17,28 @@ namespace paddle { namespace framework { +// Holds all the transfer scope across the process. std::unordered_map& global_transfer_data_cache() { - thread_local auto* x = new std::unordered_map; + typedef std::unordered_map map_t; + thread_local std::unique_ptr x(new map_t); return *x; } +// Holds all the transfer scope for this thread. std::unordered_set& global_transfer_scope_cache() { - thread_local auto* x = new std::unordered_set; + typedef std::unordered_set set_t; + thread_local std::unique_ptr x(new set_t); return *x; } +// Try to create a transfer scope. If one cached scope has match the +// requirement, just return that one. +// Inputs: +// @type0: the source kernel type. +// @type1: the target kernel type. +// @scope: the execution scope of this op. +// Returns: A scope used to hold the transfer data across the different kernel +// type. Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, const Scope* scope) { Scope* new_scope{nullptr}; @@ -46,27 +58,5 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, return new_scope; } -void RemoveKidsFromTransferScopeCache(Scope* scope) { - auto it = global_transfer_scope_cache().find(scope); - if (it != global_transfer_scope_cache().end()) { - global_transfer_scope_cache().erase(it); - } - for (auto* s : scope->kids()) { - auto it = global_transfer_scope_cache().find(s); - if (it != global_transfer_scope_cache().end()) { - global_transfer_scope_cache().erase(it); - } - } - - // remove global transfer data cache - auto& cache = global_transfer_data_cache(); - for (auto it = cache.begin(); it != cache.end();) { - if (it->second == scope) - it = cache.erase(it); - else - it++; - } -} - } // namespace framework } // namespace paddle -- GitLab