未验证 提交 79da263b 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14032 from sfraczek/sfraczek/fix-test-multithreading-mkldnn

fix test resnet50 multi-threading on mkldnn
...@@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid, ...@@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid,
double latency, int epoch = 1) { double latency, int epoch = 1) {
LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat
<< ", threads: " << num_threads << ", thread id: " << tid << ", threads: " << num_threads << ", thread id: " << tid
<< ", latency: " << latency << "ms ======"; << ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f)
<< " ======";
if (epoch > 1) { if (epoch > 1) {
int samples = batch_size * epoch; int samples = batch_size * epoch;
LOG(INFO) << "====== sample number: " << samples LOG(INFO) << "====== sample number: " << samples
......
...@@ -139,6 +139,9 @@ void TestMultiThreadPrediction( ...@@ -139,6 +139,9 @@ void TestMultiThreadPrediction(
} }
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
#ifdef PADDLE_WITH_MKLDNN
platform::set_cur_thread_id(static_cast<int>(tid) + 1);
#endif
// Each thread should have local inputs and outputs. // Each thread should have local inputs and outputs.
// The inputs of each thread are all the same. // The inputs of each thread are all the same.
std::vector<std::vector<PaddleTensor>> inputs_tid = inputs; std::vector<std::vector<PaddleTensor>> inputs_tid = inputs;
......
...@@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } ...@@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>()); p_blobmap_.reset(new BlobMap());
p_mutex_.reset(new std::mutex());
} }
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { std::shared_ptr<void> data) const {
std::unordered_map<std::string, std::shared_ptr<void>>* p; BlobMap* pMap = p_blobmap_.get();
p = p_blobs_.get(); std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id();
auto it = p->find(name); std::lock_guard<std::mutex> lock(*p_mutex_.get());
if (it == p->end()) { // Find KeyBlob for current thread
(*p)[name] = data; // create new blob auto map_it = pMap->find(tid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*pMap)[tid] = pBlob;
} else { } else {
it->second = data; // set data to existing blob pBlob = map_it->second;
} }
// Find Key in found (or newly created) KeyBlob
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) {
(*pBlob)[name] = data; // create new blob
} else {
key_it->second = data; // set data to existing blob
}
// lock will be automatically released when out of scope
return; return;
} }
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const { const std::string& name) const {
std::unordered_map<std::string, std::shared_ptr<void>>* p; BlobMap* pMap = p_blobmap_.get();
p = p_blobs_.get(); std::shared_ptr<KeyBlob> pBlob = nullptr;
auto it = p->find(name); int tid = platform::get_cur_thread_id();
if (it != p->end()) { std::lock_guard<std::mutex> lock(*p_mutex_.get());
return it->second;
} // Find KeyBlob for current thread firstly
auto map_it = pMap->find(tid);
if (map_it == pMap->end()) return nullptr;
pBlob = map_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) return nullptr;
return nullptr; // lock will be automatically released when out of scope
return key_it->second;
} }
#endif #endif
......
...@@ -176,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> { ...@@ -176,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int);
int get_cur_thread_id(void);
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
explicit MKLDNNDeviceContext(CPUPlace place); explicit MKLDNNDeviceContext(CPUPlace place);
...@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
private: private:
mkldnn::engine engine_; mkldnn::engine engine_;
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>> std::shared_ptr<BlobMap> p_blobmap_;
p_blobs_; std::shared_ptr<std::mutex> p_mutex_;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册