From 7f03ae9aed7451f18c4e489f18097d08fc2bb462 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:17:35 +0800 Subject: [PATCH] fix(imperative): reduce tls usage GitOrigin-RevId: a716b2ae9806e498b8f9d22a981235a8cb9d69e5 --- imperative/src/impl/op_def.cpp | 24 ++++++++++++------------ imperative/src/impl/ops/utility.cpp | 16 ++++++++++++++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index ccc14ed57..6bf1cc4bc 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -77,16 +77,16 @@ EncodedSubgraph OpDef::make_backward_graph( const SmallVector& output_has_grad) { using BackwardGraphCache = OpMethResultCache, SmallVector>; - thread_local BackwardGraphCache cache; - decltype(cache)::key_t cache_key{ + thread_local auto cache = std::make_unique(); + BackwardGraphCache::key_t cache_key{ const_cast(def).shared_from_this(), inputs, {input_requires_grad, output_has_grad}}; - auto iter = cache.find(cache_key); - if (iter == cache.end()) { - iter = cache.insert({cache_key, def.trait()->make_backward_graph( - def, inputs, input_requires_grad, - output_has_grad)}) + auto iter = cache->find(cache_key); + if (iter == cache->end()) { + iter = cache->insert({cache_key, def.trait()->make_backward_graph( + def, inputs, input_requires_grad, + output_has_grad)}) .first; } return iter->second; @@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph( const OpDef& def, const SmallVector& inputs) { using ForwardGraphCache = OpMethResultCache, SmallVector>; - thread_local ForwardGraphCache cache; - decltype(cache)::key_t cache_key{ + thread_local auto cache = std::make_unique(); + ForwardGraphCache::key_t cache_key{ const_cast(def).shared_from_this(), inputs}; - auto iter = cache.find(cache_key); - if (iter == cache.end()) { - iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) + auto iter = cache->find(cache_key); + if (iter == cache->end()) { + iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) .first; } return iter->second; diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 89aa5f5a0..c67ceae60 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -34,8 +34,20 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return inputs; } +auto make_backward_graph( + const OpDef& def, const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad) { + Subgraph graph; + graph.inputs = {1, 2, 3}; + graph.outputs = {3}; + graph.exprs = {}; + return EncodedSubgraph::make(graph); +} + OP_TRAIT_REG(FastpathCopy, FastpathCopy) .apply_on_var_node(apply_on_var_node) + .make_backward_graph(make_backward_graph) .fallback(); } // namespace fastpathcopy } // namespace @@ -290,10 +302,10 @@ ComputingGraphHolder& get_computing_graph( std::shared_ptr compiled_op, SmallVector descs) { using ComputingGraphHolderCache = OpMethResultCache>>; - thread_local ComputingGraphHolderCache cache; + thread_local auto cache = std::make_unique(); thread_local size_t nr_cg_holders = 0; ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; - auto& cg_holder_queue = cache[cache_key]; + auto& cg_holder_queue = (*cache)[cache_key]; std::unique_ptr holder; if (!cg_holder_queue.empty()) { // pick one -- GitLab