未验证 提交 fe32879d 编写于 作者: T Tao Luo 提交者: GitHub

add mkldnn shapeblob cache clear strategy (#18513)

* add mkldnn shapeblob cache clear strategy

test=develop

* refine with comments

test=develop

* make cache clear strategy more safey

test=develop

* add lock for GetShapeBlobSize

test=develop
上级 e576f266
...@@ -172,5 +172,61 @@ TEST(Analyzer_MM_DNN, compare_determine) { ...@@ -172,5 +172,61 @@ TEST(Analyzer_MM_DNN, compare_determine) {
input_slots_all); input_slots_all);
} }
#ifdef PADDLE_WITH_MKLDNN
void TestMkldnnCacheClear(int mkldnn_input_shape_cache_capacity) {
AnalysisConfig config;
SetConfig(&config);
config.EnableMKLDNN();
// TODO(luotao): explicit following settings will be deprecated after enhance
// config.EnableMKLDNN() interface.
if (mkldnn_input_shape_cache_capacity > 0) {
platform::set_cur_mkldnn_session_id(
platform::kMKLDNNSessionID_CacheClearing);
platform::set_cur_input_shape_cache_capacity(
mkldnn_input_shape_cache_capacity);
}
std::vector<PaddleTensor> input, output;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
int sample_num = 10;
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
auto &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext *>(
pool.Get(platform::CPUPlace()));
for (int i = 0; i < sample_num; i++) {
PrepareInputs(&input, &data, FLAGS_batch_size);
if (mkldnn_input_shape_cache_capacity > 0) {
std::stringstream ss;
for (size_t i = 0; i < input.size(); i++) {
for (size_t j = 0; j < input[i].shape.size(); ++j) {
ss << input[i].shape[j] << "-";
}
}
// TODO(luotao): explicit following settings will be deprecated after
// enhance config.EnableMKLDNN() interface.
platform::set_cur_input_shape_str(ss.str());
}
predictor->Run(input, &output, 1);
}
if (mkldnn_input_shape_cache_capacity > 0) {
PADDLE_ENFORCE_EQ(dev_ctx->GetShapeBlobSize(),
mkldnn_input_shape_cache_capacity);
} else {
PADDLE_ENFORCE_EQ(dev_ctx->GetShapeBlobSize(), 1UL);
}
dev_ctx->ResetBlobMap();
}
TEST(Analyzer_MM_DNN, mkldnn_cache_clear) {
// 0 means do not use cache clear strategy.
TestMkldnnCacheClear(0);
// 4 means use cache clear strategy, and the
// mkldnn_input_shape_cache_capacity is 4.
TestMkldnnCacheClear(4);
}
#endif
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -407,6 +407,9 @@ thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; ...@@ -407,6 +407,9 @@ thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
// - For fixed-shape, it's a null string in default. // - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific. // - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = ""; thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1;
} // namespace } // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
...@@ -414,37 +417,58 @@ size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } ...@@ -414,37 +417,58 @@ size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) { void set_cur_input_shape_str(std::string input_shape_str) {
cur_input_shape_str = input_shape_str; cur_input_shape_str = input_shape_str;
} }
std::string get_cur_input_shape_str(void) { return cur_input_shape_str; } void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
std::lock_guard<std::mutex> lock(*p_mutex_);
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(cur_mkldnn_session_id);
if (map_it == pMap->end()) {
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<< cur_mkldnn_session_id;
}
return map_it->second->size();
}
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr; std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr; std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_mkldnn_session_id(); int sid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
// Find ShapeBlob for current thread // Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(tid); auto map_it = pMap->find(sid);
if (map_it == pMap->end()) { if (map_it == pMap->end()) {
// 1st time to set blob in current thread // 1st time to set blob in current thread
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob()); sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
(*pMap)[tid] = sBlob; (*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n"; VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else { } else {
sBlob = map_it->second; sBlob = map_it->second;
} }
// Find KeyBlob for current input shape // Find KeyBlob for current input shape
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
auto key_it = sBlob->find(cur_input_shape_str); auto key_it = sBlob->find(cur_input_shape_str);
if (key_it == sBlob->end()) { if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((sid == kMKLDNNSessionID_CacheClearing) &&
(sBlob->size() >=
static_cast<size_t>(cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
}
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob()); pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*sBlob)[cur_input_shape_str] = pBlob; (*sBlob)[cur_input_shape_str] = pBlob;
} else { } else {
...@@ -458,7 +482,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -458,7 +482,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
} else { } else {
blob_it->second = data; // set data to existing blob blob_it->second = data; // set data to existing blob
} }
VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n"; VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope // lock will be automatically released when out of scope
return; return;
} }
...@@ -469,23 +493,22 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( ...@@ -469,23 +493,22 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
std::shared_ptr<ShapeBlob> sBlob = nullptr; std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr; std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_mkldnn_session_id(); int sid = platform::get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
// Find ShapeBlob for current thread firstly // Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(tid); auto map_it = pMap->find(sid);
if (map_it == pMap->end()) { if (map_it == pMap->end()) {
VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n"; VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
return nullptr; return nullptr;
} }
std::string cur_input_shape_str = platform::get_cur_input_shape_str();
sBlob = map_it->second; sBlob = map_it->second;
// Find KeyBlob for current input shape secondly // Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str); auto sBlob_it = sBlob->find(cur_input_shape_str);
if (sBlob_it == sBlob->end()) { if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: tid=" << cur_input_shape_str VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
<< ", miss input_shape_str\n"; << ", miss input_shape_str\n";
return nullptr; return nullptr;
} }
...@@ -495,11 +518,11 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( ...@@ -495,11 +518,11 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
auto key_it = pBlob->find(name); auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) { if (key_it == pBlob->end()) {
VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n"; VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr; return nullptr;
} }
VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n"; VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope // lock will be automatically released when out of scope
return key_it->second; return key_it->second;
} }
......
...@@ -396,7 +396,7 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1; ...@@ -396,7 +396,7 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t); void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void); size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str); void set_cur_input_shape_str(std::string input_shape_str);
std::string get_cur_input_shape_str(void); void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
...@@ -408,6 +408,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -408,6 +408,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map // Remove all entries from the blob map
void ResetBlobMap() const; void ResetBlobMap() const;
// Get the ShapeBlob size in cur_mkldnn_session_id.
size_t GetShapeBlobSize() const;
// Set data to blob (i.e. name/data pair). Create blob if not existing // Set data to blob (i.e. name/data pair). Create blob if not existing
void SetBlob(const std::string& name, std::shared_ptr<void> data) const; void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册