From 2d0bb2c3961c7bb06746051732b460829e2450dd Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:11:28 +0800 Subject: [PATCH] fix infer tans scope (#45203) * fix infer tans scop * fix infer trans scope * fic infer trans scope * fic infer trans scope Co-authored-by: dingjiawei <327396238@qq.com> --- paddle/fluid/framework/transfer_scope_cache.cc | 9 +++++++++ paddle/fluid/framework/transfer_scope_cache.h | 3 +++ paddle/fluid/inference/api/analysis_predictor.cc | 10 ++++++++++ 3 files changed, 22 insertions(+) diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index e04c4583d33..c812a6dc95a 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -27,6 +27,13 @@ std::unordered_set& global_transfer_scope_cache() { return *x; } +std::unordered_map>& +global_transfer_scope_key() { + thread_local auto* x = + new std::unordered_map>; + return *x; +} + Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, const Scope* scope) { @@ -36,6 +43,8 @@ Scope* TryCreateTransferScope(OpKernelType type0, infer_cache_key = CombineHash(infer_cache_key, std::hash()(scope)); + global_transfer_scope_key()[scope].insert(infer_cache_key); + auto it = global_transfer_data_cache().find(infer_cache_key); if (it != global_transfer_data_cache().end()) { new_scope = global_transfer_data_cache()[infer_cache_key]; diff --git a/paddle/fluid/framework/transfer_scope_cache.h b/paddle/fluid/framework/transfer_scope_cache.h index 7e639de615c..da2e319d5ba 100644 --- a/paddle/fluid/framework/transfer_scope_cache.h +++ b/paddle/fluid/framework/transfer_scope_cache.h @@ -31,6 +31,9 @@ std::unordered_map& global_transfer_data_cache(); std::unordered_set& global_transfer_scope_cache(); +std::unordered_map>& +global_transfer_scope_key(); + // Combine two hash values to a single hash. static size_t CombineHash(size_t seed, size_t a) { return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a8962c61e47..49045089ce5 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -33,6 +33,7 @@ #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/version.h" #include "paddle/fluid/inference/analysis/helper.h" @@ -1928,6 +1929,15 @@ AnalysisPredictor::~AnalysisPredictor() { "./profile.log"); } if (sub_scope_) { + if (framework::global_transfer_scope_key().find(sub_scope_) != + framework::global_transfer_scope_key().end()) { + auto scope_key_set = framework::global_transfer_scope_key()[sub_scope_]; + for (auto iter = scope_key_set.begin(); iter != scope_key_set.end(); + iter++) { + framework::global_transfer_data_cache().erase(*iter); + } + framework::global_transfer_scope_key().erase(sub_scope_); + } scope_->DeleteScope(sub_scope_); } -- GitLab