未验证 提交 2d0bb2c3 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix infer tans scope (#45203)

* fix infer tans scop

* fix infer trans scope

* fic infer trans scope

* fic infer trans scope
Co-authored-by: Ndingjiawei <327396238@qq.com>
上级 133f608f
......@@ -27,6 +27,13 @@ std::unordered_set<Scope*>& global_transfer_scope_cache() {
return *x;
}
std::unordered_map<const Scope*, std::unordered_set<size_t>>&
global_transfer_scope_key() {
thread_local auto* x =
new std::unordered_map<const Scope*, std::unordered_set<size_t>>;
return *x;
}
Scope* TryCreateTransferScope(OpKernelType type0,
OpKernelType type1,
const Scope* scope) {
......@@ -36,6 +43,8 @@ Scope* TryCreateTransferScope(OpKernelType type0,
infer_cache_key =
CombineHash(infer_cache_key, std::hash<const Scope*>()(scope));
global_transfer_scope_key()[scope].insert(infer_cache_key);
auto it = global_transfer_data_cache().find(infer_cache_key);
if (it != global_transfer_data_cache().end()) {
new_scope = global_transfer_data_cache()[infer_cache_key];
......
......@@ -31,6 +31,9 @@ std::unordered_map<size_t, Scope*>& global_transfer_data_cache();
std::unordered_set<Scope*>& global_transfer_scope_cache();
std::unordered_map<const Scope*, std::unordered_set<size_t>>&
global_transfer_scope_key();
// Combine two hash values to a single hash.
static size_t CombineHash(size_t seed, size_t a) {
return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
......
......@@ -33,6 +33,7 @@
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/inference/analysis/helper.h"
......@@ -1928,6 +1929,15 @@ AnalysisPredictor::~AnalysisPredictor() {
"./profile.log");
}
if (sub_scope_) {
if (framework::global_transfer_scope_key().find(sub_scope_) !=
framework::global_transfer_scope_key().end()) {
auto scope_key_set = framework::global_transfer_scope_key()[sub_scope_];
for (auto iter = scope_key_set.begin(); iter != scope_key_set.end();
iter++) {
framework::global_transfer_data_cache().erase(*iter);
}
framework::global_transfer_scope_key().erase(sub_scope_);
}
scope_->DeleteScope(sub_scope_);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册