提交 cfcb96d2 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Fix int8 performance regression (#18758)

test=develop

- optimization of TID to string

test=develop
上级 e0a2d4df
...@@ -83,11 +83,9 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -83,11 +83,9 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
std::to_string(multi_input[0]->format())); std::to_string(multi_input[0]->format()));
if (platform::get_cur_mkldnn_session_id() == if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) { platform::kMKLDNNSessionID_Default) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
platform::MKLDNNHandler::AppendKey(&key, "-t:"); platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(&key, ss.str()); platform::MKLDNNHandler::AppendKey(
&key, platform::MKLDNNHandler::ThreadIDasStr());
} }
return key; return key;
} }
......
...@@ -408,12 +408,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -408,12 +408,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
std::shared_ptr<platform::ConvMKLDNNHandler> handler; std::shared_ptr<platform::ConvMKLDNNHandler> handler;
auto prim_key = key + "@conv_p"; // This is workaround for hacky implementation
auto dst_key = key + "@dst_mem_p"; // of conv int8 mkl-dnn. Once conv fp32 and conv int8
auto src_key = key + "@src_mem_p"; // are merged/unified, this will disappear
auto user_src_key = key + "@user_src_mem_p"; std::string key_tid = "";
auto src_reorder_key = key + "@src_mem_preorder_p"; if (platform::get_cur_mkldnn_session_id() ==
auto residual_reorder_key = key + "@residual_data_mem_preorder_p"; platform::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::MKLDNNHandler::ThreadIDasStr();
}
auto prim_key = key + key_tid + "@conv_p";
auto dst_key = key + key_tid + "@dst_mem_p";
auto src_key = key + key_tid + "@src_mem_p";
auto user_src_key = key + key_tid + "@user_src_mem_p";
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>( conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key)); dev_ctx.GetBlob(prim_key));
......
...@@ -34,14 +34,11 @@ class MKLDNNHandler { ...@@ -34,14 +34,11 @@ class MKLDNNHandler {
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key) const std::string& base_key)
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
// TODO(jczaja): Make it faster
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
key_ = key_common_ + "-t:" + ss.str();
if (platform::get_cur_mkldnn_session_id() != if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) { platform::kMKLDNNSessionID_Default) {
key_ = key_common_; key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + MKLDNNHandler::ThreadIDasStr();
} }
} }
...@@ -205,6 +202,11 @@ class MKLDNNHandler { ...@@ -205,6 +202,11 @@ class MKLDNNHandler {
return target_memory_p; return target_memory_p;
} }
static std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) { const std::string& suffix) {
return dims2str(operand_dims) + suffix; return dims2str(operand_dims) + suffix;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册