未验证 提交 3c49f08e 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Second fix to #33021 (#33471)

* - Second fix

- fix

* - fix
上级 681778d8
......@@ -343,8 +343,6 @@ void AnalysisPredictor::MkldnnPreSet(
platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
platform::MKLDNNDeviceContextThreadLocals::
kMKLDNNSessionID_CacheClearing);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(
config_.mkldnn_cache_capacity_);
// Set current_input_shape for caching dynamic shape.
std::stringstream ss;
for (size_t i = 0; i < inputs_shape.size(); ++i) {
......@@ -355,6 +353,9 @@ void AnalysisPredictor::MkldnnPreSet(
VLOG(2) << "Set input shape=" << ss.str();
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str());
}
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(
config_.mkldnn_cache_capacity_);
#endif
}
......@@ -370,10 +371,9 @@ void AnalysisPredictor::MkldnnPostReset() {
CHECK_LE(shape_blob_size,
static_cast<size_t>(config_.mkldnn_cache_capacity_));
}
paddle::platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(0);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str("");
// We cannot reset to the default cache settings
// as there maybe CopyToCPU method used and oneDNN
// primitives are used there so cache would grow
}
#endif
}
......
......@@ -120,6 +120,19 @@ void validate_cache_onednn(int cache_capacity = 1) {
file.close();
infer_file.close();
// Pick first output tensor from model
// as internally reorders may be called
// so it will impact cache size
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
size_t out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
std::vector<float> out_data;
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
// Release predictor (relevant cache should be emptied)
predictor.reset(nullptr);
cache_filling.push_back(GetNumCachedObjects());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册