diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index f6219a14173094d15e9c60a2e26f98da1b04ec2e..e52a8317e2113a9489f8c05bcf47bc96bea33c64 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -17,28 +17,16 @@ namespace paddle { namespace framework { -// Holds all the transfer scope across the process. std::unordered_map& global_transfer_data_cache() { - typedef std::unordered_map map_t; - thread_local std::unique_ptr x(new map_t); + thread_local auto* x = new std::unordered_map; return *x; } -// Holds all the transfer scope for this thread. std::unordered_set& global_transfer_scope_cache() { - typedef std::unordered_set set_t; - thread_local std::unique_ptr x(new set_t); + thread_local auto* x = new std::unordered_set; return *x; } -// Try to create a transfer scope. If one cached scope has match the -// requirement, just return that one. -// Inputs: -// @type0: the source kernel type. -// @type1: the target kernel type. -// @scope: the execution scope of this op. -// Returns: A scope used to hold the transfer data across the different kernel -// type. Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, const Scope* scope) { Scope* new_scope{nullptr}; @@ -58,5 +46,27 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, return new_scope; } +void RemoveKidsFromTransferScopeCache(Scope* scope) { + auto it = global_transfer_scope_cache().find(scope); + if (it != global_transfer_scope_cache().end()) { + global_transfer_scope_cache().erase(it); + } + for (auto* s : scope->kids()) { + auto it = global_transfer_scope_cache().find(s); + if (it != global_transfer_scope_cache().end()) { + global_transfer_scope_cache().erase(it); + } + } + + // remove global transfer data cache + auto& cache = global_transfer_data_cache(); + for (auto it = cache.begin(); it != cache.end();) { + if (it->second == scope) + it = cache.erase(it); + else + it++; + } +} + } // namespace framework } // namespace paddle