diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index ccc14ed5788f3075eb4d3096d041a41c41e5f9c6..6bf1cc4bc3d4668e3bc4db0bf1de9a55396b10a0 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -77,16 +77,16 @@ EncodedSubgraph OpDef::make_backward_graph( const SmallVector& output_has_grad) { using BackwardGraphCache = OpMethResultCache, SmallVector>; - thread_local BackwardGraphCache cache; - decltype(cache)::key_t cache_key{ + thread_local auto cache = std::make_unique(); + BackwardGraphCache::key_t cache_key{ const_cast(def).shared_from_this(), inputs, {input_requires_grad, output_has_grad}}; - auto iter = cache.find(cache_key); - if (iter == cache.end()) { - iter = cache.insert({cache_key, def.trait()->make_backward_graph( - def, inputs, input_requires_grad, - output_has_grad)}) + auto iter = cache->find(cache_key); + if (iter == cache->end()) { + iter = cache->insert({cache_key, def.trait()->make_backward_graph( + def, inputs, input_requires_grad, + output_has_grad)}) .first; } return iter->second; @@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph( const OpDef& def, const SmallVector& inputs) { using ForwardGraphCache = OpMethResultCache, SmallVector>; - thread_local ForwardGraphCache cache; - decltype(cache)::key_t cache_key{ + thread_local auto cache = std::make_unique(); + ForwardGraphCache::key_t cache_key{ const_cast(def).shared_from_this(), inputs}; - auto iter = cache.find(cache_key); - if (iter == cache.end()) { - iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) + auto iter = cache->find(cache_key); + if (iter == cache->end()) { + iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) .first; } return iter->second; diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 89aa5f5a05263cb6d6afa824064236e1159f92c8..c67ceae605ed0b260f9686e0f2ad40c741357781 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -34,8 +34,20 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return inputs; } +auto make_backward_graph( + const OpDef& def, const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad) { + Subgraph graph; + graph.inputs = {1, 2, 3}; + graph.outputs = {3}; + graph.exprs = {}; + return EncodedSubgraph::make(graph); +} + OP_TRAIT_REG(FastpathCopy, FastpathCopy) .apply_on_var_node(apply_on_var_node) + .make_backward_graph(make_backward_graph) .fallback(); } // namespace fastpathcopy } // namespace @@ -290,10 +302,10 @@ ComputingGraphHolder& get_computing_graph( std::shared_ptr compiled_op, SmallVector descs) { using ComputingGraphHolderCache = OpMethResultCache>>; - thread_local ComputingGraphHolderCache cache; + thread_local auto cache = std::make_unique(); thread_local size_t nr_cg_holders = 0; ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; - auto& cg_holder_queue = cache[cache_key]; + auto& cg_holder_queue = (*cache)[cache_key]; std::unique_ptr holder; if (!cg_holder_queue.empty()) { // pick one