未验证 提交 f3909020 编写于 作者: A Adam 提交者: GitHub

Add mechanism for blocking oneDNN cache clearing (#26502)

* Add mechanism for blocking oneDNN cache clearing

* Review changes and Add thread guards
上级 7d3e46e1
......@@ -29,6 +29,11 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace operators {
......@@ -262,6 +267,9 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
}
VLOG(2) << "The number of sub scopes after forward: "
<< out_scope_vec->front()->kids().size();
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) DontClearMKLDNNCache(ctx.GetPlace());
#endif
}
};
......
......@@ -464,9 +464,21 @@ MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}
void MKLDNNDeviceContext::ResetBlobMap() const {
void MKLDNNDeviceContext::ResetBlobMap() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) {
VLOG(3) << "Clearing DNNL cache.";
p_blobmap_->clear();
} else {
VLOG(3) << "Prevented Clearing DNNL cache.";
block_next_cache_clearing_ = false;
}
}
void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
VLOG(3) << "Next DNNL cache clearing has been blocked.";
block_next_cache_clearing_ = true;
}
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
......
......@@ -520,7 +520,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
const mkldnn::engine& GetEngine() const { return engine_; }
// Remove all entries from the blob map
void ResetBlobMap() const;
void ResetBlobMap();
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
// Get the ShapeBlob size in cur_mkldnn_session_id.
size_t GetShapeBlobSize() const;
......@@ -539,6 +542,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
mkldnn::engine engine_;
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
};
#endif
......
......@@ -129,6 +129,16 @@ inline void ClearMKLDNNCache(const platform::Place& place) {
}
}
inline void DontClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->BlockNextCacheClearing();
}
}
template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册