未验证 提交 db2b6b65 编写于 作者: P pawelpiotrowicz 提交者: GitHub

Hide globals & redesign restore PR (#24279)

test=develop
上级 4a105f80
......@@ -124,8 +124,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN");
innerTransDataLayoutFromMKLDNN(in_layout,
paddle::platform::get_cur_paddle_data_layout(),
innerTransDataLayoutFromMKLDNN(
in_layout,
paddle::platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout(),
in, out, place);
}
......
......@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
platform::MatchShapeToLayout(&out, lin, lout);
paddle::platform::set_cur_paddle_data_layout(lin);
paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
lin);
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
} else {
......
......@@ -89,7 +89,8 @@ Executor::~Executor() {
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif
}
......
......@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData(
if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
(var->IsType<LoDTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(paddle::platform::get_cur_paddle_data_layout() ==
DataLayout::kNHWC)) {
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == DataLayout::kNHWC)) {
// Mixed execution : MKL-DNN and GPU is not supported!
if (!new_scope) {
new_scope = &scope.NewScope();
......
......@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() {
void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
#ifdef PADDLE_WITH_MKLDNN
VLOG(2) << "AnalysisPredictor::Run get_cur_mkldnn_session_id="
<< platform::get_cur_mkldnn_session_id();
<< platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id();
// In cache clearing mode.
if (config_.mkldnn_cache_capacity_ > 0) {
VLOG(2) << "In mkldnn cache clear mode.";
platform::set_cur_mkldnn_session_id(
platform::kMKLDNNSessionID_CacheClearing);
platform::set_cur_input_shape_cache_capacity(
platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
platform::MKLDNNDeviceContextThreadLocals::
kMKLDNNSessionID_CacheClearing);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(
config_.mkldnn_cache_capacity_);
// Set current_input_shape for caching dynamic shape.
std::stringstream ss;
......@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
}
}
VLOG(2) << "Set input shape=" << ss.str();
platform::set_cur_input_shape_str(ss.str());
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str());
}
#endif
}
......@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() {
CHECK_LE(shape_blob_size,
static_cast<size_t>(config_.mkldnn_cache_capacity_));
}
paddle::platform::set_cur_mkldnn_session_id(
platform::kMKLDNNSessionID_Default);
platform::set_cur_input_shape_cache_capacity(0);
platform::set_cur_input_shape_str("");
paddle::platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(0);
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str("");
}
#endif
}
......
......@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item,
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(),
fetch_var_name == framework::GradVarName("Filter")
src_item.layout(), fetch_var_name == framework::GradVarName("Filter")
? framework::DataLayout::kNCHW
: paddle::platform::get_cur_paddle_data_layout(),
: paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), dst_item);
} else {
......
......@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::ThreadIDasStr();
}
......
......@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
p_mutex_.reset(new std::mutex());
}
namespace {
// Current mkldnn session id.
thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
thread_local paddle::framework::DataLayout cur_paddle_data_layout =
paddle::framework::DataLayout::kNCHW;
} // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; }
void set_cur_input_shape_str(std::string input_shape_str) {
MKLDNNDeviceContextThreadLocals::Body::Body() {
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
cur_input_shape_str = "";
cur_input_shape_cache_capacity = 1;
cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
size_t sid) {
cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
return cur_mkldnn_session_id;
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
void set_cur_paddle_data_layout(framework::DataLayout dl) {
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
framework::DataLayout dl) {
cur_paddle_data_layout = dl;
}
framework::DataLayout get_cur_paddle_data_layout(void) {
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}
......@@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
}
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
std::lock_guard<std::mutex> lock(*p_mutex_);
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(cur_mkldnn_session_id);
auto map_it = pMap->find(tls().cur_mkldnn_session_id);
if (map_it == pMap->end()) {
LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<< cur_mkldnn_session_id;
<< tls().cur_mkldnn_session_id;
}
return map_it->second->size();
}
void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const {
BlobPtr_t<void> data) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = platform::get_cur_mkldnn_session_id();
int sid = tls().get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(sid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
sBlob = std::shared_ptr<ShapeBlob>(new ShapeBlob());
sBlob = std::make_shared<ShapeBlob>();
(*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else {
......@@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
}
// Find KeyBlob for current input shape
auto key_it = sBlob->find(cur_input_shape_str);
auto key_it = sBlob->find(tls().cur_input_shape_str);
if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((static_cast<size_t>(sid) == kMKLDNNSessionID_CacheClearing) &&
if ((static_cast<size_t>(sid) ==
MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
sBlob->size() &&
(sBlob->size() >=
static_cast<size_t>(cur_input_shape_cache_capacity))) {
static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
}
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*sBlob)[cur_input_shape_str] = pBlob;
pBlob = std::make_shared<KeyBlob>();
(*sBlob)[tls().cur_input_shape_str] = pBlob;
} else {
pBlob = key_it->second;
}
......@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
return;
}
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const {
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<ShapeBlob> sBlob = nullptr;
std::shared_ptr<KeyBlob> pBlob = nullptr;
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = platform::get_cur_mkldnn_session_id();
int sid = tls().get_cur_mkldnn_session_id();
std::lock_guard<std::mutex> lock(*p_mutex_);
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid);
......@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
sBlob = map_it->second;
// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(cur_input_shape_str);
auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
if (sBlob_it == sBlob->end()) {
VLOG(2) << "GetBlob: sid=" << cur_input_shape_str
VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
......
......@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
// Where:
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using ShapeBlob = std::unordered_map<std::string, std::shared_ptr<KeyBlob>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<ShapeBlob>>;
// default mkldnn session id
constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(framework::DataLayout);
framework::DataLayout get_cur_paddle_data_layout(void);
class MKLDNNDeviceContextThreadLocals {
// default mkldnn session id
typedef MKLDNNDeviceContextThreadLocals self;
struct Body {
size_t cur_mkldnn_session_id;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
std::string cur_input_shape_str;
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
int cur_input_shape_cache_capacity;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
paddle::framework::DataLayout cur_paddle_data_layout;
Body();
void set_cur_mkldnn_session_id(size_t sid);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(framework::DataLayout dl);
framework::DataLayout get_cur_paddle_data_layout(void);
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
delete;
public:
// default mkldnn session id
static constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
static Body& fetch() {
thread_local Body b;
return b;
}
};
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
template <class T>
using BlobPtr_t = std::shared_ptr<T>;
template <class P1, class P2>
using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
template <class T>
using umap_key_string_t = umap_value_smart_t<std::string, T>;
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
using KeyBlob = umap_key_string_t<void>;
using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Get the active engine */
......@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Find a saved blob. Return nullptr if not found
std::shared_ptr<void> GetBlob(const std::string& name) const;
static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
return MKLDNNDeviceContextThreadLocals::fetch();
}
private:
mkldnn::engine engine_;
std::shared_ptr<BlobMap> p_blobmap_;
......
......@@ -42,8 +42,8 @@ class MKLDNNHandlerT {
key_common_(base_key),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) {
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
......@@ -177,8 +177,8 @@ class MKLDNNHandler {
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) {
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册