提交 7f03ae9a 编写于 作者: M Megvii Engine Team

fix(imperative): reduce tls usage

GitOrigin-RevId: a716b2ae9806e498b8f9d22a981235a8cb9d69e5
上级 85ea882c
......@@ -77,14 +77,14 @@ EncodedSubgraph OpDef::make_backward_graph(
const SmallVector<bool>& output_has_grad) {
using BackwardGraphCache =
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
thread_local BackwardGraphCache cache;
decltype(cache)::key_t cache_key{
thread_local auto cache = std::make_unique<BackwardGraphCache>();
BackwardGraphCache::key_t cache_key{
const_cast<OpDef&>(def).shared_from_this(),
inputs,
{input_requires_grad, output_has_grad}};
auto iter = cache.find(cache_key);
if (iter == cache.end()) {
iter = cache.insert({cache_key, def.trait()->make_backward_graph(
auto iter = cache->find(cache_key);
if (iter == cache->end()) {
iter = cache->insert({cache_key, def.trait()->make_backward_graph(
def, inputs, input_requires_grad,
output_has_grad)})
.first;
......@@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
using ForwardGraphCache =
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
thread_local ForwardGraphCache cache;
decltype(cache)::key_t cache_key{
thread_local auto cache = std::make_unique<ForwardGraphCache>();
ForwardGraphCache::key_t cache_key{
const_cast<OpDef&>(def).shared_from_this(), inputs};
auto iter = cache.find(cache_key);
if (iter == cache.end()) {
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
auto iter = cache->find(cache_key);
if (iter == cache->end()) {
iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
.first;
}
return iter->second;
......
......@@ -34,8 +34,20 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return inputs;
}
auto make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
Subgraph graph;
graph.inputs = {1, 2, 3};
graph.outputs = {3};
graph.exprs = {};
return EncodedSubgraph::make(graph);
}
OP_TRAIT_REG(FastpathCopy, FastpathCopy)
.apply_on_var_node(apply_on_var_node)
.make_backward_graph(make_backward_graph)
.fallback();
} // namespace fastpathcopy
} // namespace
......@@ -290,10 +302,10 @@ ComputingGraphHolder& get_computing_graph(
std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) {
using ComputingGraphHolderCache =
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder>>>;
thread_local ComputingGraphHolderCache cache;
thread_local auto cache = std::make_unique<ComputingGraphHolderCache>();
thread_local size_t nr_cg_holders = 0;
ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
auto& cg_holder_queue = cache[cache_key];
auto& cg_holder_queue = (*cache)[cache_key];
std::unique_ptr<ComputingGraphHolder> holder;
if (!cg_holder_queue.empty()) {
// pick one
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册