提交 7f03ae9a 编写于 作者: M Megvii Engine Team

fix(imperative): reduce tls usage

GitOrigin-RevId: a716b2ae9806e498b8f9d22a981235a8cb9d69e5
上级 85ea882c
...@@ -77,14 +77,14 @@ EncodedSubgraph OpDef::make_backward_graph( ...@@ -77,14 +77,14 @@ EncodedSubgraph OpDef::make_backward_graph(
const SmallVector<bool>& output_has_grad) { const SmallVector<bool>& output_has_grad) {
using BackwardGraphCache = using BackwardGraphCache =
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
thread_local BackwardGraphCache cache; thread_local auto cache = std::make_unique<BackwardGraphCache>();
decltype(cache)::key_t cache_key{ BackwardGraphCache::key_t cache_key{
const_cast<OpDef&>(def).shared_from_this(), const_cast<OpDef&>(def).shared_from_this(),
inputs, inputs,
{input_requires_grad, output_has_grad}}; {input_requires_grad, output_has_grad}};
auto iter = cache.find(cache_key); auto iter = cache->find(cache_key);
if (iter == cache.end()) { if (iter == cache->end()) {
iter = cache.insert({cache_key, def.trait()->make_backward_graph( iter = cache->insert({cache_key, def.trait()->make_backward_graph(
def, inputs, input_requires_grad, def, inputs, input_requires_grad,
output_has_grad)}) output_has_grad)})
.first; .first;
...@@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph( ...@@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
using ForwardGraphCache = using ForwardGraphCache =
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
thread_local ForwardGraphCache cache; thread_local auto cache = std::make_unique<ForwardGraphCache>();
decltype(cache)::key_t cache_key{ ForwardGraphCache::key_t cache_key{
const_cast<OpDef&>(def).shared_from_this(), inputs}; const_cast<OpDef&>(def).shared_from_this(), inputs};
auto iter = cache.find(cache_key); auto iter = cache->find(cache_key);
if (iter == cache.end()) { if (iter == cache->end()) {
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
.first; .first;
} }
return iter->second; return iter->second;
......
...@@ -34,8 +34,20 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -34,8 +34,20 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return inputs; return inputs;
} }
auto make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
Subgraph graph;
graph.inputs = {1, 2, 3};
graph.outputs = {3};
graph.exprs = {};
return EncodedSubgraph::make(graph);
}
OP_TRAIT_REG(FastpathCopy, FastpathCopy) OP_TRAIT_REG(FastpathCopy, FastpathCopy)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.make_backward_graph(make_backward_graph)
.fallback(); .fallback();
} // namespace fastpathcopy } // namespace fastpathcopy
} // namespace } // namespace
...@@ -290,10 +302,10 @@ ComputingGraphHolder& get_computing_graph( ...@@ -290,10 +302,10 @@ ComputingGraphHolder& get_computing_graph(
std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) { std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) {
using ComputingGraphHolderCache = using ComputingGraphHolderCache =
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder>>>; OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder>>>;
thread_local ComputingGraphHolderCache cache; thread_local auto cache = std::make_unique<ComputingGraphHolderCache>();
thread_local size_t nr_cg_holders = 0; thread_local size_t nr_cg_holders = 0;
ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
auto& cg_holder_queue = cache[cache_key]; auto& cg_holder_queue = (*cache)[cache_key];
std::unique_ptr<ComputingGraphHolder> holder; std::unique_ptr<ComputingGraphHolder> holder;
if (!cg_holder_queue.empty()) { if (!cg_holder_queue.empty()) {
// pick one // pick one
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册