提交 cfcb96d2 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Fix int8 performance regression (#18758)

test=develop

- optimization of TID to string

test=develop
上级 e0a2d4df
......@@ -83,11 +83,9 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
std::to_string(multi_input[0]->format()));
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(&key, ss.str());
platform::MKLDNNHandler::AppendKey(
&key, platform::MKLDNNHandler::ThreadIDasStr());
}
return key;
}
......
......@@ -408,12 +408,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p";
auto user_src_key = key + "@user_src_mem_p";
auto src_reorder_key = key + "@src_mem_preorder_p";
auto residual_reorder_key = key + "@residual_data_mem_preorder_p";
// This is workaround for hacky implementation
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::MKLDNNHandler::ThreadIDasStr();
}
auto prim_key = key + key_tid + "@conv_p";
auto dst_key = key + key_tid + "@dst_mem_p";
auto src_key = key + key_tid + "@src_mem_p";
auto user_src_key = key + key_tid + "@user_src_mem_p";
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
......
......@@ -34,14 +34,11 @@ class MKLDNNHandler {
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
// TODO(jczaja): Make it faster
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
key_ = key_common_ + "-t:" + ss.str();
if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + MKLDNNHandler::ThreadIDasStr();
}
}
......@@ -205,6 +202,11 @@ class MKLDNNHandler {
return target_memory_p;
}
static std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册