未验证 提交 db2b6b65 编写于 作者: P pawelpiotrowicz 提交者: GitHub

Hide globals & redesign restore PR (#24279)

test=develop
上级 4a105f80
...@@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"); "non-MKLDNN");
innerTransDataLayoutFromMKLDNN(in_layout, innerTransDataLayoutFromMKLDNN(
paddle::platform::get_cur_paddle_data_layout(), in_layout,
in, out, place); paddle::platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout(),
in, out, place);
} }
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
......
...@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
// For NHWC data we need reshape of tensors as MKL-DNN // For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order // is expecting NHWC dims description order
platform::MatchShapeToLayout(&out, lin, lout); platform::MatchShapeToLayout(&out, lin, lout);
paddle::platform::set_cur_paddle_data_layout(lin); paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
lin);
out.set_layout(DataLayout::kMKLDNN); out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format); out.set_format(out_format);
} else { } else {
......
...@@ -89,7 +89,8 @@ Executor::~Executor() { ...@@ -89,7 +89,8 @@ Executor::~Executor() {
platform::MKLDNNDeviceContext* dev_ctx = platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_); (platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap(); dev_ctx->ResetBlobMap();
platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW); platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
} }
#endif #endif
} }
......
...@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData(
if ((tensor_in->layout() == DataLayout::kMKLDNN) && if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
(var->IsType<LoDTensor>() == true) && (var->IsType<LoDTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) && (expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(paddle::platform::get_cur_paddle_data_layout() == (paddle::platform::MKLDNNDeviceContext::tls()
DataLayout::kNHWC)) { .get_cur_paddle_data_layout() == DataLayout::kNHWC)) {
// Mixed execution : MKL-DNN and GPU is not supported! // Mixed execution : MKL-DNN and GPU is not supported!
if (!new_scope) { if (!new_scope) {
new_scope = &scope.NewScope(); new_scope = &scope.NewScope();
......
...@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() { ...@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() {
void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) { void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
VLOG(2) << "AnalysisPredictor::Run get_cur_mkldnn_session_id=" VLOG(2) << "AnalysisPredictor::Run get_cur_mkldnn_session_id="
<< platform::get_cur_mkldnn_session_id(); << platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id();
// In cache clearing mode. // In cache clearing mode.
if (config_.mkldnn_cache_capacity_ > 0) { if (config_.mkldnn_cache_capacity_ > 0) {
VLOG(2) << "In mkldnn cache clear mode."; VLOG(2) << "In mkldnn cache clear mode.";
platform::set_cur_mkldnn_session_id( platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
platform::kMKLDNNSessionID_CacheClearing); platform::MKLDNNDeviceContextThreadLocals::
platform::set_cur_input_shape_cache_capacity( kMKLDNNSessionID_CacheClearing);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(
config_.mkldnn_cache_capacity_); config_.mkldnn_cache_capacity_);
// Set current_input_shape for caching dynamic shape. // Set current_input_shape for caching dynamic shape.
std::stringstream ss; std::stringstream ss;
...@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) { ...@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
} }
} }
VLOG(2) << "Set input shape=" << ss.str(); VLOG(2) << "Set input shape=" << ss.str();
platform::set_cur_input_shape_str(ss.str()); platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str());
} }
#endif #endif
} }
...@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() { ...@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() {
CHECK_LE(shape_blob_size, CHECK_LE(shape_blob_size,
static_cast<size_t>(config_.mkldnn_cache_capacity_)); static_cast<size_t>(config_.mkldnn_cache_capacity_));
} }
paddle::platform::set_cur_mkldnn_session_id( paddle::platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
platform::kMKLDNNSessionID_Default); platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default);
platform::set_cur_input_shape_cache_capacity(0); platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(0);
platform::set_cur_input_shape_str(""); platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str("");
} }
#endif #endif
} }
......
...@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item, ...@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item,
// Convert to desired Paddle layout, apart from grads of filter // Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format // as params are not a subject to paddle's data_format
framework::innerTransDataLayoutFromMKLDNN( framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), src_item.layout(), fetch_var_name == framework::GradVarName("Filter")
fetch_var_name == framework::GradVarName("Filter") ? framework::DataLayout::kNCHW
? framework::DataLayout::kNCHW : paddle::platform::MKLDNNDeviceContext::tls()
: paddle::platform::get_cur_paddle_data_layout(), .get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace()); src_item, &out, platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), dst_item); TensorCopySync(out, platform::CPUPlace(), dst_item);
} else { } else {
......
...@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// of conv int8 mkl-dnn. Once conv fp32 and conv int8 // of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear // are merged/unified, this will disappear
std::string key_tid = ""; std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() == if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) { platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::ThreadIDasStr(); key_tid = "-t:" + platform::ThreadIDasStr();
} }
......
...@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) ...@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
p_mutex_.reset(new std::mutex()); p_mutex_.reset(new std::mutex());
} }
namespace { MKLDNNDeviceContextThreadLocals::Body::Body() {
// Current mkldnn session id. cur_mkldnn_session_id = kMKLDNNSessionID_Default;
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; cur_input_shape_str = "";
// Current data input shape string. cur_input_shape_cache_capacity = 1;
// - For fixed-shape, it's a null string in default. cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
// - For dynamic-shape, it's user specific. }
thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN. void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
// Default 1 means fixed input shape, not dynamic shape. size_t sid) {
thread_local int cur_input_shape_cache_capacity = 1; cur_mkldnn_session_id = sid;
// Recently registered data_format. This is needed to }
// know for converting MKL-DNN Tensor to non MKL-DNN size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
thread_local paddle::framework::DataLayout cur_paddle_data_layout = return cur_mkldnn_session_id;
paddle::framework::DataLayout::kNCHW; }
} // namespace
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } std::string input_shape_str) {
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) {
cur_input_shape_str = input_shape_str; cur_input_shape_str = input_shape_str;
} }
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) { void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity; cur_input_shape_cache_capacity = input_shape_cache_capacity;
} }
void set_cur_paddle_data_layout(framework::DataLayout dl) { void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
framework::DataLayout dl) {
cur_paddle_data_layout = dl; cur_paddle_data_layout = dl;
} }
framework::DataLayout get_cur_paddle_data_layout(void) { framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout; return cur_paddle_data_layout;
} }
...@@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const { ...@@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
} }
size_t MKLDNNDeviceContext::GetShapeBlobSize() const { size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(cur_mkldnn_session_id); auto map_it = pMap->find(tls().cur_mkldnn_session_id);
if (map_it == pMap->end()) { if (map_it == pMap->end()) {
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : " LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<< cur_mkldnn_session_id; << tls().cur_mkldnn_session_id;
} }
return map_it->second->size(); return map_it->second->size();
} }
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { BlobPtr_t<void> data) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr; BlobPtr_t<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr; BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = platform::get_cur_mkldnn_session_id(); int sid = tls().get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id. // Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(sid); auto map_it = pMap->find(sid);
if (map_it == pMap->end()) { if (map_it == pMap->end()) {
// 1st time to set blob in current thread // 1st time to set blob in current thread
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob()); sBlob = std::make_shared<ShapeBlob>();
(*pMap)[sid] = sBlob; (*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n"; VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else { } else {
...@@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
} }
// Find KeyBlob for current input shape // Find KeyBlob for current input shape
auto key_it = sBlob->find(cur_input_shape_str); auto key_it = sBlob->find(tls().cur_input_shape_str);
if (key_it == sBlob->end()) { if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines // In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity // max pblob capacity
if ((static_cast<size_t>(sid) == kMKLDNNSessionID_CacheClearing) && if ((static_cast<size_t>(sid) ==
MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
sBlob->size() && sBlob->size() &&
(sBlob->size() >= (sBlob->size() >=
static_cast<size_t>(cur_input_shape_cache_capacity))) { static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first; << ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first); sBlob->erase(sBlob->begin()->first);
} }
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob()); pBlob = std::make_shared<KeyBlob>();
(*sBlob)[cur_input_shape_str] = pBlob; (*sBlob)[tls().cur_input_shape_str] = pBlob;
} else { } else {
pBlob = key_it->second; pBlob = key_it->second;
} }
...@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
return; return;
} }
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const { const std::string& name) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr; BlobPtr_t<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr; BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = platform::get_cur_mkldnn_session_id(); int sid = tls().get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id firstly // Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid); auto map_it = pMap->find(sid);
...@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( ...@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
sBlob = map_it->second; sBlob = map_it->second;
// Find KeyBlob for current input shape secondly // Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str); auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
if (sBlob_it == sBlob->end()) { if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
<< ", miss input_shape_str\n"; << ", miss input_shape_str\n";
return nullptr; return nullptr;
} }
......
...@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> { ...@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are: class MKLDNNDeviceContextThreadLocals {
// - BlobMap = Map<cur_thread_id, ShapeBlob> // default mkldnn session id
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob> typedef MKLDNNDeviceContextThreadLocals self;
// Where: struct Body {
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>; size_t cur_mkldnn_session_id;
using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>; // Current data input shape string.
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>; // - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
// default mkldnn session id std::string cur_input_shape_str;
constexpr size_t kMKLDNNSessionID_Default = 0; // the cache capacity of different input shapes for MKLDNN.
// mkldnn session id for cache clearing mode // Default 1 means fixed input shape, not dynamic shape.
constexpr size_t kMKLDNNSessionID_CacheClearing = -1; int cur_input_shape_cache_capacity;
// Recently registered data_format. This is needed to
void set_cur_mkldnn_session_id(size_t); // know for converting MKL-DNN Tensor to non MKL-DNN
size_t get_cur_mkldnn_session_id(void); paddle::framework::DataLayout cur_paddle_data_layout;
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); Body();
void set_cur_paddle_data_layout(framework::DataLayout); void set_cur_mkldnn_session_id(size_t sid);
framework::DataLayout get_cur_paddle_data_layout(void); size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(framework::DataLayout dl);
framework::DataLayout get_cur_paddle_data_layout(void);
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
delete;
public:
// default mkldnn session id
static constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
static Body& fetch() {
thread_local Body b;
return b;
}
};
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
template <class T>
using BlobPtr_t = std::shared_ptr<T>;
template <class P1, class P2>
using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
template <class T>
using umap_key_string_t = umap_value_smart_t<std::string, T>;
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
using KeyBlob = umap_key_string_t<void>;
using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
explicit MKLDNNDeviceContext(CPUPlace place); explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Get the active engine */ /* \brief Get the active engine */
...@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Find a saved blob. Return nullptr if not found // Find a saved blob. Return nullptr if not found
std::shared_ptr<void> GetBlob(const std::string& name) const; std::shared_ptr<void> GetBlob(const std::string& name) const;
static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
return MKLDNNDeviceContextThreadLocals::fetch();
}
private: private:
mkldnn::engine engine_; mkldnn::engine engine_;
std::shared_ptr<BlobMap> p_blobmap_; std::shared_ptr<BlobMap> p_blobmap_;
......
...@@ -42,8 +42,8 @@ class MKLDNNHandlerT { ...@@ -42,8 +42,8 @@ class MKLDNNHandlerT {
key_common_(base_key), key_common_(base_key),
fwd_pd_(nullptr), fwd_pd_(nullptr),
bwd_pd_(nullptr) { bwd_pd_(nullptr) {
if (platform::get_cur_mkldnn_session_id() != if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) { platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_; key_ = key_common_;
} else { } else {
key_ = key_common_ + "-t:" + ThreadIDasStr(); key_ = key_common_ + "-t:" + ThreadIDasStr();
...@@ -177,8 +177,8 @@ class MKLDNNHandler { ...@@ -177,8 +177,8 @@ class MKLDNNHandler {
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key) const std::string& base_key)
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
if (platform::get_cur_mkldnn_session_id() != if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) { platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_; key_ = key_common_;
} else { } else {
key_ = key_common_ + "-t:" + ThreadIDasStr(); key_ = key_common_ + "-t:" + ThreadIDasStr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册